Skip to content

Commit

Permalink
Pass context between AssistantAgent for handoffs
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu committed Jan 16, 2025
1 parent 1a3ac62 commit a990f74
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
Expand Down Expand Up @@ -89,6 +90,7 @@ class AssistantAgent(BaseChatAgent):
* If a handoff is triggered, a :class:`~autogen_agentchat.messages.HandoffMessage` will be returned in :attr:`~autogen_agentchat.base.Response.chat_message`.
* If there are tool calls, they will also be executed right away before returning the handoff.
* The tool calls and results are passed to the target agent through :attr:`~autogen_agentchat.messages.HandoffMessage.context`.
.. note::
Expand Down Expand Up @@ -334,6 +336,10 @@ async def on_messages_stream(
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.model_info["vision"] is False:
raise ValueError("The model does not support vision.")
if isinstance(msg, HandoffMessage):
# Add handoff context to the model context.
for context_msg in msg.context:
await self._model_context.add_message(context_msg)
await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))

# Inner messages.
Expand All @@ -352,77 +358,108 @@ async def on_messages_stream(

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(
model_result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))

# Check if the response is a string and return it.
if isinstance(result.content, str):
if isinstance(model_result.content, str):
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
chat_message=TextMessage(
content=model_result.content, source=self.name, models_usage=model_result.usage
),
inner_messages=inner_messages,
)
return

# Process tool calls.
assert isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content)
tool_call_msg = ToolCallRequestEvent(content=result.content, source=self.name, models_usage=result.usage)
assert isinstance(model_result.content, list) and all(
isinstance(item, FunctionCall) for item in model_result.content
)
tool_call_msg = ToolCallRequestEvent(
content=model_result.content, source=self.name, models_usage=model_result.usage
)
event_logger.debug(tool_call_msg)
# Add the tool call message to the output.
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name)
exec_results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in model_result.content]
)
tool_call_result_msg = ToolCallExecutionEvent(content=exec_results, source=self.name)
event_logger.debug(tool_call_result_msg)
await self._model_context.add_message(FunctionExecutionResultMessage(content=results))
await self._model_context.add_message(FunctionExecutionResultMessage(content=exec_results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

# Correlate tool call results with tool calls.
tool_calls = [call for call in model_result.content if call.name not in self._handoffs]
tool_call_results: List[FunctionExecutionResult] = []
for tool_call in tool_calls:
found = False
for exec_result in exec_results:
if exec_result.call_id == tool_call.id:
found = True
tool_call_results.append(exec_result)
break
if not found:
raise RuntimeError(f"Tool call result not found for call id: {tool_call.id}")

# Detect handoff requests.
handoffs: List[HandoffBase] = []
for call in result.content:
if call.name in self._handoffs:
handoffs.append(self._handoffs[call.name])
if len(handoffs) > 0:
handoff_reqs = [call for call in model_result.content if call.name in self._handoffs]
if len(handoff_reqs) > 0:
handoffs = [self._handoffs[call.name] for call in handoff_reqs]
if len(handoffs) > 1:
# show warning if multiple handoffs detected
warnings.warn(
f"Multiple handoffs detected only the first is executed: {[handoff.name for handoff in handoffs]}",
(
f"Multiple handoffs detected only the first is executed: {[handoff.name for handoff in handoffs]}. "
"Disable parallel tool call in the model client to avoid this warning."
),
stacklevel=2,
)
# Current context for handoff.
handoff_context: List[LLMMessage] = []
if len(tool_calls) > 0:
handoff_context.append(AssistantMessage(content=tool_calls, source=self.name))
handoff_context.append(FunctionExecutionResultMessage(content=tool_call_results))
# Return the output messages to signal the handoff.
yield Response(
chat_message=HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name),
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name, context=handoff_context
),
inner_messages=inner_messages,
)
return

if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(model_result.content, str)
# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
# Yield the response.
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
chat_message=TextMessage(
content=model_result.content, source=self.name, models_usage=model_result.usage
),
inner_messages=inner_messages,
)
else:
# Return tool call result as the response.
tool_call_summaries: List[str] = []
for i in range(len(tool_call_msg.content)):
for tool_call, tool_call_result in zip(tool_calls, tool_call_results, strict=False):
tool_call_summaries.append(
self._tool_call_summary_format.format(
tool_name=tool_call_msg.content[i].name,
arguments=tool_call_msg.content[i].arguments,
result=tool_call_result_msg.content[i].content,
tool_name=tool_call.name,
arguments=tool_call.arguments,
result=tool_call_result.content,
),
)
tool_call_summary = "\n".join(tool_call_summaries)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class and includes specific fields relevant to the type of message being sent.

from autogen_core import FunctionCall, Image
from autogen_core.memory import MemoryContent
from autogen_core.models import FunctionExecutionResult, RequestUsage
from autogen_core.models import FunctionExecutionResult, LLMMessage, RequestUsage
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated

Expand Down Expand Up @@ -74,6 +74,9 @@ class HandoffMessage(BaseChatMessage):
content: str
"""The handoff message to the target agent."""

context: List[LLMMessage] = []
"""The model context to be passed to the target agent."""

type: Literal["HandoffMessage"] = "HandoffMessage"


Expand Down
139 changes: 138 additions & 1 deletion python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@
from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
from autogen_agentchat.ui import Console
from autogen_core import AgentId, CancellationToken
from autogen_core import AgentId, CancellationToken, FunctionCall
from autogen_core.models import (
AssistantMessage,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
UserMessage,
)
from autogen_core.tools import FunctionTool
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
from autogen_ext.models.openai import OpenAIChatCompletionClient
Expand Down Expand Up @@ -986,6 +993,136 @@ async def test_swarm_pause_and_resume() -> None:
assert result.messages[0].content == "Transferred to second_agent."


@pytest.mark.asyncio
async def test_swarm_with_parallel_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="tool1",
arguments=json.dumps({}),
),
),
ChatCompletionMessageToolCall(
id="2",
type="function",
function=Function(
name="tool2",
arguments=json.dumps({}),
),
),
ChatCompletionMessageToolCall(
id="3",
type="function",
function=Function(
name="handoff_to_agent2",
arguments=json.dumps({}),
),
),
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)

expected_handoff_context: List[LLMMessage] = [
AssistantMessage(
source="agent1",
content=[
FunctionCall(id="1", name="tool1", arguments="{}"),
FunctionCall(id="2", name="tool2", arguments="{}"),
],
),
FunctionExecutionResultMessage(
content=[
FunctionExecutionResult(content="tool1", call_id="1"),
FunctionExecutionResult(content="tool2", call_id="2"),
]
),
]

def tool1() -> str:
return "tool1"

def tool2() -> str:
return "tool2"

agent1 = AssistantAgent(
"agent1",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
tools=[tool1, tool2],
)
agent2 = AssistantAgent(
"agent2",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
)
termination = TextMentionTermination("TERMINATE")
team = Swarm([agent1, agent2], termination_condition=termination)
result = await team.run(task="task")
assert len(result.messages) == 6
assert result.messages[0] == TextMessage(content="task", source="user")
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[3] == HandoffMessage(
content="handoff to agent2",
target="agent2",
source="agent1",
context=expected_handoff_context,
)
assert result.messages[4].content == "Hello"
assert result.messages[4].source == "agent2"
assert result.messages[5].content == "TERMINATE"
assert result.messages[5].source == "agent2"

# Verify the tool calls are in agent2's context.
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
assert agent2_model_ctx_messages[0] == UserMessage(content="task", source="user")
assert agent2_model_ctx_messages[1] == expected_handoff_context[0]
assert agent2_model_ctx_messages[2] == expected_handoff_context[1]


@pytest.mark.asyncio
async def test_swarm_with_handoff_termination() -> None:
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
Expand Down

0 comments on commit a990f74

Please sign in to comment.