From f656ff1e011738b73d1c9dd0a4353aea5c6bacb8 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 30 Jan 2025 11:03:54 -0800 Subject: [PATCH] feat: Support R1 reasoning text in model create result; enhance API docs (#5262) Resolves #5255 --------- Co-authored-by: afourney --- .../src/autogen_core/models/_model_client.py | 3 +- .../src/autogen_core/models/_types.py | 35 +++ python/packages/autogen-ext/pyproject.toml | 1 + .../models/_utils/parse_r1_content.py | 33 +++ .../models/azure/_azure_ai_client.py | 15 ++ .../models/openai/_openai_client.py | 31 ++- .../_sk_chat_completion_adapter.py | 15 ++ .../models/test_azure_ai_model_client.py | 81 ++++++- .../tests/models/test_openai_model_client.py | 214 ++++++++++++++++++ .../models/test_sk_chat_completion_adapter.py | 72 ++++++ .../autogen-ext/tests/models/test_utils.py | 43 ++++ python/uv.lock | 2 + 12 files changed, 536 insertions(+), 9 deletions(-) create mode 100644 python/packages/autogen-ext/src/autogen_ext/models/_utils/parse_r1_content.py create mode 100644 python/packages/autogen-ext/tests/models/test_utils.py diff --git a/python/packages/autogen-core/src/autogen_core/models/_model_client.py b/python/packages/autogen-core/src/autogen_core/models/_model_client.py index 356fad5487c..62c26c51a67 100644 --- a/python/packages/autogen-core/src/autogen_core/models/_model_client.py +++ b/python/packages/autogen-core/src/autogen_core/models/_model_client.py @@ -22,9 +22,10 @@ class ModelFamily: O1 = "o1" GPT_4 = "gpt-4" GPT_35 = "gpt-35" + R1 = "r1" UNKNOWN = "unknown" - ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "unknown"] + ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "r1", "unknown"] def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily: raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.") diff --git a/python/packages/autogen-core/src/autogen_core/models/_types.py b/python/packages/autogen-core/src/autogen_core/models/_types.py index a3d6af1edde..239af52bf14 100644 --- a/python/packages/autogen-core/src/autogen_core/models/_types.py +++ b/python/packages/autogen-core/src/autogen_core/models/_types.py @@ -8,11 +8,25 @@ class SystemMessage(BaseModel): + """System message contains instructions for the model coming from the developer. + + .. note:: + + Open AI is moving away from using 'system' role in favor of 'developer' role. + See `Model Spec `_ for more details. + However, the 'system' role is still allowed in their API and will be automatically converted to 'developer' role + on the server side. + So, you can use `SystemMessage` for developer messages. + + """ + content: str type: Literal["SystemMessage"] = "SystemMessage" class UserMessage(BaseModel): + """User message contains input from end users, or a catch-all for data provided to the model.""" + content: Union[str, List[Union[str, Image]]] # Name of the agent that sent this message @@ -22,6 +36,8 @@ class UserMessage(BaseModel): class AssistantMessage(BaseModel): + """Assistant message are sampled from the language model.""" + content: Union[str, List[FunctionCall]] # Name of the agent that sent this message @@ -31,11 +47,15 @@ class AssistantMessage(BaseModel): class FunctionExecutionResult(BaseModel): + """Function execution result contains the output of a function call.""" + content: str call_id: str class FunctionExecutionResultMessage(BaseModel): + """Function execution result message contains the output of multiple function calls.""" + content: List[FunctionExecutionResult] type: Literal["FunctionExecutionResultMessage"] = "FunctionExecutionResultMessage" @@ -69,8 +89,23 @@ class ChatCompletionTokenLogprob(BaseModel): class CreateResult(BaseModel): + """Create result contains the output of a model completion.""" + finish_reason: FinishReasons + """The reason the model finished generating the completion.""" + content: Union[str, List[FunctionCall]] + """The output of the model completion.""" + usage: RequestUsage + """The usage of tokens in the prompt and completion.""" + cached: bool + """Whether the completion was generated from a cached response.""" + logprobs: Optional[List[ChatCompletionTokenLogprob] | None] = None + """The logprobs of the tokens in the completion.""" + + thought: Optional[str] = None + """The reasoning text for the completion if available. Used for reasoning models + and additional text content besides function calls.""" diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml index 060b80327b0..e8a07bd998e 100644 --- a/python/packages/autogen-ext/pyproject.toml +++ b/python/packages/autogen-ext/pyproject.toml @@ -120,6 +120,7 @@ dev = [ "autogen_test_utils", "langchain-experimental", "pandas-stubs>=2.2.3.241126", + "httpx>=0.28.1", ] [tool.ruff] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/_utils/parse_r1_content.py b/python/packages/autogen-ext/src/autogen_ext/models/_utils/parse_r1_content.py new file mode 100644 index 00000000000..6b31c361f00 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/_utils/parse_r1_content.py @@ -0,0 +1,33 @@ +import warnings +from typing import Tuple + + +def parse_r1_content(content: str) -> Tuple[str | None, str]: + """Parse the content of an R1-style message that contains a `...` field.""" + # Find the start and end of the think field + think_start = content.find("") + think_end = content.find("") + + if think_start == -1 or think_end == -1: + warnings.warn( + "Could not find .. field in model response content. " "No thought was extracted.", + UserWarning, + stacklevel=2, + ) + return None, content + + if think_end < think_start: + warnings.warn( + "Found before in model response content. " "No thought was extracted.", + UserWarning, + stacklevel=2, + ) + return None, content + + # Extract the think field + thought = content[think_start + len("") : think_end].strip() + + # Extract the rest of the content, skipping the think field. + content = content[think_end + len("") :].strip() + + return thought, content diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py index bb0af506b65..4d8d5eb5063 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py @@ -12,6 +12,7 @@ FinishReasons, FunctionExecutionResultMessage, LLMMessage, + ModelFamily, ModelInfo, RequestUsage, SystemMessage, @@ -55,6 +56,8 @@ AzureAIChatCompletionClientConfig, ) +from .._utils.parse_r1_content import parse_r1_content + create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs) AzureMessage = Union[AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage] @@ -354,11 +357,17 @@ async def create( finish_reason = choice.finish_reason # type: ignore content = choice.message.content or "" + if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1: + thought, content = parse_r1_content(content) + else: + thought = None + response = CreateResult( finish_reason=finish_reason, # type: ignore content=content, usage=usage, cached=False, + thought=thought, ) self.add_usage(usage) @@ -464,11 +473,17 @@ async def create_stream( prompt_tokens=prompt_tokens, ) + if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1: + thought, content = parse_r1_content(content) + else: + thought = None + result = CreateResult( finish_reason=finish_reason, content=content, usage=usage, cached=False, + thought=thought, ) self.add_usage(usage) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index c743debf992..ca652d6580c 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -72,6 +72,7 @@ from pydantic import BaseModel from typing_extensions import Self, Unpack +from .._utils.parse_r1_content import parse_r1_content from . import _model_info from .config import ( AzureOpenAIClientConfiguration, @@ -605,12 +606,19 @@ async def create( ) for x in choice.logprobs.content ] + + if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1: + thought, content = parse_r1_content(content) + else: + thought = None + response = CreateResult( finish_reason=normalize_stop_reason(finish_reason), content=content, usage=usage, cached=False, logprobs=logprobs, + thought=thought, ) self._total_usage = _add_usage(self._total_usage, usage) @@ -818,12 +826,18 @@ async def create_stream( completion_tokens=completion_tokens, ) + if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1: + thought, content = parse_r1_content(content) + else: + thought = None + result = CreateResult( finish_reason=normalize_stop_reason(stop_reason), content=content, usage=usage, cached=False, logprobs=logprobs, + thought=thought, ) self._total_usage = _add_usage(self._total_usage, usage) @@ -992,20 +1006,23 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA print(result) - To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model capabilities: + To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model info. + For example, to use Ollama, you can use the following code snippet: .. code-block:: python from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_core.models import ModelFamily custom_model_client = OpenAIChatCompletionClient( - model="custom-model-name", - base_url="https://custom-model.com/reset/of/the/path", + model="deepseek-r1:1.5b", + base_url="http://localhost:11434/v1", api_key="placeholder", - model_capabilities={ - "vision": True, - "function_calling": True, - "json_output": True, + model_info={ + "vision": False, + "function_calling": False, + "json_output": False, + "family": ModelFamily.R1, }, ) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py index fbc76b83627..c6a80a35871 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py @@ -25,6 +25,8 @@ from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool +from .._utils.parse_r1_content import parse_r1_content + class SKChatCompletionAdapter(ChatCompletionClient): """ @@ -402,11 +404,17 @@ async def create( content = result[0].content finish_reason = "stop" + if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1: + thought, content = parse_r1_content(content) + else: + thought = None + return CreateResult( content=content, finish_reason=finish_reason, usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), cached=False, + thought=thought, ) async def create_stream( @@ -485,11 +493,18 @@ async def create_stream( if accumulated_content: self._total_prompt_tokens += prompt_tokens self._total_completion_tokens += completion_tokens + + if isinstance(accumulated_content, str) and self._model_info["family"] == ModelFamily.R1: + thought, accumulated_content = parse_r1_content(accumulated_content) + else: + thought = None + yield CreateResult( content=accumulated_content, finish_reason="stop", usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), cached=False, + thought=thought, ) def actual_usage(self) -> RequestUsage: diff --git a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py index 0e510f93e87..d2662a0a270 100644 --- a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py @@ -5,7 +5,7 @@ import pytest from autogen_core import CancellationToken, FunctionCall, Image -from autogen_core.models import CreateResult, UserMessage +from autogen_core.models import CreateResult, ModelFamily, UserMessage from autogen_ext.models.azure import AzureAIChatCompletionClient from azure.ai.inference.aio import ( ChatCompletionsClient, @@ -295,3 +295,82 @@ async def _mock_create_noop(*args: Any, **kwargs: Any) -> ChatCompletions: ] ) assert result.content == "Handled image" + + +@pytest.mark.asyncio +async def test_r1_content(monkeypatch: pytest.MonkeyPatch) -> None: + """ + Ensures that the content is parsed correctly when it contains an R1-style think field. + """ + + async def _mock_create_r1_content_stream( + *args: Any, **kwargs: Any + ) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]: + mock_chunks_content = ["Thought Hello", " Another Hello", " Yet Another Hello"] + + mock_chunks = [ + StreamingChatChoiceUpdate( + index=0, + finish_reason="stop", + delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content), + ) + for chunk_content in mock_chunks_content + ] + + for mock_chunk in mock_chunks: + await asyncio.sleep(0.1) + yield StreamingChatCompletionsUpdate( + id="id", + choices=[mock_chunk], + created=datetime.now(), + model="model", + usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + + async def _mock_create_r1_content( + *args: Any, **kwargs: Any + ) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]: + stream = kwargs.get("stream", False) + + if not stream: + await asyncio.sleep(0.1) + return ChatCompletions( + id="id", + created=datetime.now(), + model="model", + choices=[ + ChatChoice( + index=0, + finish_reason="stop", + message=ChatResponseMessage(content="Thought Hello", role="assistant"), + ) + ], + usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + else: + return _mock_create_r1_content_stream(*args, **kwargs) + + monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create_r1_content) + + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_info={ + "json_output": False, + "function_calling": False, + "vision": True, + "family": ModelFamily.R1, + }, + model="model", + ) + + result = await client.create(messages=[UserMessage(content="Hello", source="user")]) + assert result.content == "Hello" + assert result.thought == "Thought" + + chunks: List[str | CreateResult] = [] + async for chunk in client.create_stream(messages=[UserMessage(content="Hello", source="user")]): + chunks.append(chunk) + assert isinstance(chunks[-1], CreateResult) + assert chunks[-1].content == "Hello Another Hello Yet Another Hello" + assert chunks[-1].thought == "Thought" diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index ab67917998a..5f8d39af815 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -4,6 +4,7 @@ from typing import Annotated, Any, AsyncGenerator, Dict, Generic, List, Literal, Tuple, TypeVar from unittest.mock import MagicMock +import httpx import pytest from autogen_core import CancellationToken, FunctionCall, Image from autogen_core.models import ( @@ -12,6 +13,7 @@ FunctionExecutionResult, FunctionExecutionResultMessage, LLMMessage, + ModelInfo, RequestUsage, SystemMessage, UserMessage, @@ -468,6 +470,154 @@ class AgentResponse(BaseModel): assert response.response == "happy" +@pytest.mark.asyncio +async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None: + async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]: + chunks = [" Hello", " Another Hello", " Yet Another Hello"] + for i, chunk in enumerate(chunks): + await asyncio.sleep(0.1) + yield ChatCompletionChunk( + id="id", + choices=[ + ChunkChoice( + finish_reason="stop" if i == len(chunks) - 1 else None, + index=0, + delta=ChoiceDelta( + content=chunk, + role="assistant", + ), + ), + ], + created=0, + model="r1", + object="chat.completion.chunk", + usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + + async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]: + stream = kwargs.get("stream", False) + if not stream: + await asyncio.sleep(0.1) + return ChatCompletion( + id="id", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content=" Hello Another Hello Yet Another Hello", role="assistant" + ), + ) + ], + created=0, + model="r1", + object="chat.completion", + usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + else: + return _mock_create_stream(*args, **kwargs) + + monkeypatch.setattr(AsyncCompletions, "create", _mock_create) + + model_client = OpenAIChatCompletionClient( + model="r1", + api_key="", + model_info={"family": ModelFamily.R1, "vision": False, "function_calling": False, "json_output": False}, + ) + + # Successful completion with think field. + create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")]) + assert create_result.content == "Another Hello Yet Another Hello" + assert create_result.finish_reason == "stop" + assert not create_result.cached + assert create_result.thought == "Hello" + + # Stream completion with think field. + chunks: List[str | CreateResult] = [] + async for chunk in model_client.create_stream(messages=[UserMessage(content="Hello", source="user")]): + chunks.append(chunk) + assert len(chunks) > 0 + assert isinstance(chunks[-1], CreateResult) + assert chunks[-1].content == "Another Hello Yet Another Hello" + assert chunks[-1].thought == "Hello" + assert not chunks[-1].cached + + +@pytest.mark.asyncio +async def test_r1_think_field_not_present(monkeypatch: pytest.MonkeyPatch) -> None: + async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]: + chunks = ["Hello", " Another Hello", " Yet Another Hello"] + for i, chunk in enumerate(chunks): + await asyncio.sleep(0.1) + yield ChatCompletionChunk( + id="id", + choices=[ + ChunkChoice( + finish_reason="stop" if i == len(chunks) - 1 else None, + index=0, + delta=ChoiceDelta( + content=chunk, + role="assistant", + ), + ), + ], + created=0, + model="r1", + object="chat.completion.chunk", + usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + + async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]: + stream = kwargs.get("stream", False) + if not stream: + await asyncio.sleep(0.1) + return ChatCompletion( + id="id", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content="Hello Another Hello Yet Another Hello", role="assistant" + ), + ) + ], + created=0, + model="r1", + object="chat.completion", + usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + else: + return _mock_create_stream(*args, **kwargs) + + monkeypatch.setattr(AsyncCompletions, "create", _mock_create) + + model_client = OpenAIChatCompletionClient( + model="r1", + api_key="", + model_info={"family": ModelFamily.R1, "vision": False, "function_calling": False, "json_output": False}, + ) + + # Warning completion when think field is not present. + with pytest.warns(UserWarning, match="Could not find .. field in model response content."): + create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")]) + assert create_result.content == "Hello Another Hello Yet Another Hello" + assert create_result.finish_reason == "stop" + assert not create_result.cached + assert create_result.thought is None + + # Stream completion with think field. + with pytest.warns(UserWarning, match="Could not find .. field in model response content."): + chunks: List[str | CreateResult] = [] + async for chunk in model_client.create_stream(messages=[UserMessage(content="Hello", source="user")]): + chunks.append(chunk) + assert len(chunks) > 0 + assert isinstance(chunks[-1], CreateResult) + assert chunks[-1].content == "Hello Another Hello Yet Another Hello" + assert chunks[-1].thought is None + assert not chunks[-1].cached + + @pytest.mark.asyncio async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None: model = "gpt-4o-2024-05-13" @@ -836,4 +986,68 @@ async def test_hugging_face() -> None: await _test_model_client_basic_completion(model_client) +@pytest.mark.asyncio +async def test_ollama() -> None: + model = "deepseek-r1:1.5b" + model_info: ModelInfo = { + "function_calling": False, + "json_output": False, + "vision": False, + "family": ModelFamily.R1, + } + # Check if the model is running locally. + try: + async with httpx.AsyncClient() as client: + response = await client.get(f"http://localhost:11434/v1/models/{model}") + response.raise_for_status() + except httpx.HTTPStatusError as e: + pytest.skip(f"{model} model is not running locally: {e}") + except httpx.ConnectError as e: + pytest.skip(f"Ollama is not running locally: {e}") + + model_client = OpenAIChatCompletionClient( + model=model, + api_key="placeholder", + base_url="http://localhost:11434/v1", + model_info=model_info, + ) + + # Test basic completion with the Ollama deepseek-r1:1.5b model. + create_result = await model_client.create( + messages=[ + UserMessage( + content="Taking two balls from a bag of 10 green balls and 20 red balls, " + "what is the probability of getting a green and a red balls?", + source="user", + ), + ] + ) + assert isinstance(create_result.content, str) + assert len(create_result.content) > 0 + assert create_result.finish_reason == "stop" + assert create_result.usage is not None + if model_info["family"] == ModelFamily.R1: + assert create_result.thought is not None + + # Test streaming completion with the Ollama deepseek-r1:1.5b model. + chunks: List[str | CreateResult] = [] + async for chunk in model_client.create_stream( + messages=[ + UserMessage( + content="Taking two balls from a bag of 10 green balls and 20 red balls, " + "what is the probability of getting a green and a red balls?", + source="user", + ), + ] + ): + chunks.append(chunk) + assert len(chunks) > 0 + assert isinstance(chunks[-1], CreateResult) + assert chunks[-1].finish_reason == "stop" + assert len(chunks[-1].content) > 0 + assert chunks[-1].usage is not None + if model_info["family"] == ModelFamily.R1: + assert chunks[-1].thought is not None + + # TODO: add integration tests for Azure OpenAI using AAD token. diff --git a/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py b/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py index 69be349c38c..71b74a9a0cb 100644 --- a/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py +++ b/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py @@ -377,3 +377,75 @@ async def test_sk_chat_completion_custom_model_info(sk_client: AzureChatCompleti # Verify capabilities returns the same ModelInfo assert adapter.capabilities == adapter.model_info + + +@pytest.mark.asyncio +async def test_sk_chat_completion_r1_content() -> None: + async def mock_get_chat_message_contents( + chat_history: ChatHistory, + settings: PromptExecutionSettings, + **kwargs: Any, + ) -> list[ChatMessageContent]: + return [ + ChatMessageContent( + ai_model_id="r1", + role=AuthorRole.ASSISTANT, + metadata={"usage": {"prompt_tokens": 20, "completion_tokens": 9}}, + items=[TextContent(text="Reasoning... Hello!")], + finish_reason=FinishReason.STOP, + ) + ] + + async def mock_get_streaming_chat_message_contents( + chat_history: ChatHistory, + settings: PromptExecutionSettings, + **kwargs: Any, + ) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]: + chunks = ["Reasoning...", " Hello!"] + for i, chunk in enumerate(chunks): + yield [ + StreamingChatMessageContent( + choice_index=0, + inner_content=ChatCompletionChunk( + id=f"chatcmpl-{i}", + choices=[Choice(delta=ChoiceDelta(content=chunk), finish_reason=None, index=0)], + created=1736674044, + model="r1", + object="chat.completion.chunk", + service_tier="scale", + system_fingerprint="fingerprint", + usage=CompletionUsage(prompt_tokens=20, completion_tokens=9, total_tokens=29), + ), + ai_model_id="gpt-4o-mini", + metadata={"id": f"chatcmpl-{i}", "created": 1736674044}, + role=AuthorRole.ASSISTANT, + items=[StreamingTextContent(choice_index=0, text=chunk)], + finish_reason=FinishReason.STOP if i == len(chunks) - 1 else None, + ) + ] + + mock_client = AsyncMock(spec=AzureChatCompletion) + mock_client.get_chat_message_contents = mock_get_chat_message_contents + mock_client.get_streaming_chat_message_contents = mock_get_streaming_chat_message_contents + + kernel = Kernel(memory=NullMemory()) + + adapter = SKChatCompletionAdapter( + mock_client, + kernel=kernel, + model_info=ModelInfo(vision=False, function_calling=False, json_output=False, family=ModelFamily.R1), + ) + + result = await adapter.create(messages=[UserMessage(content="Say hello!", source="user")]) + assert result.finish_reason == "stop" + assert result.content == "Hello!" + assert result.thought == "Reasoning..." + + response_chunks: list[CreateResult | str] = [] + async for chunk in adapter.create_stream(messages=[UserMessage(content="Say hello!", source="user")]): + response_chunks.append(chunk) + assert len(response_chunks) > 0 + assert isinstance(response_chunks[-1], CreateResult) + assert response_chunks[-1].finish_reason == "stop" + assert response_chunks[-1].content == "Hello!" + assert response_chunks[-1].thought == "Reasoning..." diff --git a/python/packages/autogen-ext/tests/models/test_utils.py b/python/packages/autogen-ext/tests/models/test_utils.py new file mode 100644 index 00000000000..dca0fb2ad53 --- /dev/null +++ b/python/packages/autogen-ext/tests/models/test_utils.py @@ -0,0 +1,43 @@ +import pytest +from autogen_ext.models._utils.parse_r1_content import parse_r1_content + + +def test_parse_r1_content() -> None: + content = "Hello, world How are you?" + thought, content = parse_r1_content(content) + assert thought == "world" + assert content == "How are you?" + + with pytest.warns( + UserWarning, + match="Could not find .. field in model response content. " "No thought was extracted.", + ): + content = "Hello, world How are you?" + thought, content = parse_r1_content(content) + assert thought is None + assert content == "Hello, world How are you?" + + with pytest.warns( + UserWarning, + match="Could not find .. field in model response content. " "No thought was extracted.", + ): + content = "Hello, world How are you?" + thought, content = parse_r1_content(content) + assert thought is None + assert content == "Hello, world How are you?" + + with pytest.warns( + UserWarning, match="Found before in model response content. " "No thought was extracted." + ): + content = "Hello, world" + thought, content = parse_r1_content(content) + assert thought is None + assert content == "Hello, world" + + with pytest.warns( + UserWarning, match="Found before in model response content. " "No thought was extracted." + ): + content = "Hello, world" + thought, content = parse_r1_content(content) + assert thought is None + assert content == "Hello, world" diff --git a/python/uv.lock b/python/uv.lock index 3c07e7e246c..d8e78aadb64 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -654,6 +654,7 @@ web-surfer = [ [package.dev-dependencies] dev = [ { name = "autogen-test-utils" }, + { name = "httpx" }, { name = "langchain-experimental" }, { name = "pandas-stubs" }, ] @@ -706,6 +707,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "autogen-test-utils", editable = "packages/autogen-test-utils" }, + { name = "httpx", specifier = ">=0.28.1" }, { name = "langchain-experimental" }, { name = "pandas-stubs", specifier = ">=2.2.3.241126" }, ]