Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass context between AssistantAgent for handoffs #5084

Merged
merged 6 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
Expand Down Expand Up @@ -105,6 +106,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):

* If a handoff is triggered, a :class:`~autogen_agentchat.messages.HandoffMessage` will be returned in :attr:`~autogen_agentchat.base.Response.chat_message`.
* If there are tool calls, they will also be executed right away before returning the handoff.
* The tool calls and results are passed to the target agent through :attr:`~autogen_agentchat.messages.HandoffMessage.context`.


.. note::
Expand Down Expand Up @@ -353,6 +355,10 @@ async def on_messages_stream(
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.model_info["vision"] is False:
raise ValueError("The model does not support vision.")
if isinstance(msg, HandoffMessage):
# Add handoff context to the model context.
for context_msg in msg.context:
await self._model_context.add_message(context_msg)
await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))

# Inner messages.
Expand All @@ -371,77 +377,108 @@ async def on_messages_stream(

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(
model_result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))

# Check if the response is a string and return it.
if isinstance(result.content, str):
if isinstance(model_result.content, str):
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
chat_message=TextMessage(
content=model_result.content, source=self.name, models_usage=model_result.usage
),
inner_messages=inner_messages,
)
return

# Process tool calls.
assert isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content)
tool_call_msg = ToolCallRequestEvent(content=result.content, source=self.name, models_usage=result.usage)
assert isinstance(model_result.content, list) and all(
isinstance(item, FunctionCall) for item in model_result.content
)
tool_call_msg = ToolCallRequestEvent(
content=model_result.content, source=self.name, models_usage=model_result.usage
)
event_logger.debug(tool_call_msg)
# Add the tool call message to the output.
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name)
exec_results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in model_result.content]
)
tool_call_result_msg = ToolCallExecutionEvent(content=exec_results, source=self.name)
event_logger.debug(tool_call_result_msg)
await self._model_context.add_message(FunctionExecutionResultMessage(content=results))
await self._model_context.add_message(FunctionExecutionResultMessage(content=exec_results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

# Correlate tool call results with tool calls.
tool_calls = [call for call in model_result.content if call.name not in self._handoffs]
tool_call_results: List[FunctionExecutionResult] = []
for tool_call in tool_calls:
found = False
for exec_result in exec_results:
if exec_result.call_id == tool_call.id:
found = True
tool_call_results.append(exec_result)
break
if not found:
raise RuntimeError(f"Tool call result not found for call id: {tool_call.id}")

# Detect handoff requests.
handoffs: List[HandoffBase] = []
for call in result.content:
if call.name in self._handoffs:
handoffs.append(self._handoffs[call.name])
if len(handoffs) > 0:
handoff_reqs = [call for call in model_result.content if call.name in self._handoffs]
if len(handoff_reqs) > 0:
handoffs = [self._handoffs[call.name] for call in handoff_reqs]
if len(handoffs) > 1:
# show warning if multiple handoffs detected
warnings.warn(
f"Multiple handoffs detected only the first is executed: {[handoff.name for handoff in handoffs]}",
(
f"Multiple handoffs detected only the first is executed: {[handoff.name for handoff in handoffs]}. "
"Disable parallel tool call in the model client to avoid this warning."
),
stacklevel=2,
)
# Current context for handoff.
handoff_context: List[LLMMessage] = []
if len(tool_calls) > 0:
handoff_context.append(AssistantMessage(content=tool_calls, source=self.name))
handoff_context.append(FunctionExecutionResultMessage(content=tool_call_results))
# Return the output messages to signal the handoff.
yield Response(
chat_message=HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name),
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name, context=handoff_context
),
inner_messages=inner_messages,
)
return

if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(model_result.content, str)
# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
# Yield the response.
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
chat_message=TextMessage(
content=model_result.content, source=self.name, models_usage=model_result.usage
),
inner_messages=inner_messages,
)
else:
# Return tool call result as the response.
tool_call_summaries: List[str] = []
for i in range(len(tool_call_msg.content)):
for tool_call, tool_call_result in zip(tool_calls, tool_call_results, strict=False):
tool_call_summaries.append(
self._tool_call_summary_format.format(
tool_name=tool_call_msg.content[i].name,
arguments=tool_call_msg.content[i].arguments,
result=tool_call_result_msg.content[i].content,
tool_name=tool_call.name,
arguments=tool_call.arguments,
result=tool_call_result.content,
),
)
tool_call_summary = "\n".join(tool_call_summaries)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class and includes specific fields relevant to the type of message being sent.

from autogen_core import FunctionCall, Image
from autogen_core.memory import MemoryContent
from autogen_core.models import FunctionExecutionResult, RequestUsage
from autogen_core.models import FunctionExecutionResult, LLMMessage, RequestUsage
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated

Expand Down Expand Up @@ -74,6 +74,9 @@ class HandoffMessage(BaseChatMessage):
content: str
"""The handoff message to the target agent."""

context: List[LLMMessage] = []
"""The model context to be passed to the target agent."""

type: Literal["HandoffMessage"] = "HandoffMessage"


Expand Down
139 changes: 138 additions & 1 deletion python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@
from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
from autogen_agentchat.ui import Console
from autogen_core import AgentId, CancellationToken
from autogen_core import AgentId, CancellationToken, FunctionCall
from autogen_core.models import (
AssistantMessage,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
UserMessage,
)
from autogen_core.tools import FunctionTool
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
from autogen_ext.models.openai import OpenAIChatCompletionClient
Expand Down Expand Up @@ -986,6 +993,136 @@ async def test_swarm_pause_and_resume() -> None:
assert result.messages[0].content == "Transferred to second_agent."


@pytest.mark.asyncio
async def test_swarm_with_parallel_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="tool1",
arguments=json.dumps({}),
),
),
ChatCompletionMessageToolCall(
id="2",
type="function",
function=Function(
name="tool2",
arguments=json.dumps({}),
),
),
ChatCompletionMessageToolCall(
id="3",
type="function",
function=Function(
name="handoff_to_agent2",
arguments=json.dumps({}),
),
),
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)

expected_handoff_context: List[LLMMessage] = [
AssistantMessage(
source="agent1",
content=[
FunctionCall(id="1", name="tool1", arguments="{}"),
FunctionCall(id="2", name="tool2", arguments="{}"),
],
),
FunctionExecutionResultMessage(
content=[
FunctionExecutionResult(content="tool1", call_id="1"),
FunctionExecutionResult(content="tool2", call_id="2"),
]
),
]

def tool1() -> str:
return "tool1"

def tool2() -> str:
return "tool2"

agent1 = AssistantAgent(
"agent1",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
tools=[tool1, tool2],
)
agent2 = AssistantAgent(
"agent2",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
)
termination = TextMentionTermination("TERMINATE")
team = Swarm([agent1, agent2], termination_condition=termination)
result = await team.run(task="task")
assert len(result.messages) == 6
assert result.messages[0] == TextMessage(content="task", source="user")
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[3] == HandoffMessage(
content="handoff to agent2",
target="agent2",
source="agent1",
context=expected_handoff_context,
)
assert result.messages[4].content == "Hello"
assert result.messages[4].source == "agent2"
assert result.messages[5].content == "TERMINATE"
assert result.messages[5].source == "agent2"

# Verify the tool calls are in agent2's context.
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
assert agent2_model_ctx_messages[0] == UserMessage(content="task", source="user")
assert agent2_model_ctx_messages[1] == expected_handoff_context[0]
assert agent2_model_ctx_messages[2] == expected_handoff_context[1]


@pytest.mark.asyncio
async def test_swarm_with_handoff_termination() -> None:
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
Expand Down
Loading