From 18efc2314a96d28841645a6585095e495b3c1ebf Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Sat, 28 Sep 2024 08:40:13 -0700 Subject: [PATCH] Use agentchat message types rather than core's model client message types (#662) * Use agentchat message types rather than core's model client message types * Merge remote-tracking branch 'origin/main' into ekzhu-tool-use-assistant --- .../src/autogen_agentchat/agents/__init__.py | 19 +++++-- .../agents/_base_chat_agent.py | 52 ++++++++++++++++--- .../{coding => }/_code_executor_agent.py | 14 ++--- .../{coding => }/_coding_assistant_agent.py | 34 +++++++----- .../agents/coding/__init__.py | 0 .../group_chat/_base_chat_agent_container.py | 18 ++++--- .../group_chat/_base_group_chat_manager.py | 24 +++++---- .../teams/group_chat/_events.py | 21 ++++++++ .../teams/group_chat/_messages.py | 25 --------- .../group_chat/_round_robin_group_chat.py | 13 +++-- .../_round_robin_group_chat_manager.py | 5 +- 11 files changed, 142 insertions(+), 83 deletions(-) rename python/packages/autogen-agentchat/src/autogen_agentchat/agents/{coding => }/_code_executor_agent.py (70%) rename python/packages/autogen-agentchat/src/autogen_agentchat/agents/{coding => }/_coding_assistant_agent.py (72%) delete mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/agents/coding/__init__.py create mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_events.py delete mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_messages.py diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py index f47329ac2541..e466164d944c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py @@ -1,10 +1,23 @@ -from ._base_chat_agent import BaseChatAgent, ChatMessage -from .coding._code_executor_agent import CodeExecutorAgent -from .coding._coding_assistant_agent import CodingAssistantAgent +from ._base_chat_agent import ( + BaseChatAgent, + ChatMessage, + MultiModalMessage, + StopMessage, + TextMessage, + ToolCallMessage, + ToolCallResultMessage, +) +from ._code_executor_agent import CodeExecutorAgent +from ._coding_assistant_agent import CodingAssistantAgent __all__ = [ "BaseChatAgent", "ChatMessage", + "TextMessage", + "MultiModalMessage", + "ToolCallMessage", + "ToolCallResultMessage", + "StopMessage", "CodeExecutorAgent", "CodingAssistantAgent", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index 8199ebabbc93..eb3fc875f6e6 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -1,20 +1,56 @@ from abc import ABC, abstractmethod -from typing import Sequence +from typing import List, Sequence from autogen_core.base import CancellationToken -from autogen_core.components.models import AssistantMessage, UserMessage +from autogen_core.components import FunctionCall, Image +from autogen_core.components.models import FunctionExecutionResult from pydantic import BaseModel -class ChatMessage(BaseModel): - """A chat message from a user or agent.""" +class BaseMessage(BaseModel): + """A base message.""" - content: UserMessage | AssistantMessage + source: str + """The name of the agent that sent this message.""" + + +class TextMessage(BaseMessage): + """A text message.""" + + content: str """The content of the message.""" - request_pause: bool - """A flag indicating whether the current conversation session should be - paused after processing this message.""" + +class MultiModalMessage(BaseMessage): + """A multimodal message.""" + + content: List[str | Image] + """The content of the message.""" + + +class ToolCallMessage(BaseMessage): + """A message containing a list of function calls.""" + + content: List[FunctionCall] + """The list of function calls.""" + + +class ToolCallResultMessage(BaseMessage): + """A message containing the results of function calls.""" + + content: List[FunctionExecutionResult] + """The list of function execution results.""" + + +class StopMessage(BaseMessage): + """A message requesting stop of a conversation.""" + + content: str + """The content for the stop message.""" + + +ChatMessage = TextMessage | MultiModalMessage | ToolCallMessage | ToolCallResultMessage | StopMessage +"""A message used by agents in a team.""" class BaseChatAgent(ABC): diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/coding/_code_executor_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_code_executor_agent.py similarity index 70% rename from python/packages/autogen-agentchat/src/autogen_agentchat/agents/coding/_code_executor_agent.py rename to python/packages/autogen-agentchat/src/autogen_agentchat/agents/_code_executor_agent.py index 0b3e47297515..7cdc182f1ee7 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/coding/_code_executor_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_code_executor_agent.py @@ -2,9 +2,8 @@ from autogen_core.base import CancellationToken from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks -from autogen_core.components.models import UserMessage -from .._base_chat_agent import BaseChatAgent, ChatMessage +from ._base_chat_agent import BaseChatAgent, ChatMessage, TextMessage class CodeExecutorAgent(BaseChatAgent): @@ -21,14 +20,11 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: # Extract code blocks from the messages. code_blocks: List[CodeBlock] = [] for msg in messages: - if isinstance(msg.content, UserMessage) and isinstance(msg.content.content, str): - code_blocks.extend(extract_markdown_code_blocks(msg.content.content)) + if isinstance(msg, TextMessage): + code_blocks.extend(extract_markdown_code_blocks(msg.content)) if code_blocks: # Execute the code blocks. result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token) - return ChatMessage(content=UserMessage(content=result.output, source=self.name), request_pause=False) + return TextMessage(content=result.output, source=self.name) else: - return ChatMessage( - content=UserMessage(content="No code blocks found in the thread.", source=self.name), - request_pause=False, - ) + return TextMessage(content="No code blocks found in the thread.", source=self.name) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/coding/_coding_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_coding_assistant_agent.py similarity index 72% rename from python/packages/autogen-agentchat/src/autogen_agentchat/agents/coding/_coding_assistant_agent.py rename to python/packages/autogen-agentchat/src/autogen_agentchat/agents/_coding_assistant_agent.py index 02e503cfa63d..2fa737c7e582 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/coding/_coding_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_coding_assistant_agent.py @@ -1,9 +1,15 @@ from typing import List, Sequence from autogen_core.base import CancellationToken -from autogen_core.components.models import AssistantMessage, ChatCompletionClient, SystemMessage, UserMessage +from autogen_core.components.models import ( + AssistantMessage, + ChatCompletionClient, + LLMMessage, + SystemMessage, + UserMessage, +) -from .._base_chat_agent import BaseChatAgent, ChatMessage +from ._base_chat_agent import BaseChatAgent, ChatMessage, MultiModalMessage, StopMessage, TextMessage class CodingAssistantAgent(BaseChatAgent): @@ -27,22 +33,26 @@ def __init__(self, name: str, model_client: ChatCompletionClient): super().__init__(name=name, description=self.DESCRIPTION) self._model_client = model_client self._system_messages = [SystemMessage(content=self.SYSTEM_MESSAGE)] - self._message_thread: List[UserMessage | AssistantMessage] = [] + self._model_context: List[LLMMessage] = [] async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: - # Add messages to the thread. + # Add messages to the model context and detect stopping. for msg in messages: - self._message_thread.append(msg.content) + if not isinstance(msg, TextMessage | MultiModalMessage | StopMessage): + raise ValueError(f"Unsupported message type: {type(msg)}") + self._model_context.append(UserMessage(content=msg.content, source=msg.source)) - # Generate an inference result based on the thread. - llm_messages = self._system_messages + self._message_thread + # Generate an inference result based on the current model context. + llm_messages = self._system_messages + self._model_context result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token) assert isinstance(result.content, str) - # Add the response to the thread. - self._message_thread.append(AssistantMessage(content=result.content, source=self.name)) + # Add the response to the model context. + self._model_context.append(AssistantMessage(content=result.content, source=self.name)) - # Detect pause request. - request_pause = "terminate" in result.content.strip().lower() + # Detect stop request. + request_stop = "terminate" in result.content.strip().lower() + if request_stop: + return StopMessage(content=result.content, source=self.name) - return ChatMessage(content=UserMessage(content=result.content, source=self.name), request_pause=request_pause) + return TextMessage(content=result.content, source=self.name) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/coding/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/coding/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_chat_agent_container.py index 95e19eefcdcc..4260009e024b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_chat_agent_container.py @@ -3,8 +3,8 @@ from autogen_core.base import MessageContext from autogen_core.components import DefaultTopicId, RoutedAgent, event -from ...agents import BaseChatAgent, ChatMessage -from ._messages import ContentPublishEvent, ContentRequestEvent +from ...agents import BaseChatAgent, MultiModalMessage, StopMessage, TextMessage +from ._events import ContentPublishEvent, ContentRequestEvent class BaseChatAgentContainer(RoutedAgent): @@ -21,20 +21,26 @@ def __init__(self, parent_topic_type: str, agent: BaseChatAgent) -> None: super().__init__(description=agent.description) self._parent_topic_type = parent_topic_type self._agent = agent - self._message_buffer: List[ChatMessage] = [] + self._message_buffer: List[TextMessage | MultiModalMessage | StopMessage] = [] @event async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None: """Handle a content publish event by appending the content to the buffer.""" - self._message_buffer.append(ChatMessage(content=message.content, request_pause=message.request_pause)) + if not isinstance(message.agent_message, TextMessage | MultiModalMessage | StopMessage): + raise ValueError( + f"Unexpected message type: {type(message.agent_message)}. " + "The message must be a text, multimodal, or stop message." + ) + self._message_buffer.append(message.agent_message) @event async def handle_content_request(self, message: ContentRequestEvent, ctx: MessageContext) -> None: """Handle a content request event by passing the messages in the buffer to the delegate agent and publish the response.""" response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token) + # TODO: handle tool call messages. + assert isinstance(response, TextMessage | MultiModalMessage | StopMessage) self._message_buffer.clear() await self.publish_message( - ContentPublishEvent(content=response.content, request_pause=response.request_pause), - topic_id=DefaultTopicId(type=self._parent_topic_type), + ContentPublishEvent(agent_message=response), topic_id=DefaultTopicId(type=self._parent_topic_type) ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_group_chat_manager.py index 5bd4a32a88a7..6fb7847ad289 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_group_chat_manager.py @@ -3,9 +3,9 @@ from autogen_core.base import MessageContext, TopicId from autogen_core.components import RoutedAgent, event -from autogen_core.components.models import AssistantMessage, UserMessage -from ._messages import ContentPublishEvent, ContentRequestEvent +from ...agents import MultiModalMessage, StopMessage, TextMessage +from ._events import ContentPublishEvent, ContentRequestEvent class BaseGroupChatManager(RoutedAgent): @@ -47,7 +47,7 @@ def __init__( raise ValueError("The group topic type must not be the same as the parent topic type.") self._participant_topic_types = participant_topic_types self._participant_descriptions = participant_descriptions - self._message_thread: List[UserMessage | AssistantMessage] = [] + self._message_thread: List[TextMessage | MultiModalMessage | StopMessage] = [] @event async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None: @@ -61,23 +61,27 @@ async def handle_content_publish(self, message: ContentPublishEvent, ctx: Messag group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source) # TODO: use something else other than print. - assert isinstance(message.content, UserMessage) or isinstance(message.content, AssistantMessage) - sys.stdout.write(f"{'-'*80}\n{message.content.source}:\n{message.content.content}\n") + sys.stdout.write(f"{'-'*80}\n{message.agent_message.source}:\n{message.agent_message.content}\n") # Process event from parent. if ctx.topic_id.type == self._parent_topic_type: - self._message_thread.append(message.content) + self._message_thread.append(message.agent_message) await self.publish_message(message, topic_id=group_chat_topic_id) return # Process event from the group chat this agent manages. assert ctx.topic_id.type == self._group_topic_type - self._message_thread.append(message.content) + self._message_thread.append(message.agent_message) - if message.request_pause: + # If the message is a stop message, publish the last message as a TextMessage to the parent topic. + # TODO: custom handling the final message. + if isinstance(message.agent_message, StopMessage): parent_topic_id = TopicId(type=self._parent_topic_type, source=ctx.topic_id.source) await self.publish_message( - ContentPublishEvent(content=message.content, request_pause=True), topic_id=parent_topic_id + ContentPublishEvent( + agent_message=TextMessage(content=message.agent_message.content, source=self.metadata["type"]) + ), + topic_id=parent_topic_id, ) return @@ -100,7 +104,7 @@ async def handle_content_request(self, message: ContentRequestEvent, ctx: Messag participant_topic_id = TopicId(type=speaker_topic_type, source=ctx.topic_id.source) await self.publish_message(ContentRequestEvent(), topic_id=participant_topic_id) - async def select_speaker(self, thread: List[UserMessage | AssistantMessage]) -> str: + async def select_speaker(self, thread: List[TextMessage | MultiModalMessage | StopMessage]) -> str: """Select a speaker from the participants and return the topic type of the selected speaker.""" raise NotImplementedError("Method not implemented") diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_events.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_events.py new file mode 100644 index 000000000000..70ab47e4b089 --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_events.py @@ -0,0 +1,21 @@ +from pydantic import BaseModel + +from ...agents import MultiModalMessage, StopMessage, TextMessage + + +class ContentPublishEvent(BaseModel): + """An event for sharing some data. Agents receive this event should + update their internal state (e.g., append to message history) with the + content of the event. + """ + + agent_message: TextMessage | MultiModalMessage | StopMessage + """The message published by the agent.""" + + +class ContentRequestEvent(BaseModel): + """An event for requesting to publish a content event. + Upon receiving this event, the agent should publish a ContentPublishEvent. + """ + + ... diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_messages.py deleted file mode 100644 index 57dcdbd75c9a..000000000000 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_messages.py +++ /dev/null @@ -1,25 +0,0 @@ -from autogen_core.components.models import AssistantMessage, UserMessage -from pydantic import BaseModel - - -class ContentPublishEvent(BaseModel): - """An event message for sharing some data. Agents receive this message should - update their internal state (e.g., append to message history) with the - content of the message. - """ - - content: UserMessage | AssistantMessage - """The content of the message.""" - - request_pause: bool - """A flag indicating whether the current conversation session should be - paused after processing this message.""" - - -class ContentRequestEvent(BaseModel): - """An event message for requesting to publish a content message. - Upon receiving this message, the agent should publish a ContentPublishEvent - message. - """ - - ... diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat.py index 6e782029648d..14a4e127f46d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat.py @@ -5,12 +5,11 @@ from autogen_core.application import SingleThreadedAgentRuntime from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId from autogen_core.components import ClosureAgent, TypeSubscription -from autogen_core.components.models import UserMessage -from ...agents import BaseChatAgent +from ...agents import BaseChatAgent, TextMessage from .._base_team import BaseTeam, TeamRunResult from ._base_chat_agent_container import BaseChatAgentContainer -from ._messages import ContentPublishEvent, ContentRequestEvent +from ._events import ContentPublishEvent, ContentRequestEvent from ._round_robin_group_chat_manager import RoundRobinGroupChatManager @@ -106,7 +105,7 @@ async def output_result( team_topic_id = TopicId(type=team_topic_type, source=self._team_id) group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id) await runtime.publish_message( - ContentPublishEvent(content=UserMessage(content=task, source="user"), request_pause=False), + ContentPublishEvent(agent_message=TextMessage(content=task, source="user")), topic_id=team_topic_id, ) await runtime.publish_message(ContentRequestEvent(), topic_id=group_chat_manager_topic_id) @@ -121,7 +120,7 @@ async def output_result( assert ( last_message is not None - and isinstance(last_message.content, UserMessage) - and isinstance(last_message.content.content, str) + and isinstance(last_message.agent_message, TextMessage) + and isinstance(last_message.agent_message.content, str) ) - return TeamRunResult(last_message.content.content) + return TeamRunResult(last_message.agent_message.content) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat_manager.py index 129d79dd1a91..a0c5a97e90a3 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat_manager.py @@ -1,7 +1,6 @@ from typing import List -from autogen_core.components.models import AssistantMessage, UserMessage - +from ...agents import MultiModalMessage, StopMessage, TextMessage from ._base_group_chat_manager import BaseGroupChatManager @@ -23,7 +22,7 @@ def __init__( ) self._next_speaker_index = 0 - async def select_speaker(self, thread: List[UserMessage | AssistantMessage]) -> str: + async def select_speaker(self, thread: List[TextMessage | MultiModalMessage | StopMessage]) -> str: """Select a speaker from the participants in a round-robin fashion.""" current_speaker_index = self._next_speaker_index self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_topic_types)