-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain4-function-calling.py
137 lines (118 loc) · 4.12 KB
/
main4-function-calling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import json
import os
from openai import AzureOpenAI
from rich import print as rprint
from tools import book_flight, search_flights
background = (
"You are an expert travel planner with 25+ years of experience. Leverage your expertise to provide recommendations,"
+ "while considering the user's personal interests and preferences."
)
tools = [
{
"type": "function",
"function": {
"name": "search_flights",
"description": "Search for flights between the provided cities on the given date.",
"parameters": {
"type": "object",
"properties": {
"departure_city": {"type": "string"},
"destination_city": {"type": "string"},
"travel_date": {"type": "string"},
},
"required": ["departure_city", "destination_city", "travel_date"],
"additionalProperties": False,
},
"strict": True,
},
},
{
"type": "function",
"function": {
"name": "book_flight",
"description": "Book a flight based on the provided flight number and travel date.",
"parameters": {
"type": "object",
"properties": {
"flight_number": {"type": "string"},
"travel_date": {"type": "string"},
},
"required": ["flight_number", "travel_date"],
"additionalProperties": False,
},
"strict": True,
},
},
]
def format_msg(msg: dict):
if msg.get("tool_calls"):
return msg["tool_calls"]
return msg["content"]
def main():
user_request = input("How can I help you today? ")
azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
azure_deployment = os.environ["AZURE_OPENAI_DEPLOYMENT"]
client = AzureOpenAI(
azure_endpoint=azure_endpoint,
api_version="2024-08-01-preview",
azure_deployment=azure_deployment,
)
messages = [
{
"role": "system",
"content": background,
},
{
"role": "user",
"content": user_request,
},
]
while True:
chat_completion = client.chat.completions.create(
messages=messages,
model="gpt-4o",
temperature=0.1,
tools=tools,
)
messages.append(chat_completion.choices[0].message.to_dict())
rprint(
"\n".join(
[
f"[red]{format_msg(m)}[/red]"
for m in messages
if m["role"] != "system"
]
)
)
print()
if chat_completion.choices[0].message.tool_calls:
for tool_call in chat_completion.choices[0].message.tool_calls:
func_name = tool_call.function.name
func_args = json.loads(tool_call.function.arguments)
match func_name:
case "search_flights":
result = search_flights(*func_args)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": f"{result}",
}
)
case "book_flight":
booking_result = book_flight(*func_args)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": f"{booking_result}",
}
)
case _:
raise ValueError(f"Unsupported function: {func_name}")
else:
print(chat_completion.choices[0].message.content)
human_response = input("> ")
messages.append({"role": "user", "content": human_response})
if __name__ == "__main__":
main()