Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow returning thought in the tool call responses #5173

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Sequence,
)

from autogen_core import CancellationToken, Component, ComponentModel, FunctionCall
from autogen_core import CancellationToken, Component, ComponentModel, FunctionCall, FunctionCalls
from autogen_core.memory import Memory
from autogen_core.model_context import (
ChatCompletionContext,
Expand Down Expand Up @@ -401,9 +401,10 @@ async def on_messages_stream(
return

# Process tool calls.
assert isinstance(model_result.content, list) and all(
isinstance(item, FunctionCall) for item in model_result.content
assert isinstance(model_result.content, FunctionCalls) and all(
isinstance(item, FunctionCall) or isinstance(item, str) for item in model_result.content.function_calls
)

tool_call_msg = ToolCallRequestEvent(
content=model_result.content, source=self.name, models_usage=model_result.usage
)
Expand All @@ -412,9 +413,15 @@ async def on_messages_stream(
inner_messages.append(tool_call_msg)
yield tool_call_msg

function_calls = model_result.content.function_calls

# Execute the tool calls.
exec_results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in model_result.content]
*[
self._execute_tool_call(call, cancellation_token)
for call in function_calls
if isinstance(call, FunctionCall)
]
)
tool_call_result_msg = ToolCallExecutionEvent(content=exec_results, source=self.name)
event_logger.debug(tool_call_result_msg)
Expand All @@ -423,7 +430,7 @@ async def on_messages_stream(
yield tool_call_result_msg

# Correlate tool call results with tool calls.
tool_calls = [call for call in model_result.content if call.name not in self._handoffs]
tool_calls = [call for call in function_calls if call.name not in self._handoffs]
tool_call_results: List[FunctionExecutionResult] = []
for tool_call in tool_calls:
found = False
Expand All @@ -436,7 +443,7 @@ async def on_messages_stream(
raise RuntimeError(f"Tool call result not found for call id: {tool_call.id}")

# Detect handoff requests.
handoff_reqs = [call for call in model_result.content if call.name in self._handoffs]
handoff_reqs = [call for call in function_calls if call.name in self._handoffs]
if len(handoff_reqs) > 0:
handoffs = [self._handoffs[call.name] for call in handoff_reqs]
if len(handoffs) > 1:
Expand All @@ -451,7 +458,12 @@ async def on_messages_stream(
# Current context for handoff.
handoff_context: List[LLMMessage] = []
if len(tool_calls) > 0:
handoff_context.append(AssistantMessage(content=tool_calls, source=self.name))
handoff_context.append(
AssistantMessage(
content=FunctionCalls(function_calls=tool_calls, thought=model_result.content.thought),
source=self.name,
)
)
handoff_context.append(FunctionExecutionResultMessage(content=tool_call_results))
# Return the output messages to signal the handoff.
yield Response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class and includes specific fields relevant to the type of message being sent.
from abc import ABC
from typing import List, Literal

from autogen_core import FunctionCall, Image
from autogen_core import FunctionCalls, Image
from autogen_core.memory import MemoryContent
from autogen_core.models import FunctionExecutionResult, LLMMessage, RequestUsage
from pydantic import BaseModel, ConfigDict, Field
Expand Down Expand Up @@ -83,7 +83,7 @@ class HandoffMessage(BaseChatMessage):
class ToolCallRequestEvent(BaseAgentEvent):
"""An event signaling a request to use tools."""

content: List[FunctionCall]
content: FunctionCalls
"""The tool calls."""

type: Literal["ToolCallRequestEvent"] = "ToolCallRequestEvent"
Expand Down
52 changes: 49 additions & 3 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
from typing import Any, AsyncGenerator, List
from unittest.mock import AsyncMock

import pytest
from autogen_agentchat import EVENT_LOGGER_NAME
Expand All @@ -17,10 +18,10 @@
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import FunctionCall, Image
from autogen_core import FunctionCall, FunctionCalls, Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import FunctionExecutionResult, LLMMessage
from autogen_core.models import CreateResult, FunctionExecutionResult, LLMMessage, RequestUsage
from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
Expand Down Expand Up @@ -176,6 +177,51 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert state == state2


@pytest.mark.asyncio
async def test_run_stream_with_thought() -> None:
mocked_model = AsyncMock()
mocked_model.create.side_effect = [
CreateResult(
finish_reason="function_calls",
content=FunctionCalls(
function_calls=[FunctionCall(id="call_foo", name="echo", arguments=json.dumps({"input": "foo"}))],
thought="going to say foo!",
),
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
CreateResult(
finish_reason="stop",
content="ok!",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
CreateResult(
finish_reason="stop",
content="TERMINATE",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
]

agent = AssistantAgent(
"thoughtfull_tool_use_agent",
model_client=mocked_model,
tools=[
FunctionTool(_echo_function, description="Echo", name="echo"),
],
)

streamed_messages = [item async for item in agent.run_stream(task="prompt")]
assert len(streamed_messages) == 5
assert isinstance(streamed_messages[0], TextMessage)
assert isinstance(streamed_messages[1], ToolCallRequestEvent)
assert streamed_messages[1].content.thought == "going to say foo!"
assert isinstance(streamed_messages[2], ToolCallExecutionEvent)
assert isinstance(streamed_messages[3], ToolCallSummaryMessage)
assert isinstance(streamed_messages[4], TaskResult)


@pytest.mark.asyncio
async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
Expand Down Expand Up @@ -374,7 +420,7 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].content == [
assert result.messages[1].content.function_calls == [
FunctionCall(id="1", arguments=r'{"input": "task1"}', name="_pass_function"),
FunctionCall(id="2", arguments=r'{"input": "task2"}', name="_pass_function"),
FunctionCall(id="3", arguments=r'{"input": "task3"}', name="_echo_function"),
Expand Down
13 changes: 8 additions & 5 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
from autogen_agentchat.ui import Console
from autogen_core import AgentId, CancellationToken, FunctionCall
from autogen_core import AgentId, CancellationToken, FunctionCall, FunctionCalls
from autogen_core.models import (
AssistantMessage,
FunctionExecutionResult,
Expand Down Expand Up @@ -220,6 +220,7 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
result_2 = await team.run(
task=MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
)
assert isinstance(result_2.messages[0].content, list)
assert result.messages[0].content == result_2.messages[0].content[0]
assert result.messages[1:] == result_2.messages[1:]

Expand Down Expand Up @@ -1065,10 +1066,12 @@ async def test_swarm_with_parallel_tool_calls(monkeypatch: pytest.MonkeyPatch) -
expected_handoff_context: List[LLMMessage] = [
AssistantMessage(
source="agent1",
content=[
FunctionCall(id="1", name="tool1", arguments="{}"),
FunctionCall(id="2", name="tool2", arguments="{}"),
],
content=FunctionCalls(
function_calls=[
FunctionCall(id="1", name="tool1", arguments="{}"),
FunctionCall(id="2", name="tool2", arguments="{}"),
]
),
),
FunctionExecutionResultMessage(
content=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ from autogen_agentchat.messages import (
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import FunctionCall, Image
from autogen_core import FunctionCall, FunctionCalls, Image
from autogen_core.models import FunctionExecutionResult


Expand Down Expand Up @@ -660,7 +660,7 @@ def convert_to_v02_message(
raise ValueError(f"Invalid multimodal message content: {modal}")
elif isinstance(message, ToolCallRequestEvent):
v02_message = {"tool_calls": [], "role": "assistant", "content": None, "name": message.source}
for tool_call in message.content:
for tool_call in message.content.function_calls:
v02_message["tool_calls"].append(
{
"id": tool_call.id,
Expand Down Expand Up @@ -697,7 +697,7 @@ def convert_to_v04_message(message: Dict[str, Any]) -> AgentEvent | ChatMessage:
arguments=tool_call["function"]["args"],
)
)
return ToolCallRequestEvent(source=message["name"], content=tool_calls)
return ToolCallRequestEvent(source=message["name"], content=FunctionCalls(function_calls=tool_calls))
elif "tool_responses" in message:
tool_results: List[FunctionExecutionResult] = []
for tool_response in message["tool_responses"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"\n",
"from autogen_core import (\n",
" FunctionCall,\n",
" FunctionCalls,\n",
" MessageContext,\n",
" RoutedAgent,\n",
" SingleThreadedAgentRuntime,\n",
Expand Down Expand Up @@ -173,11 +174,13 @@
" )\n",
" print(f\"{'-'*80}\\n{self.id.type}:\\n{llm_result.content}\", flush=True)\n",
" # Process the LLM result.\n",
" while isinstance(llm_result.content, list) and all(isinstance(m, FunctionCall) for m in llm_result.content):\n",
" while isinstance(llm_result.content, FunctionCalls) and all(\n",
" isinstance(m, FunctionCall) for m in llm_result.content.function_calls\n",
" ):\n",
" tool_call_results: List[FunctionExecutionResult] = []\n",
" delegate_targets: List[Tuple[str, UserTask]] = []\n",
" # Process each function call.\n",
" for call in llm_result.content:\n",
" for call in llm_result.content.function_calls:\n",
" arguments = json.loads(call.arguments)\n",
" if call.name in self._tools:\n",
" # Execute the tool directly.\n",
Expand All @@ -190,7 +193,7 @@
" topic_type = self._delegate_tools[call.name].return_value_as_string(result)\n",
" # Create the context for the delegate agent, including the function call and the result.\n",
" delegate_messages = list(message.context) + [\n",
" AssistantMessage(content=[call], source=self.id.type),\n",
" AssistantMessage(content=FunctionCalls(function_calls=[call]), source=self.id.type),\n",
" FunctionExecutionResultMessage(\n",
" content=[\n",
" FunctionExecutionResult(\n",
Expand Down
3 changes: 2 additions & 1 deletion python/packages/autogen-core/src/autogen_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from ._topic import TopicId
from ._type_prefix_subscription import TypePrefixSubscription
from ._type_subscription import TypeSubscription
from ._types import FunctionCall
from ._types import FunctionCall, FunctionCalls

EVENT_LOGGER_NAME = EVENT_LOGGER_NAME_ALIAS
"""The name of the logger used for structured events."""
Expand Down Expand Up @@ -107,6 +107,7 @@
"event",
"rpc",
"FunctionCall",
"FunctionCalls",
"TypeSubscription",
"DefaultSubscription",
"DefaultTopicId",
Expand Down
7 changes: 7 additions & 0 deletions python/packages/autogen-core/src/autogen_core/_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional


@dataclass
Expand All @@ -10,3 +11,9 @@ class FunctionCall:
arguments: str
# Function to call
name: str


@dataclass
class FunctionCalls:
function_calls: List[FunctionCall]
thought: Optional[str] = None
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing_extensions import Self

from .._component_config import Component
from .._types import FunctionCall
from .._types import FunctionCall, FunctionCalls
from ..models import AssistantMessage, FunctionExecutionResultMessage, LLMMessage, UserMessage
from ._chat_completion_context import ChatCompletionContext

Expand Down Expand Up @@ -45,8 +45,8 @@ async def get_messages(self) -> List[LLMMessage]:
if (
head_messages
and isinstance(head_messages[-1], AssistantMessage)
and isinstance(head_messages[-1].content, list)
and all(isinstance(item, FunctionCall) for item in head_messages[-1].content)
and isinstance(head_messages[-1].content, FunctionCalls)
and all(isinstance(item, FunctionCall) for item in head_messages[-1].content.function_calls)
):
# Remove the last message from the head.
head_messages = head_messages[:-1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from .. import FunctionCall, Image
from .. import FunctionCalls, Image


class SystemMessage(BaseModel):
Expand All @@ -22,7 +22,7 @@ class UserMessage(BaseModel):


class AssistantMessage(BaseModel):
content: Union[str, List[FunctionCall]]
content: Union[str, FunctionCalls]

# Name of the agent that sent this message
source: str
Expand Down Expand Up @@ -70,7 +70,7 @@ class ChatCompletionTokenLogprob(BaseModel):

class CreateResult(BaseModel):
finish_reason: FinishReasons
content: Union[str, List[FunctionCall]]
content: Union[str, FunctionCalls]
usage: RequestUsage
cached: bool
logprobs: Optional[List[ChatCompletionTokenLogprob] | None] = None
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from typing import List

from .. import AgentId, AgentRuntime, BaseAgent, CancellationToken, FunctionCall
from .. import AgentId, AgentRuntime, BaseAgent, CancellationToken, FunctionCall, FunctionCalls
from ..models import (
AssistantMessage,
ChatCompletionClient,
Expand Down Expand Up @@ -43,7 +43,9 @@ async def tool_agent_caller_loop(
generated_messages.append(AssistantMessage(content=response.content, source=caller_source))

# Keep iterating until the model stops generating tool calls.
while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content):
while isinstance(response.content, FunctionCalls) and all(
isinstance(item, FunctionCall) for item in response.content.function_calls
):
# Execute functions called by the model by sending messages to tool agent.
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
*[
Expand All @@ -52,7 +54,7 @@ async def tool_agent_caller_loop(
recipient=tool_agent_id,
cancellation_token=cancellation_token,
)
for call in response.content
for call in response.content.function_calls
],
return_exceptions=True,
)
Expand Down
1 change: 1 addition & 0 deletions python/packages/autogen-core/tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ async def test_list_memory_update_context() -> None:
context_messages = await context.get_messages()
assert len(results.memories.results) == 2
assert len(context_messages) == 1
assert isinstance(context_messages[0].content, str)
assert "test1" in context_messages[0].content
assert "test2" in context_messages[0].content

Expand Down
Loading