From 7784f44ea62e4b57c8efdbcf78a93e8877260b33 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Fri, 21 Feb 2025 14:58:32 -0700 Subject: [PATCH] feat: Add thought process handling in tool calls and expose ThoughtEvent through stream in AgentChat (#5500) Resolves #5192 Test ```python import asyncio import os from random import randint from typing import List from autogen_core.tools import BaseTool, FunctionTool from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.ui import Console async def get_current_time(city: str) -> str: return f"The current time in {city} is {randint(0, 23)}:{randint(0, 59)}." tools: List[BaseTool] = [ FunctionTool( get_current_time, name="get_current_time", description="Get current time for a city.", ), ] model_client = OpenAIChatCompletionClient( model="anthropic/claude-3.5-haiku-20241022", base_url="https://openrouter.ai/api/v1", api_key=os.environ["OPENROUTER_API_KEY"], model_info={ "family": "claude-3.5-haiku", "function_calling": True, "vision": False, "json_output": False, } ) agent = AssistantAgent( name="Agent", model_client=model_client, tools=tools, system_message= "You are an assistant with some tools that can be used to answer some questions", ) async def main() -> None: await Console(agent.run_stream(task="What is current time of Paris and Toronto?")) asyncio.run(main()) ``` ``` ---------- user ---------- What is current time of Paris and Toronto? ---------- Agent ---------- I'll help you find the current time for Paris and Toronto by using the get_current_time function for each city. ---------- Agent ---------- [FunctionCall(id='toolu_01NwP3fNAwcYKn1x656Dq9xW', arguments='{"city": "Paris"}', name='get_current_time'), FunctionCall(id='toolu_018d4cWSy3TxXhjgmLYFrfRt', arguments='{"city": "Toronto"}', name='get_current_time')] ---------- Agent ---------- [FunctionExecutionResult(content='The current time in Paris is 1:10.', call_id='toolu_01NwP3fNAwcYKn1x656Dq9xW', is_error=False), FunctionExecutionResult(content='The current time in Toronto is 7:28.', call_id='toolu_018d4cWSy3TxXhjgmLYFrfRt', is_error=False)] ---------- Agent ---------- The current time in Paris is 1:10. The current time in Toronto is 7:28. ``` --------- Co-authored-by: Jack Gerrits --- .../agents/_assistant_agent.py | 19 +++- .../src/autogen_agentchat/messages.py | 15 ++- .../tests/test_assistant_agent.py | 49 +++++---- .../src/autogen_core/models/_types.py | 3 + .../models/openai/_openai_client.py | 58 +++++----- .../tests/models/test_openai_model_client.py | 100 ++++++++++++++++-- 6 files changed, 184 insertions(+), 60 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 3b109c19670c..f6a6ec61a298 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -44,6 +44,7 @@ MemoryQueryEvent, ModelClientStreamingChunkEvent, TextMessage, + ThoughtEvent, ToolCallExecutionEvent, ToolCallRequestEvent, ToolCallSummaryMessage, @@ -418,7 +419,15 @@ async def on_messages_stream( ) # Add the response to the model context. - await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name)) + await self._model_context.add_message( + AssistantMessage(content=model_result.content, source=self.name, thought=model_result.thought) + ) + + # Add thought to the inner messages. + if model_result.thought: + thought_event = ThoughtEvent(content=model_result.thought, source=self.name) + inner_messages.append(thought_event) + yield thought_event # Check if the response is a string and return it. if isinstance(model_result.content, str): @@ -479,7 +488,9 @@ async def on_messages_stream( # Current context for handoff. handoff_context: List[LLMMessage] = [] if len(tool_calls) > 0: - handoff_context.append(AssistantMessage(content=tool_calls, source=self.name)) + handoff_context.append( + AssistantMessage(content=tool_calls, source=self.name, thought=model_result.thought) + ) handoff_context.append(FunctionExecutionResultMessage(content=tool_call_results)) # Return the output messages to signal the handoff. yield Response( @@ -515,7 +526,9 @@ async def on_messages_stream( assert isinstance(reflection_model_result.content, str) # Add the response to the model context. await self._model_context.add_message( - AssistantMessage(content=reflection_model_result.content, source=self.name) + AssistantMessage( + content=reflection_model_result.content, source=self.name, thought=reflection_model_result.thought + ) ) # Yield the response. yield Response( diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 17249e674854..a4a962b62204 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -137,6 +137,17 @@ class ModelClientStreamingChunkEvent(BaseAgentEvent): type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent" +class ThoughtEvent(BaseAgentEvent): + """An event signaling the thought process of an agent. + It is used to communicate the reasoning tokens generated by a reasoning model, + or the extra text content generated by a function call.""" + + content: str + """The thought process.""" + + type: Literal["ThoughtEvent"] = "ThoughtEvent" + + ChatMessage = Annotated[ TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type") ] @@ -148,7 +159,8 @@ class ModelClientStreamingChunkEvent(BaseAgentEvent): | ToolCallExecutionEvent | MemoryQueryEvent | UserInputRequestedEvent - | ModelClientStreamingChunkEvent, + | ModelClientStreamingChunkEvent + | ThoughtEvent, Field(discriminator="type"), ] """Events emitted by agents and teams when they work, not used for agent-to-agent communication.""" @@ -168,4 +180,5 @@ class ModelClientStreamingChunkEvent(BaseAgentEvent): "MemoryQueryEvent", "UserInputRequestedEvent", "ModelClientStreamingChunkEvent", + "ThoughtEvent", ] diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 827a547660e4..e420d23fb05d 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -17,6 +17,7 @@ ToolCallExecutionEvent, ToolCallRequestEvent, ToolCallSummaryMessage, + ThoughtEvent, ) from autogen_core import ComponentModel, FunctionCall, Image from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult @@ -89,7 +90,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: finish_reason="tool_calls", index=0, message=ChatCompletionMessage( - content=None, + content="Calling pass function", tool_calls=[ ChatCompletionMessageToolCall( id="1", @@ -151,18 +152,20 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: ) result = await agent.run(task="task") - assert len(result.messages) == 4 + assert len(result.messages) == 5 assert isinstance(result.messages[0], TextMessage) assert result.messages[0].models_usage is None - assert isinstance(result.messages[1], ToolCallRequestEvent) - assert result.messages[1].models_usage is not None - assert result.messages[1].models_usage.completion_tokens == 5 - assert result.messages[1].models_usage.prompt_tokens == 10 - assert isinstance(result.messages[2], ToolCallExecutionEvent) - assert result.messages[2].models_usage is None - assert isinstance(result.messages[3], ToolCallSummaryMessage) - assert result.messages[3].content == "pass" + assert isinstance(result.messages[1], ThoughtEvent) + assert result.messages[1].content == "Calling pass function" + assert isinstance(result.messages[2], ToolCallRequestEvent) + assert result.messages[2].models_usage is not None + assert result.messages[2].models_usage.completion_tokens == 5 + assert result.messages[2].models_usage.prompt_tokens == 10 + assert isinstance(result.messages[3], ToolCallExecutionEvent) assert result.messages[3].models_usage is None + assert isinstance(result.messages[4], ToolCallSummaryMessage) + assert result.messages[4].content == "pass" + assert result.messages[4].models_usage is None # Test streaming. mock.curr_index = 0 # Reset the mock @@ -302,7 +305,7 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None: finish_reason="tool_calls", index=0, message=ChatCompletionMessage( - content=None, + content="Calling pass and echo functions", tool_calls=[ ChatCompletionMessageToolCall( id="1", @@ -380,30 +383,32 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None: ) result = await agent.run(task="task") - assert len(result.messages) == 4 + assert len(result.messages) == 5 assert isinstance(result.messages[0], TextMessage) assert result.messages[0].models_usage is None - assert isinstance(result.messages[1], ToolCallRequestEvent) - assert result.messages[1].content == [ + assert isinstance(result.messages[1], ThoughtEvent) + assert result.messages[1].content == "Calling pass and echo functions" + assert isinstance(result.messages[2], ToolCallRequestEvent) + assert result.messages[2].content == [ FunctionCall(id="1", arguments=r'{"input": "task1"}', name="_pass_function"), FunctionCall(id="2", arguments=r'{"input": "task2"}', name="_pass_function"), FunctionCall(id="3", arguments=r'{"input": "task3"}', name="_echo_function"), ] - assert result.messages[1].models_usage is not None - assert result.messages[1].models_usage.completion_tokens == 5 - assert result.messages[1].models_usage.prompt_tokens == 10 - assert isinstance(result.messages[2], ToolCallExecutionEvent) + assert result.messages[2].models_usage is not None + assert result.messages[2].models_usage.completion_tokens == 5 + assert result.messages[2].models_usage.prompt_tokens == 10 + assert isinstance(result.messages[3], ToolCallExecutionEvent) expected_content = [ FunctionExecutionResult(call_id="1", content="pass", is_error=False), FunctionExecutionResult(call_id="2", content="pass", is_error=False), FunctionExecutionResult(call_id="3", content="task3", is_error=False), ] for expected in expected_content: - assert expected in result.messages[2].content - assert result.messages[2].models_usage is None - assert isinstance(result.messages[3], ToolCallSummaryMessage) - assert result.messages[3].content == "pass\npass\ntask3" + assert expected in result.messages[3].content assert result.messages[3].models_usage is None + assert isinstance(result.messages[4], ToolCallSummaryMessage) + assert result.messages[4].content == "pass\npass\ntask3" + assert result.messages[4].models_usage is None # Test streaming. mock.curr_index = 0 # Reset the mock diff --git a/python/packages/autogen-core/src/autogen_core/models/_types.py b/python/packages/autogen-core/src/autogen_core/models/_types.py index 76d87ab8d46d..39ffcd64af9b 100644 --- a/python/packages/autogen-core/src/autogen_core/models/_types.py +++ b/python/packages/autogen-core/src/autogen_core/models/_types.py @@ -44,6 +44,9 @@ class AssistantMessage(BaseModel): content: Union[str, List[FunctionCall]] """The content of the message.""" + thought: str | None = None + """The reasoning text for the completion if available. Used for reasoning model and additional text content besides function calls.""" + source: str """The name of the agent that sent this message.""" diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index 380b79bd4be5..ad2017ec4053 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -208,11 +208,19 @@ def assistant_message_to_oai( ) -> ChatCompletionAssistantMessageParam: assert_valid_name(message.source) if isinstance(message.content, list): - return ChatCompletionAssistantMessageParam( - tool_calls=[func_call_to_oai(x) for x in message.content], - role="assistant", - name=message.source, - ) + if message.thought is not None: + return ChatCompletionAssistantMessageParam( + content=message.thought, + tool_calls=[func_call_to_oai(x) for x in message.content], + role="assistant", + name=message.source, + ) + else: + return ChatCompletionAssistantMessageParam( + tool_calls=[func_call_to_oai(x) for x in message.content], + role="assistant", + name=message.source, + ) else: return ChatCompletionAssistantMessageParam( content=message.content, @@ -572,6 +580,7 @@ async def create( # Detect whether it is a function call or not. # We don't rely on choice.finish_reason as it is not always accurate, depending on the API used. content: Union[str, List[FunctionCall]] + thought: str | None = None if choice.message.function_call is not None: raise ValueError("function_call is deprecated and is not supported by this model client.") elif choice.message.tool_calls is not None and len(choice.message.tool_calls) > 0: @@ -583,11 +592,8 @@ async def create( stacklevel=2, ) if choice.message.content is not None and choice.message.content != "": - warnings.warn( - "Both tool_calls and content are present in the message. " - "This is unexpected. content will be ignored, tool_calls will be used.", - stacklevel=2, - ) + # Put the content in the thought field. + thought = choice.message.content # NOTE: If OAI response type changes, this will need to be updated content = [] for tool_call in choice.message.tool_calls: @@ -626,8 +632,6 @@ async def create( if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1: thought, content = parse_r1_content(content) - else: - thought = None response = CreateResult( finish_reason=normalize_stop_reason(finish_reason), @@ -788,6 +792,8 @@ async def create_stream( content_deltas.append(choice.delta.content) if len(choice.delta.content) > 0: yield choice.delta.content + # NOTE: for OpenAI, tool_calls and content are mutually exclusive it seems, so we can skip the rest of the loop. + # However, this may not be the case for other APIs -- we should expect this may need to be updated. continue # Otherwise, get tool calls @@ -832,20 +838,24 @@ async def create_stream( raise ValueError("Function calls are not supported in this context") content: Union[str, List[FunctionCall]] - if len(content_deltas) > 1: + thought: str | None = None + if full_tool_calls: + # This is a tool call. + content = list(full_tool_calls.values()) + if len(content_deltas) > 1: + # Put additional text content in the thought field. + thought = "".join(content_deltas) + elif len(content_deltas) > 0: + # This is a text-only content. content = "".join(content_deltas) - if chunk and chunk.usage: - completion_tokens = chunk.usage.completion_tokens - else: - completion_tokens = 0 + else: + warnings.warn("No text content or tool calls are available. Model returned empty result.", stacklevel=2) + content = "" + + if chunk and chunk.usage: + completion_tokens = chunk.usage.completion_tokens else: completion_tokens = 0 - # TODO: fix assumption that dict values were added in order and actually order by int index - # for tool_call in full_tool_calls.values(): - # # value = json.dumps(tool_call) - # # completion_tokens += count_token(value, model=model) - # completion_tokens += 0 - content = list(full_tool_calls.values()) usage = RequestUsage( prompt_tokens=prompt_tokens, @@ -854,8 +864,6 @@ async def create_stream( if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1: thought, content = parse_r1_content(content) - else: - thought = None result = CreateResult( finish_reason=normalize_stop_reason(stop_reason), diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index ea5b28e28a63..45655e653379 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -26,7 +26,12 @@ from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions from openai.resources.chat.completions import AsyncCompletions from openai.types.chat.chat_completion import ChatCompletion, Choice -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.chat.chat_completion_message_tool_call import ( @@ -734,7 +739,7 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None: object="chat.completion", usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), - # Warning completion when content is not None. + # Thought field is populated when content is not None. ChatCompletion( id="id4", choices=[ @@ -850,13 +855,11 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None: assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")] assert create_result.finish_reason == "function_calls" - # Warning completion when content is not None. - with pytest.warns(UserWarning, match="Both tool_calls and content are present in the message"): - create_result = await model_client.create( - messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool] - ) - assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")] - assert create_result.finish_reason == "function_calls" + # Thought field is populated when content is not None. + create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]) + assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")] + assert create_result.finish_reason == "function_calls" + assert create_result.thought == "I should make a tool call." # Should not be returning tool calls when the tool_calls are empty create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]) @@ -872,6 +875,85 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None: assert create_result.finish_reason == "function_calls" +@pytest.mark.asyncio +async def test_tool_calling_with_stream(monkeypatch: pytest.MonkeyPatch) -> None: + async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatCompletionChunk, None]: + model = resolve_model(kwargs.get("model", "gpt-4o")) + mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"] + mock_chunks = [ + # generate the list of mock chunk content + MockChunkDefinition( + chunk_choice=ChunkChoice( + finish_reason=None, + index=0, + delta=ChoiceDelta( + content=mock_chunk_content, + role="assistant", + ), + ), + usage=None, + ) + for mock_chunk_content in mock_chunks_content + ] + [ + # generate the function call chunk + MockChunkDefinition( + chunk_choice=ChunkChoice( + finish_reason="tool_calls", + index=0, + delta=ChoiceDelta( + content=None, + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="1", + type="function", + function=ChoiceDeltaToolCallFunction( + name="_pass_function", + arguments=json.dumps({"input": "task"}), + ), + ) + ], + ), + ), + usage=None, + ) + ] + for mock_chunk in mock_chunks: + await asyncio.sleep(0.1) + yield ChatCompletionChunk( + id="id", + choices=[mock_chunk.chunk_choice], + created=0, + model=model, + object="chat.completion.chunk", + usage=mock_chunk.usage, + ) + + async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]: + stream = kwargs.get("stream", False) + if not stream: + raise ValueError("Stream is not False") + else: + return _mock_create_stream(*args, **kwargs) + + monkeypatch.setattr(AsyncCompletions, "create", _mock_create) + + model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="") + pass_tool = FunctionTool(_pass_function, description="pass tool.") + stream = model_client.create_stream(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]) + chunks: List[str | CreateResult] = [] + async for chunk in stream: + chunks.append(chunk) + assert chunks[0] == "Hello" + assert chunks[1] == " Another Hello" + assert chunks[2] == " Yet Another Hello" + assert isinstance(chunks[-1], CreateResult) + assert chunks[-1].content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")] + assert chunks[-1].finish_reason == "function_calls" + assert chunks[-1].thought == "Hello Another Hello Yet Another Hello" + + async def _test_model_client_basic_completion(model_client: OpenAIChatCompletionClient) -> None: # Test basic completion create_result = await model_client.create(