@@ -128,18 +128,23 @@ def _run(
128
128
Args:
129
129
messages (List[OpenAIMessage]): Message list with the chat history
130
130
in OpenAI API format.
131
+ response_format (Optional[Type[BaseModel]]): The format of the
132
+ response.
133
+ tools (Optional[List[Dict[str, Any]]]): The schema of the tools to
134
+ use for the request.
131
135
132
136
Returns:
133
137
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
134
138
`ChatCompletion` in the non-stream mode, or
135
139
`Stream[ChatCompletionChunk]` in the stream mode.
136
140
"""
137
- response = self ._client .chat .completions .create (
138
- messages = messages ,
139
- model = self .azure_deployment_name , # type:ignore[arg-type]
140
- ** self .model_config_dict ,
141
+ response_format = response_format or self .model_config_dict .get (
142
+ "response_format" , None
141
143
)
142
- return response
144
+ if response_format :
145
+ return self ._request_parse (messages , response_format , tools )
146
+ else :
147
+ return self ._request_chat_completion (messages , tools )
143
148
144
149
async def _arun (
145
150
self ,
@@ -152,18 +157,93 @@ async def _arun(
152
157
Args:
153
158
messages (List[OpenAIMessage]): Message list with the chat history
154
159
in OpenAI API format.
160
+ response_format (Optional[Type[BaseModel]]): The format of the
161
+ response.
162
+ tools (Optional[List[Dict[str, Any]]]): The schema of the tools to
163
+ use for the request.
155
164
156
165
Returns:
157
166
Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
158
167
`ChatCompletion` in the non-stream mode, or
159
168
`AsyncStream[ChatCompletionChunk]` in the stream mode.
160
169
"""
161
- response = await self ._async_client .chat .completions .create (
170
+ response_format = response_format or self .model_config_dict .get (
171
+ "response_format" , None
172
+ )
173
+ if response_format :
174
+ return await self ._arequest_parse (messages , response_format , tools )
175
+ else :
176
+ return await self ._arequest_chat_completion (messages , tools )
177
+
178
+ def _request_chat_completion (
179
+ self ,
180
+ messages : List [OpenAIMessage ],
181
+ tools : Optional [List [Dict [str , Any ]]] = None ,
182
+ ) -> Union [ChatCompletion , Stream [ChatCompletionChunk ]]:
183
+ request_config = self .model_config_dict .copy ()
184
+
185
+ if tools :
186
+ request_config ["tools" ] = tools
187
+
188
+ return self ._client .chat .completions .create (
189
+ messages = messages ,
190
+ model = self .azure_deployment_name , # type:ignore[arg-type]
191
+ ** request_config ,
192
+ )
193
+
194
+ async def _arequest_chat_completion (
195
+ self ,
196
+ messages : List [OpenAIMessage ],
197
+ tools : Optional [List [Dict [str , Any ]]] = None ,
198
+ ) -> Union [ChatCompletion , AsyncStream [ChatCompletionChunk ]]:
199
+ request_config = self .model_config_dict .copy ()
200
+
201
+ if tools :
202
+ request_config ["tools" ] = tools
203
+
204
+ return await self ._async_client .chat .completions .create (
205
+ messages = messages ,
206
+ model = self .azure_deployment_name , # type:ignore[arg-type]
207
+ ** request_config ,
208
+ )
209
+
210
+ def _request_parse (
211
+ self ,
212
+ messages : List [OpenAIMessage ],
213
+ response_format : Type [BaseModel ],
214
+ tools : Optional [List [Dict [str , Any ]]] = None ,
215
+ ) -> ChatCompletion :
216
+ request_config = self .model_config_dict .copy ()
217
+
218
+ request_config ["response_format" ] = response_format
219
+ request_config .pop ("stream" , None )
220
+ if tools is not None :
221
+ request_config ["tools" ] = tools
222
+
223
+ return self ._client .beta .chat .completions .parse (
224
+ messages = messages ,
225
+ model = self .azure_deployment_name , # type:ignore[arg-type]
226
+ ** request_config ,
227
+ )
228
+
229
+ async def _arequest_parse (
230
+ self ,
231
+ messages : List [OpenAIMessage ],
232
+ response_format : Type [BaseModel ],
233
+ tools : Optional [List [Dict [str , Any ]]] = None ,
234
+ ) -> ChatCompletion :
235
+ request_config = self .model_config_dict .copy ()
236
+
237
+ request_config ["response_format" ] = response_format
238
+ request_config .pop ("stream" , None )
239
+ if tools is not None :
240
+ request_config ["tools" ] = tools
241
+
242
+ return await self ._async_client .beta .chat .completions .parse (
162
243
messages = messages ,
163
244
model = self .azure_deployment_name , # type:ignore[arg-type]
164
- ** self . model_config_dict ,
245
+ ** request_config ,
165
246
)
166
- return response
167
247
168
248
def check_model_config (self ):
169
249
r"""Check whether the model configuration contains any
0 commit comments