Skip to content

Commit f17e26a

Browse files
committed
feat: enhance Azure and OpenAI models with response format and tools support
1 parent 621c167 commit f17e26a

File tree

2 files changed

+176
-16
lines changed

2 files changed

+176
-16
lines changed

camel/models/azure_openai_model.py

+88-8
Original file line numberDiff line numberDiff line change
@@ -128,18 +128,23 @@ def _run(
128128
Args:
129129
messages (List[OpenAIMessage]): Message list with the chat history
130130
in OpenAI API format.
131+
response_format (Optional[Type[BaseModel]]): The format of the
132+
response.
133+
tools (Optional[List[Dict[str, Any]]]): The schema of the tools to
134+
use for the request.
131135
132136
Returns:
133137
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
134138
`ChatCompletion` in the non-stream mode, or
135139
`Stream[ChatCompletionChunk]` in the stream mode.
136140
"""
137-
response = self._client.chat.completions.create(
138-
messages=messages,
139-
model=self.azure_deployment_name, # type:ignore[arg-type]
140-
**self.model_config_dict,
141+
response_format = response_format or self.model_config_dict.get(
142+
"response_format", None
141143
)
142-
return response
144+
if response_format:
145+
return self._request_parse(messages, response_format, tools)
146+
else:
147+
return self._request_chat_completion(messages, tools)
143148

144149
async def _arun(
145150
self,
@@ -152,18 +157,93 @@ async def _arun(
152157
Args:
153158
messages (List[OpenAIMessage]): Message list with the chat history
154159
in OpenAI API format.
160+
response_format (Optional[Type[BaseModel]]): The format of the
161+
response.
162+
tools (Optional[List[Dict[str, Any]]]): The schema of the tools to
163+
use for the request.
155164
156165
Returns:
157166
Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
158167
`ChatCompletion` in the non-stream mode, or
159168
`AsyncStream[ChatCompletionChunk]` in the stream mode.
160169
"""
161-
response = await self._async_client.chat.completions.create(
170+
response_format = response_format or self.model_config_dict.get(
171+
"response_format", None
172+
)
173+
if response_format:
174+
return await self._arequest_parse(messages, response_format, tools)
175+
else:
176+
return await self._arequest_chat_completion(messages, tools)
177+
178+
def _request_chat_completion(
179+
self,
180+
messages: List[OpenAIMessage],
181+
tools: Optional[List[Dict[str, Any]]] = None,
182+
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
183+
request_config = self.model_config_dict.copy()
184+
185+
if tools:
186+
request_config["tools"] = tools
187+
188+
return self._client.chat.completions.create(
189+
messages=messages,
190+
model=self.azure_deployment_name, # type:ignore[arg-type]
191+
**request_config,
192+
)
193+
194+
async def _arequest_chat_completion(
195+
self,
196+
messages: List[OpenAIMessage],
197+
tools: Optional[List[Dict[str, Any]]] = None,
198+
) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
199+
request_config = self.model_config_dict.copy()
200+
201+
if tools:
202+
request_config["tools"] = tools
203+
204+
return await self._async_client.chat.completions.create(
205+
messages=messages,
206+
model=self.azure_deployment_name, # type:ignore[arg-type]
207+
**request_config,
208+
)
209+
210+
def _request_parse(
211+
self,
212+
messages: List[OpenAIMessage],
213+
response_format: Type[BaseModel],
214+
tools: Optional[List[Dict[str, Any]]] = None,
215+
) -> ChatCompletion:
216+
request_config = self.model_config_dict.copy()
217+
218+
request_config["response_format"] = response_format
219+
request_config.pop("stream", None)
220+
if tools is not None:
221+
request_config["tools"] = tools
222+
223+
return self._client.beta.chat.completions.parse(
224+
messages=messages,
225+
model=self.azure_deployment_name, # type:ignore[arg-type]
226+
**request_config,
227+
)
228+
229+
async def _arequest_parse(
230+
self,
231+
messages: List[OpenAIMessage],
232+
response_format: Type[BaseModel],
233+
tools: Optional[List[Dict[str, Any]]] = None,
234+
) -> ChatCompletion:
235+
request_config = self.model_config_dict.copy()
236+
237+
request_config["response_format"] = response_format
238+
request_config.pop("stream", None)
239+
if tools is not None:
240+
request_config["tools"] = tools
241+
242+
return await self._async_client.beta.chat.completions.parse(
162243
messages=messages,
163244
model=self.azure_deployment_name, # type:ignore[arg-type]
164-
**self.model_config_dict,
245+
**request_config,
165246
)
166-
return response
167247

168248
def check_model_config(self):
169249
r"""Check whether the model configuration contains any

camel/models/openai_compatible_model.py

+88-8
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,23 @@ def _run(
8686
Args:
8787
messages (List[OpenAIMessage]): Message list with the chat history
8888
in OpenAI API format.
89+
response_format (Optional[Type[BaseModel]]): The format of the
90+
response.
91+
tools (Optional[List[Dict[str, Any]]]): The schema of the tools to
92+
use for the request.
8993
9094
Returns:
9195
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
9296
`ChatCompletion` in the non-stream mode, or
9397
`Stream[ChatCompletionChunk]` in the stream mode.
9498
"""
95-
response = self._client.chat.completions.create(
96-
messages=messages,
97-
model=self.model_type,
98-
**self.model_config_dict,
99+
response_format = response_format or self.model_config_dict.get(
100+
"response_format", None
99101
)
100-
return response
102+
if response_format:
103+
return self._request_parse(messages, response_format, tools)
104+
else:
105+
return self._request_chat_completion(messages, tools)
101106

102107
async def _arun(
103108
self,
@@ -110,18 +115,93 @@ async def _arun(
110115
Args:
111116
messages (List[OpenAIMessage]): Message list with the chat history
112117
in OpenAI API format.
118+
response_format (Optional[Type[BaseModel]]): The format of the
119+
response.
120+
tools (Optional[List[Dict[str, Any]]]): The schema of the tools to
121+
use for the request.
113122
114123
Returns:
115124
Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
116125
`ChatCompletion` in the non-stream mode, or
117126
`AsyncStream[ChatCompletionChunk]` in the stream mode.
118127
"""
119-
response = await self._async_client.chat.completions.create(
128+
response_format = response_format or self.model_config_dict.get(
129+
"response_format", None
130+
)
131+
if response_format:
132+
return await self._arequest_parse(messages, response_format, tools)
133+
else:
134+
return await self._arequest_chat_completion(messages, tools)
135+
136+
def _request_chat_completion(
137+
self,
138+
messages: List[OpenAIMessage],
139+
tools: Optional[List[Dict[str, Any]]] = None,
140+
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
141+
request_config = self.model_config_dict.copy()
142+
143+
if tools:
144+
request_config["tools"] = tools
145+
146+
return self._client.chat.completions.create(
147+
messages=messages,
148+
model=self.model_type,
149+
**request_config,
150+
)
151+
152+
async def _arequest_chat_completion(
153+
self,
154+
messages: List[OpenAIMessage],
155+
tools: Optional[List[Dict[str, Any]]] = None,
156+
) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
157+
request_config = self.model_config_dict.copy()
158+
159+
if tools:
160+
request_config["tools"] = tools
161+
162+
return await self._async_client.chat.completions.create(
163+
messages=messages,
164+
model=self.model_type,
165+
**request_config,
166+
)
167+
168+
def _request_parse(
169+
self,
170+
messages: List[OpenAIMessage],
171+
response_format: Type[BaseModel],
172+
tools: Optional[List[Dict[str, Any]]] = None,
173+
) -> ChatCompletion:
174+
request_config = self.model_config_dict.copy()
175+
176+
request_config["response_format"] = response_format
177+
request_config.pop("stream", None)
178+
if tools is not None:
179+
request_config["tools"] = tools
180+
181+
return self._client.beta.chat.completions.parse(
182+
messages=messages,
183+
model=self.model_type,
184+
**request_config,
185+
)
186+
187+
async def _arequest_parse(
188+
self,
189+
messages: List[OpenAIMessage],
190+
response_format: Type[BaseModel],
191+
tools: Optional[List[Dict[str, Any]]] = None,
192+
) -> ChatCompletion:
193+
request_config = self.model_config_dict.copy()
194+
195+
request_config["response_format"] = response_format
196+
request_config.pop("stream", None)
197+
if tools is not None:
198+
request_config["tools"] = tools
199+
200+
return await self._async_client.beta.chat.completions.parse(
120201
messages=messages,
121202
model=self.model_type,
122-
**self.model_config_dict,
203+
**request_config,
123204
)
124-
return response
125205

126206
@property
127207
def token_counter(self) -> BaseTokenCounter:

0 commit comments

Comments
 (0)