diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 2c71a2558df0..85825e47beb5 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -24,6 +24,7 @@ ChatCompletionClient, FunctionExecutionResult, FunctionExecutionResultMessage, + LLMMessage, SystemMessage, UserMessage, ) @@ -105,6 +106,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): * If a handoff is triggered, a :class:`~autogen_agentchat.messages.HandoffMessage` will be returned in :attr:`~autogen_agentchat.base.Response.chat_message`. * If there are tool calls, they will also be executed right away before returning the handoff. + * The tool calls and results are passed to the target agent through :attr:`~autogen_agentchat.messages.HandoffMessage.context`. .. note:: @@ -353,6 +355,10 @@ async def on_messages_stream( for msg in messages: if isinstance(msg, MultiModalMessage) and self._model_client.model_info["vision"] is False: raise ValueError("The model does not support vision.") + if isinstance(msg, HandoffMessage): + # Add handoff context to the model context. + for context_msg in msg.context: + await self._model_context.add_message(context_msg) await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source)) # Inner messages. @@ -371,52 +377,81 @@ async def on_messages_stream( # Generate an inference result based on the current model context. llm_messages = self._system_messages + await self._model_context.get_messages() - result = await self._model_client.create( + model_result = await self._model_client.create( llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token ) # Add the response to the model context. - await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name)) + await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name)) # Check if the response is a string and return it. - if isinstance(result.content, str): + if isinstance(model_result.content, str): yield Response( - chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage), + chat_message=TextMessage( + content=model_result.content, source=self.name, models_usage=model_result.usage + ), inner_messages=inner_messages, ) return # Process tool calls. - assert isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content) - tool_call_msg = ToolCallRequestEvent(content=result.content, source=self.name, models_usage=result.usage) + assert isinstance(model_result.content, list) and all( + isinstance(item, FunctionCall) for item in model_result.content + ) + tool_call_msg = ToolCallRequestEvent( + content=model_result.content, source=self.name, models_usage=model_result.usage + ) event_logger.debug(tool_call_msg) # Add the tool call message to the output. inner_messages.append(tool_call_msg) yield tool_call_msg # Execute the tool calls. - results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content]) - tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name) + exec_results = await asyncio.gather( + *[self._execute_tool_call(call, cancellation_token) for call in model_result.content] + ) + tool_call_result_msg = ToolCallExecutionEvent(content=exec_results, source=self.name) event_logger.debug(tool_call_result_msg) - await self._model_context.add_message(FunctionExecutionResultMessage(content=results)) + await self._model_context.add_message(FunctionExecutionResultMessage(content=exec_results)) inner_messages.append(tool_call_result_msg) yield tool_call_result_msg + # Correlate tool call results with tool calls. + tool_calls = [call for call in model_result.content if call.name not in self._handoffs] + tool_call_results: List[FunctionExecutionResult] = [] + for tool_call in tool_calls: + found = False + for exec_result in exec_results: + if exec_result.call_id == tool_call.id: + found = True + tool_call_results.append(exec_result) + break + if not found: + raise RuntimeError(f"Tool call result not found for call id: {tool_call.id}") + # Detect handoff requests. - handoffs: List[HandoffBase] = [] - for call in result.content: - if call.name in self._handoffs: - handoffs.append(self._handoffs[call.name]) - if len(handoffs) > 0: + handoff_reqs = [call for call in model_result.content if call.name in self._handoffs] + if len(handoff_reqs) > 0: + handoffs = [self._handoffs[call.name] for call in handoff_reqs] if len(handoffs) > 1: # show warning if multiple handoffs detected warnings.warn( - f"Multiple handoffs detected only the first is executed: {[handoff.name for handoff in handoffs]}", + ( + f"Multiple handoffs detected only the first is executed: {[handoff.name for handoff in handoffs]}. " + "Disable parallel tool call in the model client to avoid this warning." + ), stacklevel=2, ) + # Current context for handoff. + handoff_context: List[LLMMessage] = [] + if len(tool_calls) > 0: + handoff_context.append(AssistantMessage(content=tool_calls, source=self.name)) + handoff_context.append(FunctionExecutionResultMessage(content=tool_call_results)) # Return the output messages to signal the handoff. yield Response( - chat_message=HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name), + chat_message=HandoffMessage( + content=handoffs[0].message, target=handoffs[0].target, source=self.name, context=handoff_context + ), inner_messages=inner_messages, ) return @@ -424,24 +459,26 @@ async def on_messages_stream( if self._reflect_on_tool_use: # Generate another inference result based on the tool call and result. llm_messages = self._system_messages + await self._model_context.get_messages() - result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token) - assert isinstance(result.content, str) + model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token) + assert isinstance(model_result.content, str) # Add the response to the model context. - await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name)) + await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name)) # Yield the response. yield Response( - chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage), + chat_message=TextMessage( + content=model_result.content, source=self.name, models_usage=model_result.usage + ), inner_messages=inner_messages, ) else: # Return tool call result as the response. tool_call_summaries: List[str] = [] - for i in range(len(tool_call_msg.content)): + for tool_call, tool_call_result in zip(tool_calls, tool_call_results, strict=False): tool_call_summaries.append( self._tool_call_summary_format.format( - tool_name=tool_call_msg.content[i].name, - arguments=tool_call_msg.content[i].arguments, - result=tool_call_result_msg.content[i].content, + tool_name=tool_call.name, + arguments=tool_call.arguments, + result=tool_call_result.content, ), ) tool_call_summary = "\n".join(tool_call_summaries) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 6069c8ddc8dd..25d9e732d335 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -9,7 +9,7 @@ class and includes specific fields relevant to the type of message being sent. from autogen_core import FunctionCall, Image from autogen_core.memory import MemoryContent -from autogen_core.models import FunctionExecutionResult, RequestUsage +from autogen_core.models import FunctionExecutionResult, LLMMessage, RequestUsage from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated @@ -74,6 +74,9 @@ class HandoffMessage(BaseChatMessage): content: str """The handoff message to the target agent.""" + context: List[LLMMessage] = [] + """The model context to be passed to the target agent.""" + type: Literal["HandoffMessage"] = "HandoffMessage" diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index e477ed3f1245..7b281e39f6f5 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -17,10 +17,10 @@ ToolCallRequestEvent, ToolCallSummaryMessage, ) -from autogen_core import Image +from autogen_core import FunctionCall, Image from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult from autogen_core.model_context import BufferedChatCompletionContext -from autogen_core.models import LLMMessage +from autogen_core.models import FunctionExecutionResult, LLMMessage from autogen_core.models._model_client import ModelFamily from autogen_core.tools import FunctionTool from autogen_ext.models.openai import OpenAIChatCompletionClient @@ -281,6 +281,142 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> assert state == state2 +@pytest.mark.asyncio +async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None: + model = "gpt-4o-2024-05-13" + chat_completions = [ + ChatCompletion( + id="id1", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="1", + type="function", + function=Function( + name="_pass_function", + arguments=json.dumps({"input": "task1"}), + ), + ), + ChatCompletionMessageToolCall( + id="2", + type="function", + function=Function( + name="_pass_function", + arguments=json.dumps({"input": "task2"}), + ), + ), + ChatCompletionMessageToolCall( + id="3", + type="function", + function=Function( + name="_echo_function", + arguments=json.dumps({"input": "task3"}), + ), + ), + ], + role="assistant", + ), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + ChatCompletion( + id="id2", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="pass", role="assistant"), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + ChatCompletion( + id="id2", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="TERMINATE", role="assistant"), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + ] + mock = _MockChatCompletion(chat_completions) + monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + agent = AssistantAgent( + "tool_use_agent", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], + ) + result = await agent.run(task="task") + + assert len(result.messages) == 4 + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].models_usage is None + assert isinstance(result.messages[1], ToolCallRequestEvent) + assert result.messages[1].content == [ + FunctionCall(id="1", arguments=r'{"input": "task1"}', name="_pass_function"), + FunctionCall(id="2", arguments=r'{"input": "task2"}', name="_pass_function"), + FunctionCall(id="3", arguments=r'{"input": "task3"}', name="_echo_function"), + ] + assert result.messages[1].models_usage is not None + assert result.messages[1].models_usage.completion_tokens == 5 + assert result.messages[1].models_usage.prompt_tokens == 10 + assert isinstance(result.messages[2], ToolCallExecutionEvent) + expected_content = [ + FunctionExecutionResult(call_id="1", content="pass"), + FunctionExecutionResult(call_id="2", content="pass"), + FunctionExecutionResult(call_id="3", content="task3"), + ] + for expected in expected_content: + assert expected in result.messages[2].content + assert result.messages[2].models_usage is None + assert isinstance(result.messages[3], ToolCallSummaryMessage) + assert result.messages[3].content == "pass\npass\ntask3" + assert result.messages[3].models_usage is None + + # Test streaming. + mock.curr_index = 0 # Reset the mock + index = 0 + async for message in agent.run_stream(task="task"): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + + # Test state saving and loading. + state = await agent.save_state() + agent2 = AssistantAgent( + "tool_use_agent", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + ) + await agent2.load_state(state) + state2 = await agent2.save_state() + assert state == state2 + + @pytest.mark.asyncio async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: handoff = Handoff(target="agent2") diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 25194d6049af..3d51e1b58a40 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -33,7 +33,14 @@ from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager from autogen_agentchat.ui import Console -from autogen_core import AgentId, CancellationToken +from autogen_core import AgentId, CancellationToken, FunctionCall +from autogen_core.models import ( + AssistantMessage, + FunctionExecutionResult, + FunctionExecutionResultMessage, + LLMMessage, + UserMessage, +) from autogen_core.tools import FunctionTool from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor from autogen_ext.models.openai import OpenAIChatCompletionClient @@ -986,6 +993,136 @@ async def test_swarm_pause_and_resume() -> None: assert result.messages[0].content == "Transferred to second_agent." +@pytest.mark.asyncio +async def test_swarm_with_parallel_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None: + model = "gpt-4o-2024-05-13" + chat_completions = [ + ChatCompletion( + id="id1", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="1", + type="function", + function=Function( + name="tool1", + arguments=json.dumps({}), + ), + ), + ChatCompletionMessageToolCall( + id="2", + type="function", + function=Function( + name="tool2", + arguments=json.dumps({}), + ), + ), + ChatCompletionMessageToolCall( + id="3", + type="function", + function=Function( + name="handoff_to_agent2", + arguments=json.dumps({}), + ), + ), + ], + role="assistant", + ), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ), + ChatCompletion( + id="id2", + choices=[ + Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant")) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ), + ChatCompletion( + id="id2", + choices=[ + Choice( + finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant") + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ), + ] + mock = _MockChatCompletion(chat_completions) + monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + + expected_handoff_context: List[LLMMessage] = [ + AssistantMessage( + source="agent1", + content=[ + FunctionCall(id="1", name="tool1", arguments="{}"), + FunctionCall(id="2", name="tool2", arguments="{}"), + ], + ), + FunctionExecutionResultMessage( + content=[ + FunctionExecutionResult(content="tool1", call_id="1"), + FunctionExecutionResult(content="tool2", call_id="2"), + ] + ), + ] + + def tool1() -> str: + return "tool1" + + def tool2() -> str: + return "tool2" + + agent1 = AssistantAgent( + "agent1", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")], + tools=[tool1, tool2], + ) + agent2 = AssistantAgent( + "agent2", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + ) + termination = TextMentionTermination("TERMINATE") + team = Swarm([agent1, agent2], termination_condition=termination) + result = await team.run(task="task") + assert len(result.messages) == 6 + assert result.messages[0] == TextMessage(content="task", source="user") + assert isinstance(result.messages[1], ToolCallRequestEvent) + assert isinstance(result.messages[2], ToolCallExecutionEvent) + assert result.messages[3] == HandoffMessage( + content="handoff to agent2", + target="agent2", + source="agent1", + context=expected_handoff_context, + ) + assert result.messages[4].content == "Hello" + assert result.messages[4].source == "agent2" + assert result.messages[5].content == "TERMINATE" + assert result.messages[5].source == "agent2" + + # Verify the tool calls are in agent2's context. + agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore + assert agent2_model_ctx_messages[0] == UserMessage(content="task", source="user") + assert agent2_model_ctx_messages[1] == expected_handoff_context[0] + assert agent2_model_ctx_messages[2] == expected_handoff_context[1] + + @pytest.mark.asyncio async def test_swarm_with_handoff_termination() -> None: first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")