Skip to content

Commit

Permalink
feat: Add thought process handling in tool calls and expose ThoughtEv…
Browse files Browse the repository at this point in the history
…ent through stream in AgentChat (#5500)

Resolves #5192

Test

```python
import asyncio
import os
from random import randint
from typing import List
from autogen_core.tools import BaseTool, FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console

async def get_current_time(city: str) -> str:
    return f"The current time in {city} is {randint(0, 23)}:{randint(0, 59)}."

tools: List[BaseTool] = [
    FunctionTool(
        get_current_time,
        name="get_current_time",
        description="Get current time for a city.",
    ),
]

model_client = OpenAIChatCompletionClient(
    model="anthropic/claude-3.5-haiku-20241022",
    base_url="https://openrouter.ai/api/v1",
    api_key=os.environ["OPENROUTER_API_KEY"],
    model_info={
        "family": "claude-3.5-haiku",
        "function_calling": True,
        "vision": False,
        "json_output": False,
    }
)

agent = AssistantAgent(
    name="Agent",
    model_client=model_client,
    tools=tools,
    system_message= "You are an assistant with some tools that can be used to answer some questions",
)

async def main() -> None:
    await Console(agent.run_stream(task="What is current time of Paris and Toronto?"))

asyncio.run(main())
```

```
---------- user ----------
What is current time of Paris and Toronto?
---------- Agent ----------
I'll help you find the current time for Paris and Toronto by using the get_current_time function for each city.
---------- Agent ----------
[FunctionCall(id='toolu_01NwP3fNAwcYKn1x656Dq9xW', arguments='{"city": "Paris"}', name='get_current_time'), FunctionCall(id='toolu_018d4cWSy3TxXhjgmLYFrfRt', arguments='{"city": "Toronto"}', name='get_current_time')]
---------- Agent ----------
[FunctionExecutionResult(content='The current time in Paris is 1:10.', call_id='toolu_01NwP3fNAwcYKn1x656Dq9xW', is_error=False), FunctionExecutionResult(content='The current time in Toronto is 7:28.', call_id='toolu_018d4cWSy3TxXhjgmLYFrfRt', is_error=False)]
---------- Agent ----------
The current time in Paris is 1:10.
The current time in Toronto is 7:28.
```

---------

Co-authored-by: Jack Gerrits <[email protected]>
  • Loading branch information
ekzhu and jackgerrits authored Feb 21, 2025
1 parent 45c6d13 commit 7784f44
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
MemoryQueryEvent,
ModelClientStreamingChunkEvent,
TextMessage,
ThoughtEvent,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
Expand Down Expand Up @@ -418,7 +419,15 @@ async def on_messages_stream(
)

# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
await self._model_context.add_message(
AssistantMessage(content=model_result.content, source=self.name, thought=model_result.thought)
)

# Add thought to the inner messages.
if model_result.thought:
thought_event = ThoughtEvent(content=model_result.thought, source=self.name)
inner_messages.append(thought_event)
yield thought_event

# Check if the response is a string and return it.
if isinstance(model_result.content, str):
Expand Down Expand Up @@ -479,7 +488,9 @@ async def on_messages_stream(
# Current context for handoff.
handoff_context: List[LLMMessage] = []
if len(tool_calls) > 0:
handoff_context.append(AssistantMessage(content=tool_calls, source=self.name))
handoff_context.append(
AssistantMessage(content=tool_calls, source=self.name, thought=model_result.thought)
)
handoff_context.append(FunctionExecutionResultMessage(content=tool_call_results))
# Return the output messages to signal the handoff.
yield Response(
Expand Down Expand Up @@ -515,7 +526,9 @@ async def on_messages_stream(
assert isinstance(reflection_model_result.content, str)
# Add the response to the model context.
await self._model_context.add_message(
AssistantMessage(content=reflection_model_result.content, source=self.name)
AssistantMessage(
content=reflection_model_result.content, source=self.name, thought=reflection_model_result.thought
)
)
# Yield the response.
yield Response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,17 @@ class ModelClientStreamingChunkEvent(BaseAgentEvent):
type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"


class ThoughtEvent(BaseAgentEvent):
"""An event signaling the thought process of an agent.
It is used to communicate the reasoning tokens generated by a reasoning model,
or the extra text content generated by a function call."""

content: str
"""The thought process."""

type: Literal["ThoughtEvent"] = "ThoughtEvent"


ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
Expand All @@ -148,7 +159,8 @@ class ModelClientStreamingChunkEvent(BaseAgentEvent):
| ToolCallExecutionEvent
| MemoryQueryEvent
| UserInputRequestedEvent
| ModelClientStreamingChunkEvent,
| ModelClientStreamingChunkEvent
| ThoughtEvent,
Field(discriminator="type"),
]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""
Expand All @@ -168,4 +180,5 @@ class ModelClientStreamingChunkEvent(BaseAgentEvent):
"MemoryQueryEvent",
"UserInputRequestedEvent",
"ModelClientStreamingChunkEvent",
"ThoughtEvent",
]
49 changes: 27 additions & 22 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
ThoughtEvent,
)
from autogen_core import ComponentModel, FunctionCall, Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
Expand Down Expand Up @@ -89,7 +90,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
content="Calling pass function",
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
Expand Down Expand Up @@ -151,18 +152,20 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
)
result = await agent.run(task="task")

assert len(result.messages) == 4
assert len(result.messages) == 5
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], ToolCallSummaryMessage)
assert result.messages[3].content == "pass"
assert isinstance(result.messages[1], ThoughtEvent)
assert result.messages[1].content == "Calling pass function"
assert isinstance(result.messages[2], ToolCallRequestEvent)
assert result.messages[2].models_usage is not None
assert result.messages[2].models_usage.completion_tokens == 5
assert result.messages[2].models_usage.prompt_tokens == 10
assert isinstance(result.messages[3], ToolCallExecutionEvent)
assert result.messages[3].models_usage is None
assert isinstance(result.messages[4], ToolCallSummaryMessage)
assert result.messages[4].content == "pass"
assert result.messages[4].models_usage is None

# Test streaming.
mock.curr_index = 0 # Reset the mock
Expand Down Expand Up @@ -302,7 +305,7 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
content="Calling pass and echo functions",
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
Expand Down Expand Up @@ -380,30 +383,32 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
)
result = await agent.run(task="task")

assert len(result.messages) == 4
assert len(result.messages) == 5
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert result.messages[1].content == [
assert isinstance(result.messages[1], ThoughtEvent)
assert result.messages[1].content == "Calling pass and echo functions"
assert isinstance(result.messages[2], ToolCallRequestEvent)
assert result.messages[2].content == [
FunctionCall(id="1", arguments=r'{"input": "task1"}', name="_pass_function"),
FunctionCall(id="2", arguments=r'{"input": "task2"}', name="_pass_function"),
FunctionCall(id="3", arguments=r'{"input": "task3"}', name="_echo_function"),
]
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is not None
assert result.messages[2].models_usage.completion_tokens == 5
assert result.messages[2].models_usage.prompt_tokens == 10
assert isinstance(result.messages[3], ToolCallExecutionEvent)
expected_content = [
FunctionExecutionResult(call_id="1", content="pass", is_error=False),
FunctionExecutionResult(call_id="2", content="pass", is_error=False),
FunctionExecutionResult(call_id="3", content="task3", is_error=False),
]
for expected in expected_content:
assert expected in result.messages[2].content
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], ToolCallSummaryMessage)
assert result.messages[3].content == "pass\npass\ntask3"
assert expected in result.messages[3].content
assert result.messages[3].models_usage is None
assert isinstance(result.messages[4], ToolCallSummaryMessage)
assert result.messages[4].content == "pass\npass\ntask3"
assert result.messages[4].models_usage is None

# Test streaming.
mock.curr_index = 0 # Reset the mock
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class AssistantMessage(BaseModel):
content: Union[str, List[FunctionCall]]
"""The content of the message."""

thought: str | None = None
"""The reasoning text for the completion if available. Used for reasoning model and additional text content besides function calls."""

source: str
"""The name of the agent that sent this message."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,19 @@ def assistant_message_to_oai(
) -> ChatCompletionAssistantMessageParam:
assert_valid_name(message.source)
if isinstance(message.content, list):
return ChatCompletionAssistantMessageParam(
tool_calls=[func_call_to_oai(x) for x in message.content],
role="assistant",
name=message.source,
)
if message.thought is not None:
return ChatCompletionAssistantMessageParam(
content=message.thought,
tool_calls=[func_call_to_oai(x) for x in message.content],
role="assistant",
name=message.source,
)
else:
return ChatCompletionAssistantMessageParam(
tool_calls=[func_call_to_oai(x) for x in message.content],
role="assistant",
name=message.source,
)
else:
return ChatCompletionAssistantMessageParam(
content=message.content,
Expand Down Expand Up @@ -572,6 +580,7 @@ async def create(
# Detect whether it is a function call or not.
# We don't rely on choice.finish_reason as it is not always accurate, depending on the API used.
content: Union[str, List[FunctionCall]]
thought: str | None = None
if choice.message.function_call is not None:
raise ValueError("function_call is deprecated and is not supported by this model client.")
elif choice.message.tool_calls is not None and len(choice.message.tool_calls) > 0:
Expand All @@ -583,11 +592,8 @@ async def create(
stacklevel=2,
)
if choice.message.content is not None and choice.message.content != "":
warnings.warn(
"Both tool_calls and content are present in the message. "
"This is unexpected. content will be ignored, tool_calls will be used.",
stacklevel=2,
)
# Put the content in the thought field.
thought = choice.message.content
# NOTE: If OAI response type changes, this will need to be updated
content = []
for tool_call in choice.message.tool_calls:
Expand Down Expand Up @@ -626,8 +632,6 @@ async def create(

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

response = CreateResult(
finish_reason=normalize_stop_reason(finish_reason),
Expand Down Expand Up @@ -788,6 +792,8 @@ async def create_stream(
content_deltas.append(choice.delta.content)
if len(choice.delta.content) > 0:
yield choice.delta.content
# NOTE: for OpenAI, tool_calls and content are mutually exclusive it seems, so we can skip the rest of the loop.
# However, this may not be the case for other APIs -- we should expect this may need to be updated.
continue

# Otherwise, get tool calls
Expand Down Expand Up @@ -832,20 +838,24 @@ async def create_stream(
raise ValueError("Function calls are not supported in this context")

content: Union[str, List[FunctionCall]]
if len(content_deltas) > 1:
thought: str | None = None
if full_tool_calls:
# This is a tool call.
content = list(full_tool_calls.values())
if len(content_deltas) > 1:
# Put additional text content in the thought field.
thought = "".join(content_deltas)
elif len(content_deltas) > 0:
# This is a text-only content.
content = "".join(content_deltas)
if chunk and chunk.usage:
completion_tokens = chunk.usage.completion_tokens
else:
completion_tokens = 0
else:
warnings.warn("No text content or tool calls are available. Model returned empty result.", stacklevel=2)
content = ""

if chunk and chunk.usage:
completion_tokens = chunk.usage.completion_tokens
else:
completion_tokens = 0
# TODO: fix assumption that dict values were added in order and actually order by int index
# for tool_call in full_tool_calls.values():
# # value = json.dumps(tool_call)
# # completion_tokens += count_token(value, model=model)
# completion_tokens += 0
content = list(full_tool_calls.values())

usage = RequestUsage(
prompt_tokens=prompt_tokens,
Expand All @@ -854,8 +864,6 @@ async def create_stream(

if isinstance(content, str) and self._model_info["family"] == ModelFamily.R1:
thought, content = parse_r1_content(content)
else:
thought = None

result = CreateResult(
finish_reason=normalize_stop_reason(stop_reason),
Expand Down
Loading

0 comments on commit 7784f44

Please sign in to comment.