Skip to content

Commit

Permalink
Merge branch 'main' into assistant-agent-diagram
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu authored Jan 17, 2025
2 parents 7004b2b + 8643ff6 commit 4531de3
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
Expand Down Expand Up @@ -109,7 +110,8 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
Hand off behavior:
* If a handoff is triggered, a :class:`~autogen_agentchat.messages.HandoffMessage` will be returned in :attr:`~autogen_agentchat.base.Response.chat_message`.
* If there are tool calls, they will also be executed right away before returning the handoff. The tool call messages will be added to :attr:`~autogen_agentchat.messages.HandoffMessage.context` and sent to the target agent.
* If there are tool calls, they will also be executed right away before returning the handoff.
* The tool calls and results are passed to the target agent through :attr:`~autogen_agentchat.messages.HandoffMessage.context`.
.. note::
Expand Down Expand Up @@ -359,6 +361,10 @@ async def on_messages_stream(
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.model_info["vision"] is False:
raise ValueError("The model does not support vision.")
if isinstance(msg, HandoffMessage):
# Add handoff context to the model context.
for context_msg in msg.context:
await self._model_context.add_message(context_msg)
await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))

# Inner messages.
Expand All @@ -377,77 +383,108 @@ async def on_messages_stream(

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(
model_result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))

# Check if the response is a string and return it.
if isinstance(result.content, str):
if isinstance(model_result.content, str):
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
chat_message=TextMessage(
content=model_result.content, source=self.name, models_usage=model_result.usage
),
inner_messages=inner_messages,
)
return

# Process tool calls.
assert isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content)
tool_call_msg = ToolCallRequestEvent(content=result.content, source=self.name, models_usage=result.usage)
assert isinstance(model_result.content, list) and all(
isinstance(item, FunctionCall) for item in model_result.content
)
tool_call_msg = ToolCallRequestEvent(
content=model_result.content, source=self.name, models_usage=model_result.usage
)
event_logger.debug(tool_call_msg)
# Add the tool call message to the output.
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name)
exec_results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in model_result.content]
)
tool_call_result_msg = ToolCallExecutionEvent(content=exec_results, source=self.name)
event_logger.debug(tool_call_result_msg)
await self._model_context.add_message(FunctionExecutionResultMessage(content=results))
await self._model_context.add_message(FunctionExecutionResultMessage(content=exec_results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

# Correlate tool call results with tool calls.
tool_calls = [call for call in model_result.content if call.name not in self._handoffs]
tool_call_results: List[FunctionExecutionResult] = []
for tool_call in tool_calls:
found = False
for exec_result in exec_results:
if exec_result.call_id == tool_call.id:
found = True
tool_call_results.append(exec_result)
break
if not found:
raise RuntimeError(f"Tool call result not found for call id: {tool_call.id}")

# Detect handoff requests.
handoffs: List[HandoffBase] = []
for call in result.content:
if call.name in self._handoffs:
handoffs.append(self._handoffs[call.name])
if len(handoffs) > 0:
handoff_reqs = [call for call in model_result.content if call.name in self._handoffs]
if len(handoff_reqs) > 0:
handoffs = [self._handoffs[call.name] for call in handoff_reqs]
if len(handoffs) > 1:
# show warning if multiple handoffs detected
warnings.warn(
f"Multiple handoffs detected only the first is executed: {[handoff.name for handoff in handoffs]}",
(
f"Multiple handoffs detected only the first is executed: {[handoff.name for handoff in handoffs]}. "
"Disable parallel tool call in the model client to avoid this warning."
),
stacklevel=2,
)
# Current context for handoff.
handoff_context: List[LLMMessage] = []
if len(tool_calls) > 0:
handoff_context.append(AssistantMessage(content=tool_calls, source=self.name))
handoff_context.append(FunctionExecutionResultMessage(content=tool_call_results))
# Return the output messages to signal the handoff.
yield Response(
chat_message=HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name),
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name, context=handoff_context
),
inner_messages=inner_messages,
)
return

if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(model_result.content, str)
# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
# Yield the response.
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
chat_message=TextMessage(
content=model_result.content, source=self.name, models_usage=model_result.usage
),
inner_messages=inner_messages,
)
else:
# Return tool call result as the response.
tool_call_summaries: List[str] = []
for i in range(len(tool_call_msg.content)):
for tool_call, tool_call_result in zip(tool_calls, tool_call_results, strict=False):
tool_call_summaries.append(
self._tool_call_summary_format.format(
tool_name=tool_call_msg.content[i].name,
arguments=tool_call_msg.content[i].arguments,
result=tool_call_result_msg.content[i].content,
tool_name=tool_call.name,
arguments=tool_call.arguments,
result=tool_call_result.content,
),
)
tool_call_summary = "\n".join(tool_call_summaries)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class and includes specific fields relevant to the type of message being sent.

from autogen_core import FunctionCall, Image
from autogen_core.memory import MemoryContent
from autogen_core.models import FunctionExecutionResult, RequestUsage
from autogen_core.models import FunctionExecutionResult, LLMMessage, RequestUsage
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated

Expand Down Expand Up @@ -74,6 +74,9 @@ class HandoffMessage(BaseChatMessage):
content: str
"""The handoff message to the target agent."""

context: List[LLMMessage] = []
"""The model context to be passed to the target agent."""

type: Literal["HandoffMessage"] = "HandoffMessage"


Expand Down
140 changes: 138 additions & 2 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import Image
from autogen_core import FunctionCall, Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import LLMMessage
from autogen_core.models import FunctionExecutionResult, LLMMessage
from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
Expand Down Expand Up @@ -281,6 +281,142 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) ->
assert state == state2


@pytest.mark.asyncio
async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task1"}),
),
),
ChatCompletionMessageToolCall(
id="2",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task2"}),
),
),
ChatCompletionMessageToolCall(
id="3",
type="function",
function=Function(
name="_echo_function",
arguments=json.dumps({"input": "task3"}),
),
),
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="pass", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
result = await agent.run(task="task")

assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].content == [
FunctionCall(id="1", arguments=r'{"input": "task1"}', name="_pass_function"),
FunctionCall(id="2", arguments=r'{"input": "task2"}', name="_pass_function"),
FunctionCall(id="3", arguments=r'{"input": "task3"}', name="_echo_function"),
]
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
expected_content = [
FunctionExecutionResult(call_id="1", content="pass"),
FunctionExecutionResult(call_id="2", content="pass"),
FunctionExecutionResult(call_id="3", content="task3"),
]
for expected in expected_content:
assert expected in result.messages[2].content
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], ToolCallSummaryMessage)
assert result.messages[3].content == "pass\npass\ntask3"
assert result.messages[3].models_usage is None

# Test streaming.
mock.curr_index = 0 # Reset the mock
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1

# Test state saving and loading.
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
await agent2.load_state(state)
state2 = await agent2.save_state()
assert state == state2


@pytest.mark.asyncio
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
handoff = Handoff(target="agent2")
Expand Down
Loading

0 comments on commit 4531de3

Please sign in to comment.