Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support R1 reasoning text in model create result; enhance API docs #5262

Merged
merged 11 commits into from
Jan 30, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ class ModelFamily:
O1 = "o1"
GPT_4 = "gpt-4"
GPT_35 = "gpt-35"
R1 = "r1"
UNKNOWN = "unknown"

ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "unknown"]
ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "r1", "unknown"]

def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")
Expand Down
35 changes: 35 additions & 0 deletions python/packages/autogen-core/src/autogen_core/models/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,25 @@


class SystemMessage(BaseModel):
"""System message contains instructions for the model coming from the developer.

.. note::

Open AI is moving away from using 'system' role in favor of 'developer' role.
See `Model Spec <https://cdn.openai.com/spec/model-spec-2024-05-08.html#definitions>`_ for more details.
However, the 'system' role is still allowed in their API and will be automatically converted to 'developer' role
on the server side.
So, you can use `SystemMessage` for developer messages.

"""

content: str
type: Literal["SystemMessage"] = "SystemMessage"


class UserMessage(BaseModel):
"""User message contains input from end users, or a catch-all for data provided to the model."""

content: Union[str, List[Union[str, Image]]]

# Name of the agent that sent this message
Expand All @@ -22,6 +36,8 @@ class UserMessage(BaseModel):


class AssistantMessage(BaseModel):
"""Assistant message are sampled from the language model."""

content: Union[str, List[FunctionCall]]

# Name of the agent that sent this message
Expand All @@ -31,11 +47,15 @@ class AssistantMessage(BaseModel):


class FunctionExecutionResult(BaseModel):
"""Function execution result contains the output of a function call."""

content: str
call_id: str


class FunctionExecutionResultMessage(BaseModel):
"""Function execution result message contains the output of multiple function calls."""

content: List[FunctionExecutionResult]

type: Literal["FunctionExecutionResultMessage"] = "FunctionExecutionResultMessage"
Expand Down Expand Up @@ -69,8 +89,23 @@ class ChatCompletionTokenLogprob(BaseModel):


class CreateResult(BaseModel):
"""Create result contains the output of a model completion."""

finish_reason: FinishReasons
"""The reason the model finished generating the completion."""

content: Union[str, List[FunctionCall]]
"""The output of the model completion."""

usage: RequestUsage
"""The usage of tokens in the prompt and completion."""

cached: bool
"""Whether the completion was generated from a cached response."""

logprobs: Optional[List[ChatCompletionTokenLogprob] | None] = None
"""The logprobs of the tokens in the completion."""

thought: Optional[str] = None
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
"""The reasoning text for the completion if available. Used for reasoning models
and additional text content besides function calls."""
1 change: 1 addition & 0 deletions python/packages/autogen-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ dev = [
"autogen_test_utils",
"langchain-experimental",
"pandas-stubs>=2.2.3.241126",
"httpx>=0.28.1",
]

[tool.ruff]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import warnings
from typing import Tuple


def parse_r1_content(content: str) -> Tuple[str | None, str]:
"""Parse the content of an R1-style message that contains a `<think>...</think>` field."""
# Find the start and end of the think field
think_start = content.find("<think>")
think_end = content.find("</think>")

# Extract the think field
if think_start == -1 or think_end == -1:
warnings.warn("Could not find <think>..</think> field in model response content.", UserWarning, stacklevel=2)
return None, content

thought = content[think_start + len("<think>") : think_end].strip()

# Extract the rest of the content, skipping the think field.
content = content[think_end + len("</think>") :].strip()

return thought, content
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
FinishReasons,
FunctionExecutionResultMessage,
LLMMessage,
ModelFamily,
ModelInfo,
RequestUsage,
SystemMessage,
Expand Down Expand Up @@ -55,6 +56,8 @@
AzureAIChatCompletionClientConfig,
)

from .._utils.parse_r1_content import parse_r1_content

create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)
AzureMessage = Union[AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage]

Expand Down Expand Up @@ -354,11 +357,17 @@ async def create(
finish_reason = choice.finish_reason # type: ignore
content = choice.message.content or ""

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

response = CreateResult(
finish_reason=finish_reason, # type: ignore
content=content,
usage=usage,
cached=False,
thought=thought,
)

self.add_usage(usage)
Expand Down Expand Up @@ -464,11 +473,17 @@ async def create_stream(
prompt_tokens=prompt_tokens,
)

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

result = CreateResult(
finish_reason=finish_reason,
content=content,
usage=usage,
cached=False,
thought=thought,
)

self.add_usage(usage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from pydantic import BaseModel
from typing_extensions import Self, Unpack

from .._utils.parse_r1_content import parse_r1_content
from . import _model_info
from .config import (
AzureOpenAIClientConfiguration,
Expand Down Expand Up @@ -595,12 +596,19 @@ async def create(
)
for x in choice.logprobs.content
]

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

response = CreateResult(
finish_reason=normalize_stop_reason(finish_reason),
content=content,
usage=usage,
cached=False,
logprobs=logprobs,
thought=thought,
)

self._total_usage = _add_usage(self._total_usage, usage)
Expand Down Expand Up @@ -808,12 +816,18 @@ async def create_stream(
completion_tokens=completion_tokens,
)

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

result = CreateResult(
finish_reason=normalize_stop_reason(stop_reason),
content=content,
usage=usage,
cached=False,
logprobs=logprobs,
thought=thought,
)

self._total_usage = _add_usage(self._total_usage, usage)
Expand Down Expand Up @@ -982,20 +996,23 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
print(result)


To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model capabilities:
To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model info.
For example, to use Ollama, you can use the following code snippet:

.. code-block:: python

from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_core.models import ModelFamily

custom_model_client = OpenAIChatCompletionClient(
model="custom-model-name",
base_url="https://custom-model.com/reset/of/the/path",
model="deepseek-r1:1.5b",
base_url="http://localhost:11434/v1",
api_key="placeholder",
model_capabilities={
"vision": True,
"function_calling": True,
"json_output": True,
model_info={
"vision": False,
"function_calling": False,
"json_output": False,
"family": ModelFamily.R1,
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool

from .._utils.parse_r1_content import parse_r1_content


class SKChatCompletionAdapter(ChatCompletionClient):
"""
Expand Down Expand Up @@ -402,11 +404,17 @@ async def create(
content = result[0].content
finish_reason = "stop"

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

return CreateResult(
content=content,
finish_reason=finish_reason,
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False,
thought=thought,
)

async def create_stream(
Expand Down Expand Up @@ -485,11 +493,18 @@ async def create_stream(
if accumulated_content:
self._total_prompt_tokens += prompt_tokens
self._total_completion_tokens += completion_tokens

if isinstance(accumulated_content, str) and self._model_info["family"] == ModelFamily.R1:
thought, accumulated_content = parse_r1_content(accumulated_content)
else:
thought = None

yield CreateResult(
content=accumulated_content,
finish_reason="stop",
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False,
thought=thought,
)

def actual_usage(self) -> RequestUsage:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest
from autogen_core import CancellationToken, FunctionCall, Image
from autogen_core.models import CreateResult, UserMessage
from autogen_core.models import CreateResult, ModelFamily, UserMessage
from autogen_ext.models.azure import AzureAIChatCompletionClient
from azure.ai.inference.aio import (
ChatCompletionsClient,
Expand Down Expand Up @@ -295,3 +295,82 @@ async def _mock_create_noop(*args: Any, **kwargs: Any) -> ChatCompletions:
]
)
assert result.content == "Handled image"


@pytest.mark.asyncio
async def test_r1_content(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Ensures that the content is parsed correctly when it contains an R1-style think field.
"""

async def _mock_create_r1_content_stream(
*args: Any, **kwargs: Any
) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
mock_chunks_content = ["<think>Thought</think> Hello", " Another Hello", " Yet Another Hello"]

mock_chunks = [
StreamingChatChoiceUpdate(
index=0,
finish_reason="stop",
delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content),
)
for chunk_content in mock_chunks_content
]

for mock_chunk in mock_chunks:
await asyncio.sleep(0.1)
yield StreamingChatCompletionsUpdate(
id="id",
choices=[mock_chunk],
created=datetime.now(),
model="model",
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)

async def _mock_create_r1_content(
*args: Any, **kwargs: Any
) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]:
stream = kwargs.get("stream", False)

if not stream:
await asyncio.sleep(0.1)
return ChatCompletions(
id="id",
created=datetime.now(),
model="model",
choices=[
ChatChoice(
index=0,
finish_reason="stop",
message=ChatResponseMessage(content="<think>Thought</think> Hello", role="assistant"),
)
],
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
else:
return _mock_create_r1_content_stream(*args, **kwargs)

monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create_r1_content)

client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": False,
"vision": True,
"family": ModelFamily.R1,
},
model="model",
)

result = await client.create(messages=[UserMessage(content="Hello", source="user")])
assert result.content == "Hello"
assert result.thought == "Thought"

chunks: List[str | CreateResult] = []
async for chunk in client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
chunks.append(chunk)
assert isinstance(chunks[-1], CreateResult)
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
assert chunks[-1].thought == "Thought"
Loading
Loading