Skip to content

Commit

Permalink
Define AgentEvent, rename tool call messages to events. (#4750)
Browse files Browse the repository at this point in the history
* Define AgentEvent, rename tool call messages to events.

* update doc

* Use AgentEvent | ChatMessage to replace AgentMessage

* Update docs

* update deprecation notice

* remove unused

* fix doc

* format
  • Loading branch information
ekzhu authored Dec 18, 2024
1 parent 7a7eb74 commit e902e94
Show file tree
Hide file tree
Showing 34 changed files with 3,642 additions and 3,615 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from ..base import Handoff as HandoffBase
from ..base import Response
from ..messages import (
AgentMessage,
AgentEvent,
ChatMessage,
HandoffMessage,
MultiModalMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
)
from ..state import AssistantAgentState
from ._base_chat_agent import BaseChatAgent
Expand Down Expand Up @@ -292,15 +292,15 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:

async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
# Add messages to the model context.
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False:
raise ValueError("The model does not support vision.")
self._model_context.append(UserMessage(content=msg.content, source=msg.source))

# Inner messages.
inner_messages: List[AgentMessage] = []
inner_messages: List[AgentEvent | ChatMessage] = []

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
Expand All @@ -321,15 +321,15 @@ async def on_messages_stream(

# Process tool calls.
assert isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content)
tool_call_msg = ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage)
tool_call_msg = ToolCallRequestEvent(content=result.content, source=self.name, models_usage=result.usage)
event_logger.debug(tool_call_msg)
# Add the tool call message to the output.
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
tool_call_result_msg = ToolCallResultMessage(content=results, source=self.name)
tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
inner_messages.append(tool_call_result_msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..base import ChatAgent, Response, TaskResult
from ..messages import (
AgentMessage,
AgentEvent,
ChatMessage,
TextMessage,
)
Expand Down Expand Up @@ -58,7 +58,7 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:

async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
"""Handles incoming messages and returns a stream of messages and
and the final item is the response. The base implementation in
:class:`BaseChatAgent` simply calls :meth:`on_messages` and yields
Expand Down Expand Up @@ -89,7 +89,7 @@ async def run(
if cancellation_token is None:
cancellation_token = CancellationToken()
input_messages: List[ChatMessage] = []
output_messages: List[AgentMessage] = []
output_messages: List[AgentEvent | ChatMessage] = []
if task is None:
pass
elif isinstance(task, str):
Expand Down Expand Up @@ -119,13 +119,13 @@ async def run_stream(
*,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
"""Run the agent with the given task and return a stream of messages
and the final task result as the last item in the stream."""
if cancellation_token is None:
cancellation_token = CancellationToken()
input_messages: List[ChatMessage] = []
output_messages: List[AgentMessage] = []
output_messages: List[AgentEvent | ChatMessage] = []
if task is None:
pass
elif isinstance(task, str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ..base import TaskResult, Team
from ..messages import (
AgentMessage,
AgentEvent,
ChatMessage,
HandoffMessage,
MultiModalMessage,
Expand Down Expand Up @@ -119,13 +119,13 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:

async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
# Prepare the task for the team of agents.
task = list(messages)

# Run the team of agents.
result: TaskResult | None = None
inner_messages: List[AgentMessage] = []
inner_messages: List[AgentEvent | ChatMessage] = []
count = 0
async for inner_msg in self._team.run_stream(task=task, cancellation_token=cancellation_token):
if isinstance(inner_msg, TaskResult):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from autogen_core import CancellationToken

from ..messages import AgentMessage, ChatMessage
from ..messages import AgentEvent, ChatMessage
from ._task import TaskRunner


Expand All @@ -14,7 +14,7 @@ class Response:
chat_message: ChatMessage
"""A chat message produced by the agent as the response."""

inner_messages: List[AgentMessage] | None = None
inner_messages: List[AgentEvent | ChatMessage] | None = None
"""Inner messages produced by the agent."""


Expand Down Expand Up @@ -46,7 +46,7 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:

def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
"""Handles incoming messages and returns a stream of inner messages and
and the final item is the response."""
...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

from autogen_core import CancellationToken

from ..messages import AgentMessage, ChatMessage
from ..messages import AgentEvent, ChatMessage


@dataclass
class TaskResult:
"""Result of running a task."""

messages: Sequence[AgentMessage]
messages: Sequence[AgentEvent | ChatMessage]
"""Messages produced by the task."""

stop_reason: str | None = None
Expand Down Expand Up @@ -38,7 +38,7 @@ def run_stream(
*,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
:class:`TaskResult` as the last item in the stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from typing import List, Sequence

from ..messages import AgentMessage, StopMessage
from ..messages import AgentEvent, ChatMessage, StopMessage


class TerminatedException(BaseException): ...
Expand Down Expand Up @@ -50,7 +50,7 @@ def terminated(self) -> bool:
...

@abstractmethod
async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
"""Check if the conversation should be terminated based on the messages received
since the last time the condition was called.
Return a StopMessage if the conversation should be terminated, or None otherwise.
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(self, *conditions: TerminationCondition) -> None:
def terminated(self) -> bool:
return all(condition.terminated for condition in self._conditions)

async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
if self.terminated:
raise TerminatedException("Termination condition has already been reached.")
# Check all remaining conditions.
Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(self, *conditions: TerminationCondition) -> None:
def terminated(self) -> bool:
return any(condition.terminated for condition in self._conditions)

async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
if self.terminated:
raise RuntimeError("Termination condition has already been reached")
stop_messages = await asyncio.gather(*[condition(messages) for condition in self._conditions])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Sequence

from ..base import TerminatedException, TerminationCondition
from ..messages import AgentMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
from ..messages import AgentEvent, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage


class StopMessageTermination(TerminationCondition):
Expand All @@ -15,7 +15,7 @@ def __init__(self) -> None:
def terminated(self) -> bool:
return self._terminated

async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(self, max_messages: int) -> None:
def terminated(self) -> bool:
return self._message_count >= self._max_messages

async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
if self.terminated:
raise TerminatedException("Termination condition has already been reached")
self._message_count += len(messages)
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self, text: str) -> None:
def terminated(self) -> bool:
return self._terminated

async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
Expand Down Expand Up @@ -128,7 +128,7 @@ def terminated(self) -> bool:
or (self._max_completion_token is not None and self._completion_token_count >= self._max_completion_token)
)

async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
if self.terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(self, target: str) -> None:
def terminated(self) -> bool:
return self._terminated

async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(self, timeout_seconds: float) -> None:
def terminated(self) -> bool:
return self._terminated

async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")

Expand Down Expand Up @@ -242,7 +242,7 @@ def set(self) -> None:
"""Set the termination condition to terminated."""
self._setted = True

async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
if self._setted:
Expand Down Expand Up @@ -273,7 +273,7 @@ def __init__(self, sources: List[str]) -> None:
def terminated(self) -> bool:
return self._terminated

async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
if not messages:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from autogen_core import FunctionCall, Image
from autogen_core.models import FunctionExecutionResult, RequestUsage
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from typing_extensions import Annotated, deprecated


class BaseMessage(BaseModel):
Expand Down Expand Up @@ -63,6 +63,7 @@ class HandoffMessage(BaseMessage):
type: Literal["HandoffMessage"] = "HandoffMessage"


@deprecated("Will be removed in 0.4.0, use ToolCallRequestEvent instead.")
class ToolCallMessage(BaseMessage):
"""A message signaling the use of tools."""

Expand All @@ -72,6 +73,7 @@ class ToolCallMessage(BaseMessage):
type: Literal["ToolCallMessage"] = "ToolCallMessage"


@deprecated("Will be removed in 0.4.0, use ToolCallExecutionEvent instead.")
class ToolCallResultMessage(BaseMessage):
"""A message signaling the results of tool calls."""

Expand All @@ -81,15 +83,37 @@ class ToolCallResultMessage(BaseMessage):
type: Literal["ToolCallResultMessage"] = "ToolCallResultMessage"


class ToolCallRequestEvent(BaseMessage):
"""An event signaling a request to use tools."""

content: List[FunctionCall]
"""The tool calls."""

type: Literal["ToolCallRequestEvent"] = "ToolCallRequestEvent"


class ToolCallExecutionEvent(BaseMessage):
"""An event signaling the execution of tool calls."""

content: List[FunctionExecutionResult]
"""The tool call results."""

type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent"


ChatMessage = Annotated[TextMessage | MultiModalMessage | StopMessage | HandoffMessage, Field(discriminator="type")]
"""Messages for agent-to-agent communication."""
"""Messages for agent-to-agent communication only."""


AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(discriminator="type")]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""


AgentMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallMessage | ToolCallResultMessage,
TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallRequestEvent | ToolCallExecutionEvent,
Field(discriminator="type"),
]
"""All message types."""
"""(Deprecated, will be removed in 0.4.0) All message and event types."""


__all__ = [
Expand All @@ -98,8 +122,11 @@ class ToolCallResultMessage(BaseMessage):
"MultiModalMessage",
"StopMessage",
"HandoffMessage",
"ToolCallRequestEvent",
"ToolCallExecutionEvent",
"ToolCallMessage",
"ToolCallResultMessage",
"ChatMessage",
"AgentEvent",
"AgentMessage",
]
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel, Field

from ..messages import (
AgentMessage,
AgentEvent,
ChatMessage,
)

Expand Down Expand Up @@ -36,7 +36,7 @@ class TeamState(BaseState):
class BaseGroupChatManagerState(BaseState):
"""Base state for all group chat managers."""

message_thread: List[AgentMessage] = Field(default_factory=list)
message_thread: List[AgentEvent | ChatMessage] = Field(default_factory=list)
current_turn: int = Field(default=0)
type: str = Field(default="BaseGroupChatManagerState")

Expand Down
Loading

0 comments on commit e902e94

Please sign in to comment.