|
5 | 5 |
|
6 | 6 | import pytest
|
7 | 7 | from autogen_core import CancellationToken, FunctionCall, Image
|
8 |
| -from autogen_core.models import CreateResult, UserMessage |
| 8 | +from autogen_core.models import CreateResult, ModelFamily, UserMessage |
9 | 9 | from autogen_ext.models.azure import AzureAIChatCompletionClient
|
10 | 10 | from azure.ai.inference.aio import (
|
11 | 11 | ChatCompletionsClient,
|
@@ -295,3 +295,82 @@ async def _mock_create_noop(*args: Any, **kwargs: Any) -> ChatCompletions:
|
295 | 295 | ]
|
296 | 296 | )
|
297 | 297 | assert result.content == "Handled image"
|
| 298 | + |
| 299 | + |
| 300 | +@pytest.mark.asyncio |
| 301 | +async def test_r1_content(monkeypatch: pytest.MonkeyPatch) -> None: |
| 302 | + """ |
| 303 | + Ensures that the content is parsed correctly when it contains an R1-style think field. |
| 304 | + """ |
| 305 | + |
| 306 | + async def _mock_create_r1_content_stream( |
| 307 | + *args: Any, **kwargs: Any |
| 308 | + ) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]: |
| 309 | + mock_chunks_content = ["<think>Thought</think> Hello", " Another Hello", " Yet Another Hello"] |
| 310 | + |
| 311 | + mock_chunks = [ |
| 312 | + StreamingChatChoiceUpdate( |
| 313 | + index=0, |
| 314 | + finish_reason="stop", |
| 315 | + delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content), |
| 316 | + ) |
| 317 | + for chunk_content in mock_chunks_content |
| 318 | + ] |
| 319 | + |
| 320 | + for mock_chunk in mock_chunks: |
| 321 | + await asyncio.sleep(0.1) |
| 322 | + yield StreamingChatCompletionsUpdate( |
| 323 | + id="id", |
| 324 | + choices=[mock_chunk], |
| 325 | + created=datetime.now(), |
| 326 | + model="model", |
| 327 | + usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), |
| 328 | + ) |
| 329 | + |
| 330 | + async def _mock_create_r1_content( |
| 331 | + *args: Any, **kwargs: Any |
| 332 | + ) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]: |
| 333 | + stream = kwargs.get("stream", False) |
| 334 | + |
| 335 | + if not stream: |
| 336 | + await asyncio.sleep(0.1) |
| 337 | + return ChatCompletions( |
| 338 | + id="id", |
| 339 | + created=datetime.now(), |
| 340 | + model="model", |
| 341 | + choices=[ |
| 342 | + ChatChoice( |
| 343 | + index=0, |
| 344 | + finish_reason="stop", |
| 345 | + message=ChatResponseMessage(content="<think>Thought</think> Hello", role="assistant"), |
| 346 | + ) |
| 347 | + ], |
| 348 | + usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), |
| 349 | + ) |
| 350 | + else: |
| 351 | + return _mock_create_r1_content_stream(*args, **kwargs) |
| 352 | + |
| 353 | + monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create_r1_content) |
| 354 | + |
| 355 | + client = AzureAIChatCompletionClient( |
| 356 | + endpoint="endpoint", |
| 357 | + credential=AzureKeyCredential("api_key"), |
| 358 | + model_info={ |
| 359 | + "json_output": False, |
| 360 | + "function_calling": False, |
| 361 | + "vision": True, |
| 362 | + "family": ModelFamily.R1, |
| 363 | + }, |
| 364 | + model="model", |
| 365 | + ) |
| 366 | + |
| 367 | + result = await client.create(messages=[UserMessage(content="Hello", source="user")]) |
| 368 | + assert result.content == "Hello" |
| 369 | + assert result.thought == "Thought" |
| 370 | + |
| 371 | + chunks: List[str | CreateResult] = [] |
| 372 | + async for chunk in client.create_stream(messages=[UserMessage(content="Hello", source="user")]): |
| 373 | + chunks.append(chunk) |
| 374 | + assert isinstance(chunks[-1], CreateResult) |
| 375 | + assert chunks[-1].content == "Hello Another Hello Yet Another Hello" |
| 376 | + assert chunks[-1].thought == "Thought" |
0 commit comments