Skip to content

Commit

Permalink
feat: Support R1 reasoning text in model create result; enhance API d…
Browse files Browse the repository at this point in the history
…ocs (#5262)

Resolves #5255 

---------

Co-authored-by: afourney <[email protected]>
  • Loading branch information
ekzhu and afourney authored Jan 30, 2025
1 parent 44db2cc commit f656ff1
Show file tree
Hide file tree
Showing 12 changed files with 536 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ class ModelFamily:
O1 = "o1"
GPT_4 = "gpt-4"
GPT_35 = "gpt-35"
R1 = "r1"
UNKNOWN = "unknown"

ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "unknown"]
ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "r1", "unknown"]

def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")
Expand Down
35 changes: 35 additions & 0 deletions python/packages/autogen-core/src/autogen_core/models/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,25 @@


class SystemMessage(BaseModel):
"""System message contains instructions for the model coming from the developer.
.. note::
Open AI is moving away from using 'system' role in favor of 'developer' role.
See `Model Spec <https://cdn.openai.com/spec/model-spec-2024-05-08.html#definitions>`_ for more details.
However, the 'system' role is still allowed in their API and will be automatically converted to 'developer' role
on the server side.
So, you can use `SystemMessage` for developer messages.
"""

content: str
type: Literal["SystemMessage"] = "SystemMessage"


class UserMessage(BaseModel):
"""User message contains input from end users, or a catch-all for data provided to the model."""

content: Union[str, List[Union[str, Image]]]

# Name of the agent that sent this message
Expand All @@ -22,6 +36,8 @@ class UserMessage(BaseModel):


class AssistantMessage(BaseModel):
"""Assistant message are sampled from the language model."""

content: Union[str, List[FunctionCall]]

# Name of the agent that sent this message
Expand All @@ -31,11 +47,15 @@ class AssistantMessage(BaseModel):


class FunctionExecutionResult(BaseModel):
"""Function execution result contains the output of a function call."""

content: str
call_id: str


class FunctionExecutionResultMessage(BaseModel):
"""Function execution result message contains the output of multiple function calls."""

content: List[FunctionExecutionResult]

type: Literal["FunctionExecutionResultMessage"] = "FunctionExecutionResultMessage"
Expand Down Expand Up @@ -69,8 +89,23 @@ class ChatCompletionTokenLogprob(BaseModel):


class CreateResult(BaseModel):
"""Create result contains the output of a model completion."""

finish_reason: FinishReasons
"""The reason the model finished generating the completion."""

content: Union[str, List[FunctionCall]]
"""The output of the model completion."""

usage: RequestUsage
"""The usage of tokens in the prompt and completion."""

cached: bool
"""Whether the completion was generated from a cached response."""

logprobs: Optional[List[ChatCompletionTokenLogprob] | None] = None
"""The logprobs of the tokens in the completion."""

thought: Optional[str] = None
"""The reasoning text for the completion if available. Used for reasoning models
and additional text content besides function calls."""
1 change: 1 addition & 0 deletions python/packages/autogen-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ dev = [
"autogen_test_utils",
"langchain-experimental",
"pandas-stubs>=2.2.3.241126",
"httpx>=0.28.1",
]

[tool.ruff]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import warnings
from typing import Tuple


def parse_r1_content(content: str) -> Tuple[str | None, str]:
"""Parse the content of an R1-style message that contains a `<think>...</think>` field."""
# Find the start and end of the think field
think_start = content.find("<think>")
think_end = content.find("</think>")

if think_start == -1 or think_end == -1:
warnings.warn(
"Could not find <think>..</think> field in model response content. " "No thought was extracted.",
UserWarning,
stacklevel=2,
)
return None, content

if think_end < think_start:
warnings.warn(
"Found </think> before <think> in model response content. " "No thought was extracted.",
UserWarning,
stacklevel=2,
)
return None, content

# Extract the think field
thought = content[think_start + len("<think>") : think_end].strip()

# Extract the rest of the content, skipping the think field.
content = content[think_end + len("</think>") :].strip()

return thought, content
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
FinishReasons,
FunctionExecutionResultMessage,
LLMMessage,
ModelFamily,
ModelInfo,
RequestUsage,
SystemMessage,
Expand Down Expand Up @@ -55,6 +56,8 @@
AzureAIChatCompletionClientConfig,
)

from .._utils.parse_r1_content import parse_r1_content

create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)
AzureMessage = Union[AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage]

Expand Down Expand Up @@ -354,11 +357,17 @@ async def create(
finish_reason = choice.finish_reason # type: ignore
content = choice.message.content or ""

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

response = CreateResult(
finish_reason=finish_reason, # type: ignore
content=content,
usage=usage,
cached=False,
thought=thought,
)

self.add_usage(usage)
Expand Down Expand Up @@ -464,11 +473,17 @@ async def create_stream(
prompt_tokens=prompt_tokens,
)

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

result = CreateResult(
finish_reason=finish_reason,
content=content,
usage=usage,
cached=False,
thought=thought,
)

self.add_usage(usage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from pydantic import BaseModel
from typing_extensions import Self, Unpack

from .._utils.parse_r1_content import parse_r1_content
from . import _model_info
from .config import (
AzureOpenAIClientConfiguration,
Expand Down Expand Up @@ -605,12 +606,19 @@ async def create(
)
for x in choice.logprobs.content
]

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

response = CreateResult(
finish_reason=normalize_stop_reason(finish_reason),
content=content,
usage=usage,
cached=False,
logprobs=logprobs,
thought=thought,
)

self._total_usage = _add_usage(self._total_usage, usage)
Expand Down Expand Up @@ -818,12 +826,18 @@ async def create_stream(
completion_tokens=completion_tokens,
)

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

result = CreateResult(
finish_reason=normalize_stop_reason(stop_reason),
content=content,
usage=usage,
cached=False,
logprobs=logprobs,
thought=thought,
)

self._total_usage = _add_usage(self._total_usage, usage)
Expand Down Expand Up @@ -992,20 +1006,23 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
print(result)
To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model capabilities:
To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model info.
For example, to use Ollama, you can use the following code snippet:
.. code-block:: python
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_core.models import ModelFamily
custom_model_client = OpenAIChatCompletionClient(
model="custom-model-name",
base_url="https://custom-model.com/reset/of/the/path",
model="deepseek-r1:1.5b",
base_url="http://localhost:11434/v1",
api_key="placeholder",
model_capabilities={
"vision": True,
"function_calling": True,
"json_output": True,
model_info={
"vision": False,
"function_calling": False,
"json_output": False,
"family": ModelFamily.R1,
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool

from .._utils.parse_r1_content import parse_r1_content


class SKChatCompletionAdapter(ChatCompletionClient):
"""
Expand Down Expand Up @@ -402,11 +404,17 @@ async def create(
content = result[0].content
finish_reason = "stop"

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

return CreateResult(
content=content,
finish_reason=finish_reason,
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False,
thought=thought,
)

async def create_stream(
Expand Down Expand Up @@ -485,11 +493,18 @@ async def create_stream(
if accumulated_content:
self._total_prompt_tokens += prompt_tokens
self._total_completion_tokens += completion_tokens

if isinstance(accumulated_content, str) and self._model_info["family"] == ModelFamily.R1:
thought, accumulated_content = parse_r1_content(accumulated_content)
else:
thought = None

yield CreateResult(
content=accumulated_content,
finish_reason="stop",
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
cached=False,
thought=thought,
)

def actual_usage(self) -> RequestUsage:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest
from autogen_core import CancellationToken, FunctionCall, Image
from autogen_core.models import CreateResult, UserMessage
from autogen_core.models import CreateResult, ModelFamily, UserMessage
from autogen_ext.models.azure import AzureAIChatCompletionClient
from azure.ai.inference.aio import (
ChatCompletionsClient,
Expand Down Expand Up @@ -295,3 +295,82 @@ async def _mock_create_noop(*args: Any, **kwargs: Any) -> ChatCompletions:
]
)
assert result.content == "Handled image"


@pytest.mark.asyncio
async def test_r1_content(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Ensures that the content is parsed correctly when it contains an R1-style think field.
"""

async def _mock_create_r1_content_stream(
*args: Any, **kwargs: Any
) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
mock_chunks_content = ["<think>Thought</think> Hello", " Another Hello", " Yet Another Hello"]

mock_chunks = [
StreamingChatChoiceUpdate(
index=0,
finish_reason="stop",
delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content),
)
for chunk_content in mock_chunks_content
]

for mock_chunk in mock_chunks:
await asyncio.sleep(0.1)
yield StreamingChatCompletionsUpdate(
id="id",
choices=[mock_chunk],
created=datetime.now(),
model="model",
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)

async def _mock_create_r1_content(
*args: Any, **kwargs: Any
) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]:
stream = kwargs.get("stream", False)

if not stream:
await asyncio.sleep(0.1)
return ChatCompletions(
id="id",
created=datetime.now(),
model="model",
choices=[
ChatChoice(
index=0,
finish_reason="stop",
message=ChatResponseMessage(content="<think>Thought</think> Hello", role="assistant"),
)
],
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
else:
return _mock_create_r1_content_stream(*args, **kwargs)

monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create_r1_content)

client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": False,
"vision": True,
"family": ModelFamily.R1,
},
model="model",
)

result = await client.create(messages=[UserMessage(content="Hello", source="user")])
assert result.content == "Hello"
assert result.thought == "Thought"

chunks: List[str | CreateResult] = []
async for chunk in client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
chunks.append(chunk)
assert isinstance(chunks[-1], CreateResult)
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
assert chunks[-1].thought == "Thought"
Loading

0 comments on commit f656ff1

Please sign in to comment.