Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass context between AssistantAgent for handoffs #5084

Merged
merged 6 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
Expand Down Expand Up @@ -105,6 +106,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):

* If a handoff is triggered, a :class:`~autogen_agentchat.messages.HandoffMessage` will be returned in :attr:`~autogen_agentchat.base.Response.chat_message`.
* If there are tool calls, they will also be executed right away before returning the handoff.
* The tool calls and results are passed to the target agent through :attr:`~autogen_agentchat.messages.HandoffMessage.context`.


.. note::
Expand Down Expand Up @@ -353,6 +355,10 @@ async def on_messages_stream(
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.model_info["vision"] is False:
raise ValueError("The model does not support vision.")
if isinstance(msg, HandoffMessage):
# Add handoff context to the model context.
for context_msg in msg.context:
await self._model_context.add_message(context_msg)
await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))

# Inner messages.
Expand All @@ -371,77 +377,108 @@ async def on_messages_stream(

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(
model_result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))

# Check if the response is a string and return it.
if isinstance(result.content, str):
if isinstance(model_result.content, str):
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
chat_message=TextMessage(
content=model_result.content, source=self.name, models_usage=model_result.usage
),
inner_messages=inner_messages,
)
return

# Process tool calls.
assert isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content)
tool_call_msg = ToolCallRequestEvent(content=result.content, source=self.name, models_usage=result.usage)
assert isinstance(model_result.content, list) and all(
isinstance(item, FunctionCall) for item in model_result.content
)
tool_call_msg = ToolCallRequestEvent(
content=model_result.content, source=self.name, models_usage=model_result.usage
)
event_logger.debug(tool_call_msg)
# Add the tool call message to the output.
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name)
exec_results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in model_result.content]
)
tool_call_result_msg = ToolCallExecutionEvent(content=exec_results, source=self.name)
event_logger.debug(tool_call_result_msg)
await self._model_context.add_message(FunctionExecutionResultMessage(content=results))
await self._model_context.add_message(FunctionExecutionResultMessage(content=exec_results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

# Correlate tool call results with tool calls.
tool_calls = [call for call in model_result.content if call.name not in self._handoffs]
tool_call_results: List[FunctionExecutionResult] = []
for tool_call in tool_calls:
found = False
for exec_result in exec_results:
if exec_result.call_id == tool_call.id:
found = True
tool_call_results.append(exec_result)
break
if not found:
raise RuntimeError(f"Tool call result not found for call id: {tool_call.id}")

# Detect handoff requests.
handoffs: List[HandoffBase] = []
for call in result.content:
if call.name in self._handoffs:
handoffs.append(self._handoffs[call.name])
if len(handoffs) > 0:
handoff_reqs = [call for call in model_result.content if call.name in self._handoffs]
if len(handoff_reqs) > 0:
handoffs = [self._handoffs[call.name] for call in handoff_reqs]
if len(handoffs) > 1:
# show warning if multiple handoffs detected
warnings.warn(
f"Multiple handoffs detected only the first is executed: {[handoff.name for handoff in handoffs]}",
(
f"Multiple handoffs detected only the first is executed: {[handoff.name for handoff in handoffs]}. "
"Disable parallel tool call in the model client to avoid this warning."
),
stacklevel=2,
)
# Current context for handoff.
handoff_context: List[LLMMessage] = []
if len(tool_calls) > 0:
handoff_context.append(AssistantMessage(content=tool_calls, source=self.name))
handoff_context.append(FunctionExecutionResultMessage(content=tool_call_results))
# Return the output messages to signal the handoff.
yield Response(
chat_message=HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name),
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name, context=handoff_context
),
inner_messages=inner_messages,
)
return

if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(model_result.content, str)
# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
# Yield the response.
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
chat_message=TextMessage(
content=model_result.content, source=self.name, models_usage=model_result.usage
),
inner_messages=inner_messages,
)
else:
# Return tool call result as the response.
tool_call_summaries: List[str] = []
for i in range(len(tool_call_msg.content)):
for tool_call, tool_call_result in zip(tool_calls, tool_call_results, strict=False):
tool_call_summaries.append(
self._tool_call_summary_format.format(
tool_name=tool_call_msg.content[i].name,
arguments=tool_call_msg.content[i].arguments,
result=tool_call_result_msg.content[i].content,
tool_name=tool_call.name,
arguments=tool_call.arguments,
result=tool_call_result.content,
),
)
tool_call_summary = "\n".join(tool_call_summaries)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class and includes specific fields relevant to the type of message being sent.

from autogen_core import FunctionCall, Image
from autogen_core.memory import MemoryContent
from autogen_core.models import FunctionExecutionResult, RequestUsage
from autogen_core.models import FunctionExecutionResult, LLMMessage, RequestUsage
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated

Expand Down Expand Up @@ -74,6 +74,9 @@ class HandoffMessage(BaseChatMessage):
content: str
"""The handoff message to the target agent."""

context: List[LLMMessage] = []
"""The model context to be passed to the target agent."""

type: Literal["HandoffMessage"] = "HandoffMessage"


Expand Down
140 changes: 138 additions & 2 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import Image
from autogen_core import FunctionCall, Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import LLMMessage
from autogen_core.models import FunctionExecutionResult, LLMMessage
from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
Expand Down Expand Up @@ -281,6 +281,142 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) ->
assert state == state2


@pytest.mark.asyncio
async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task1"}),
),
),
ChatCompletionMessageToolCall(
id="2",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task2"}),
),
),
ChatCompletionMessageToolCall(
id="3",
type="function",
function=Function(
name="_echo_function",
arguments=json.dumps({"input": "task3"}),
),
),
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="pass", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
result = await agent.run(task="task")

assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].content == [
FunctionCall(id="1", arguments=r'{"input": "task1"}', name="_pass_function"),
FunctionCall(id="2", arguments=r'{"input": "task2"}', name="_pass_function"),
FunctionCall(id="3", arguments=r'{"input": "task3"}', name="_echo_function"),
]
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
expected_content = [
FunctionExecutionResult(call_id="1", content="pass"),
FunctionExecutionResult(call_id="2", content="pass"),
FunctionExecutionResult(call_id="3", content="task3"),
]
for expected in expected_content:
assert expected in result.messages[2].content
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], ToolCallSummaryMessage)
assert result.messages[3].content == "pass\npass\ntask3"
assert result.messages[3].models_usage is None

# Test streaming.
mock.curr_index = 0 # Reset the mock
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1

# Test state saving and loading.
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
await agent2.load_state(state)
state2 = await agent2.save_state()
assert state == state2


@pytest.mark.asyncio
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
handoff = Handoff(target="agent2")
Expand Down
Loading
Loading