Skip to content

Commit ff9e0b7

Browse files
committed
Merge branch 'dotnet_unit_AgentId' of github.com:microsoft/autogen into dotnet_unit_AgentId
2 parents d7df79a + f4adce3 commit ff9e0b7

File tree

12 files changed

+536
-9
lines changed

12 files changed

+536
-9
lines changed

python/packages/autogen-core/src/autogen_core/models/_model_client.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ class ModelFamily:
2222
O1 = "o1"
2323
GPT_4 = "gpt-4"
2424
GPT_35 = "gpt-35"
25+
R1 = "r1"
2526
UNKNOWN = "unknown"
2627

27-
ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "unknown"]
28+
ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "r1", "unknown"]
2829

2930
def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
3031
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")

python/packages/autogen-core/src/autogen_core/models/_types.py

+35
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,25 @@
88

99

1010
class SystemMessage(BaseModel):
11+
"""System message contains instructions for the model coming from the developer.
12+
13+
.. note::
14+
15+
Open AI is moving away from using 'system' role in favor of 'developer' role.
16+
See `Model Spec <https://cdn.openai.com/spec/model-spec-2024-05-08.html#definitions>`_ for more details.
17+
However, the 'system' role is still allowed in their API and will be automatically converted to 'developer' role
18+
on the server side.
19+
So, you can use `SystemMessage` for developer messages.
20+
21+
"""
22+
1123
content: str
1224
type: Literal["SystemMessage"] = "SystemMessage"
1325

1426

1527
class UserMessage(BaseModel):
28+
"""User message contains input from end users, or a catch-all for data provided to the model."""
29+
1630
content: Union[str, List[Union[str, Image]]]
1731

1832
# Name of the agent that sent this message
@@ -22,6 +36,8 @@ class UserMessage(BaseModel):
2236

2337

2438
class AssistantMessage(BaseModel):
39+
"""Assistant message are sampled from the language model."""
40+
2541
content: Union[str, List[FunctionCall]]
2642

2743
# Name of the agent that sent this message
@@ -31,11 +47,15 @@ class AssistantMessage(BaseModel):
3147

3248

3349
class FunctionExecutionResult(BaseModel):
50+
"""Function execution result contains the output of a function call."""
51+
3452
content: str
3553
call_id: str
3654

3755

3856
class FunctionExecutionResultMessage(BaseModel):
57+
"""Function execution result message contains the output of multiple function calls."""
58+
3959
content: List[FunctionExecutionResult]
4060

4161
type: Literal["FunctionExecutionResultMessage"] = "FunctionExecutionResultMessage"
@@ -69,8 +89,23 @@ class ChatCompletionTokenLogprob(BaseModel):
6989

7090

7191
class CreateResult(BaseModel):
92+
"""Create result contains the output of a model completion."""
93+
7294
finish_reason: FinishReasons
95+
"""The reason the model finished generating the completion."""
96+
7397
content: Union[str, List[FunctionCall]]
98+
"""The output of the model completion."""
99+
74100
usage: RequestUsage
101+
"""The usage of tokens in the prompt and completion."""
102+
75103
cached: bool
104+
"""Whether the completion was generated from a cached response."""
105+
76106
logprobs: Optional[List[ChatCompletionTokenLogprob] | None] = None
107+
"""The logprobs of the tokens in the completion."""
108+
109+
thought: Optional[str] = None
110+
"""The reasoning text for the completion if available. Used for reasoning models
111+
and additional text content besides function calls."""

python/packages/autogen-ext/pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ dev = [
120120
"autogen_test_utils",
121121
"langchain-experimental",
122122
"pandas-stubs>=2.2.3.241126",
123+
"httpx>=0.28.1",
123124
]
124125

125126
[tool.ruff]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import warnings
2+
from typing import Tuple
3+
4+
5+
def parse_r1_content(content: str) -> Tuple[str | None, str]:
6+
"""Parse the content of an R1-style message that contains a `<think>...</think>` field."""
7+
# Find the start and end of the think field
8+
think_start = content.find("<think>")
9+
think_end = content.find("</think>")
10+
11+
if think_start == -1 or think_end == -1:
12+
warnings.warn(
13+
"Could not find <think>..</think> field in model response content. " "No thought was extracted.",
14+
UserWarning,
15+
stacklevel=2,
16+
)
17+
return None, content
18+
19+
if think_end < think_start:
20+
warnings.warn(
21+
"Found </think> before <think> in model response content. " "No thought was extracted.",
22+
UserWarning,
23+
stacklevel=2,
24+
)
25+
return None, content
26+
27+
# Extract the think field
28+
thought = content[think_start + len("<think>") : think_end].strip()
29+
30+
# Extract the rest of the content, skipping the think field.
31+
content = content[think_end + len("</think>") :].strip()
32+
33+
return thought, content

python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py

+15
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
FinishReasons,
1313
FunctionExecutionResultMessage,
1414
LLMMessage,
15+
ModelFamily,
1516
ModelInfo,
1617
RequestUsage,
1718
SystemMessage,
@@ -55,6 +56,8 @@
5556
AzureAIChatCompletionClientConfig,
5657
)
5758

59+
from .._utils.parse_r1_content import parse_r1_content
60+
5861
create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)
5962
AzureMessage = Union[AzureSystemMessage, AzureUserMessage, AzureAssistantMessage, AzureToolMessage]
6063

@@ -354,11 +357,17 @@ async def create(
354357
finish_reason = choice.finish_reason # type: ignore
355358
content = choice.message.content or ""
356359

360+
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
361+
thought, content = parse_r1_content(content)
362+
else:
363+
thought = None
364+
357365
response = CreateResult(
358366
finish_reason=finish_reason, # type: ignore
359367
content=content,
360368
usage=usage,
361369
cached=False,
370+
thought=thought,
362371
)
363372

364373
self.add_usage(usage)
@@ -464,11 +473,17 @@ async def create_stream(
464473
prompt_tokens=prompt_tokens,
465474
)
466475

476+
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
477+
thought, content = parse_r1_content(content)
478+
else:
479+
thought = None
480+
467481
result = CreateResult(
468482
finish_reason=finish_reason,
469483
content=content,
470484
usage=usage,
471485
cached=False,
486+
thought=thought,
472487
)
473488

474489
self.add_usage(usage)

python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from pydantic import BaseModel
7373
from typing_extensions import Self, Unpack
7474

75+
from .._utils.parse_r1_content import parse_r1_content
7576
from . import _model_info
7677
from .config import (
7778
AzureOpenAIClientConfiguration,
@@ -605,12 +606,19 @@ async def create(
605606
)
606607
for x in choice.logprobs.content
607608
]
609+
610+
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
611+
thought, content = parse_r1_content(content)
612+
else:
613+
thought = None
614+
608615
response = CreateResult(
609616
finish_reason=normalize_stop_reason(finish_reason),
610617
content=content,
611618
usage=usage,
612619
cached=False,
613620
logprobs=logprobs,
621+
thought=thought,
614622
)
615623

616624
self._total_usage = _add_usage(self._total_usage, usage)
@@ -818,12 +826,18 @@ async def create_stream(
818826
completion_tokens=completion_tokens,
819827
)
820828

829+
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
830+
thought, content = parse_r1_content(content)
831+
else:
832+
thought = None
833+
821834
result = CreateResult(
822835
finish_reason=normalize_stop_reason(stop_reason),
823836
content=content,
824837
usage=usage,
825838
cached=False,
826839
logprobs=logprobs,
840+
thought=thought,
827841
)
828842

829843
self._total_usage = _add_usage(self._total_usage, usage)
@@ -992,20 +1006,23 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
9921006
print(result)
9931007
9941008
995-
To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model capabilities:
1009+
To use the client with a non-OpenAI model, you need to provide the base URL of the model and the model info.
1010+
For example, to use Ollama, you can use the following code snippet:
9961011
9971012
.. code-block:: python
9981013
9991014
from autogen_ext.models.openai import OpenAIChatCompletionClient
1015+
from autogen_core.models import ModelFamily
10001016
10011017
custom_model_client = OpenAIChatCompletionClient(
1002-
model="custom-model-name",
1003-
base_url="https://custom-model.com/reset/of/the/path",
1018+
model="deepseek-r1:1.5b",
1019+
base_url="http://localhost:11434/v1",
10041020
api_key="placeholder",
1005-
model_capabilities={
1006-
"vision": True,
1007-
"function_calling": True,
1008-
"json_output": True,
1021+
model_info={
1022+
"vision": False,
1023+
"function_calling": False,
1024+
"json_output": False,
1025+
"family": ModelFamily.R1,
10091026
},
10101027
)
10111028

python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py

+15
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool
2727

28+
from .._utils.parse_r1_content import parse_r1_content
29+
2830

2931
class SKChatCompletionAdapter(ChatCompletionClient):
3032
"""
@@ -402,11 +404,17 @@ async def create(
402404
content = result[0].content
403405
finish_reason = "stop"
404406

407+
if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
408+
thought, content = parse_r1_content(content)
409+
else:
410+
thought = None
411+
405412
return CreateResult(
406413
content=content,
407414
finish_reason=finish_reason,
408415
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
409416
cached=False,
417+
thought=thought,
410418
)
411419

412420
async def create_stream(
@@ -485,11 +493,18 @@ async def create_stream(
485493
if accumulated_content:
486494
self._total_prompt_tokens += prompt_tokens
487495
self._total_completion_tokens += completion_tokens
496+
497+
if isinstance(accumulated_content, str) and self._model_info["family"] == ModelFamily.R1:
498+
thought, accumulated_content = parse_r1_content(accumulated_content)
499+
else:
500+
thought = None
501+
488502
yield CreateResult(
489503
content=accumulated_content,
490504
finish_reason="stop",
491505
usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
492506
cached=False,
507+
thought=thought,
493508
)
494509

495510
def actual_usage(self) -> RequestUsage:

python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77
from autogen_core import CancellationToken, FunctionCall, Image
8-
from autogen_core.models import CreateResult, UserMessage
8+
from autogen_core.models import CreateResult, ModelFamily, UserMessage
99
from autogen_ext.models.azure import AzureAIChatCompletionClient
1010
from azure.ai.inference.aio import (
1111
ChatCompletionsClient,
@@ -295,3 +295,82 @@ async def _mock_create_noop(*args: Any, **kwargs: Any) -> ChatCompletions:
295295
]
296296
)
297297
assert result.content == "Handled image"
298+
299+
300+
@pytest.mark.asyncio
301+
async def test_r1_content(monkeypatch: pytest.MonkeyPatch) -> None:
302+
"""
303+
Ensures that the content is parsed correctly when it contains an R1-style think field.
304+
"""
305+
306+
async def _mock_create_r1_content_stream(
307+
*args: Any, **kwargs: Any
308+
) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
309+
mock_chunks_content = ["<think>Thought</think> Hello", " Another Hello", " Yet Another Hello"]
310+
311+
mock_chunks = [
312+
StreamingChatChoiceUpdate(
313+
index=0,
314+
finish_reason="stop",
315+
delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content),
316+
)
317+
for chunk_content in mock_chunks_content
318+
]
319+
320+
for mock_chunk in mock_chunks:
321+
await asyncio.sleep(0.1)
322+
yield StreamingChatCompletionsUpdate(
323+
id="id",
324+
choices=[mock_chunk],
325+
created=datetime.now(),
326+
model="model",
327+
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
328+
)
329+
330+
async def _mock_create_r1_content(
331+
*args: Any, **kwargs: Any
332+
) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]:
333+
stream = kwargs.get("stream", False)
334+
335+
if not stream:
336+
await asyncio.sleep(0.1)
337+
return ChatCompletions(
338+
id="id",
339+
created=datetime.now(),
340+
model="model",
341+
choices=[
342+
ChatChoice(
343+
index=0,
344+
finish_reason="stop",
345+
message=ChatResponseMessage(content="<think>Thought</think> Hello", role="assistant"),
346+
)
347+
],
348+
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
349+
)
350+
else:
351+
return _mock_create_r1_content_stream(*args, **kwargs)
352+
353+
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create_r1_content)
354+
355+
client = AzureAIChatCompletionClient(
356+
endpoint="endpoint",
357+
credential=AzureKeyCredential("api_key"),
358+
model_info={
359+
"json_output": False,
360+
"function_calling": False,
361+
"vision": True,
362+
"family": ModelFamily.R1,
363+
},
364+
model="model",
365+
)
366+
367+
result = await client.create(messages=[UserMessage(content="Hello", source="user")])
368+
assert result.content == "Hello"
369+
assert result.thought == "Thought"
370+
371+
chunks: List[str | CreateResult] = []
372+
async for chunk in client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
373+
chunks.append(chunk)
374+
assert isinstance(chunks[-1], CreateResult)
375+
assert chunks[-1].content == "Hello Another Hello Yet Another Hello"
376+
assert chunks[-1].thought == "Thought"

0 commit comments

Comments
 (0)