Skip to content

Commit

Permalink
Use agentchat message types rather than core's model client message t…
Browse files Browse the repository at this point in the history
…ypes (#662)

* Use agentchat message types rather than core's model client message types

* Merge remote-tracking branch 'origin/main' into ekzhu-tool-use-assistant
  • Loading branch information
ekzhu authored Sep 28, 2024
1 parent 43c85d6 commit 18efc23
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 83 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
from ._base_chat_agent import BaseChatAgent, ChatMessage
from .coding._code_executor_agent import CodeExecutorAgent
from .coding._coding_assistant_agent import CodingAssistantAgent
from ._base_chat_agent import (
BaseChatAgent,
ChatMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
)
from ._code_executor_agent import CodeExecutorAgent
from ._coding_assistant_agent import CodingAssistantAgent

__all__ = [
"BaseChatAgent",
"ChatMessage",
"TextMessage",
"MultiModalMessage",
"ToolCallMessage",
"ToolCallResultMessage",
"StopMessage",
"CodeExecutorAgent",
"CodingAssistantAgent",
]
Original file line number Diff line number Diff line change
@@ -1,20 +1,56 @@
from abc import ABC, abstractmethod
from typing import Sequence
from typing import List, Sequence

from autogen_core.base import CancellationToken
from autogen_core.components.models import AssistantMessage, UserMessage
from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from pydantic import BaseModel


class ChatMessage(BaseModel):
"""A chat message from a user or agent."""
class BaseMessage(BaseModel):
"""A base message."""

content: UserMessage | AssistantMessage
source: str
"""The name of the agent that sent this message."""


class TextMessage(BaseMessage):
"""A text message."""

content: str
"""The content of the message."""

request_pause: bool
"""A flag indicating whether the current conversation session should be
paused after processing this message."""

class MultiModalMessage(BaseMessage):
"""A multimodal message."""

content: List[str | Image]
"""The content of the message."""


class ToolCallMessage(BaseMessage):
"""A message containing a list of function calls."""

content: List[FunctionCall]
"""The list of function calls."""


class ToolCallResultMessage(BaseMessage):
"""A message containing the results of function calls."""

content: List[FunctionExecutionResult]
"""The list of function execution results."""


class StopMessage(BaseMessage):
"""A message requesting stop of a conversation."""

content: str
"""The content for the stop message."""


ChatMessage = TextMessage | MultiModalMessage | ToolCallMessage | ToolCallResultMessage | StopMessage
"""A message used by agents in a team."""


class BaseChatAgent(ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from autogen_core.base import CancellationToken
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
from autogen_core.components.models import UserMessage

from .._base_chat_agent import BaseChatAgent, ChatMessage
from ._base_chat_agent import BaseChatAgent, ChatMessage, TextMessage


class CodeExecutorAgent(BaseChatAgent):
Expand All @@ -21,14 +20,11 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
# Extract code blocks from the messages.
code_blocks: List[CodeBlock] = []
for msg in messages:
if isinstance(msg.content, UserMessage) and isinstance(msg.content.content, str):
code_blocks.extend(extract_markdown_code_blocks(msg.content.content))
if isinstance(msg, TextMessage):
code_blocks.extend(extract_markdown_code_blocks(msg.content))
if code_blocks:
# Execute the code blocks.
result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
return ChatMessage(content=UserMessage(content=result.output, source=self.name), request_pause=False)
return TextMessage(content=result.output, source=self.name)
else:
return ChatMessage(
content=UserMessage(content="No code blocks found in the thread.", source=self.name),
request_pause=False,
)
return TextMessage(content="No code blocks found in the thread.", source=self.name)
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import List, Sequence

from autogen_core.base import CancellationToken
from autogen_core.components.models import AssistantMessage, ChatCompletionClient, SystemMessage, UserMessage
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
LLMMessage,
SystemMessage,
UserMessage,
)

from .._base_chat_agent import BaseChatAgent, ChatMessage
from ._base_chat_agent import BaseChatAgent, ChatMessage, MultiModalMessage, StopMessage, TextMessage


class CodingAssistantAgent(BaseChatAgent):
Expand All @@ -27,22 +33,26 @@ def __init__(self, name: str, model_client: ChatCompletionClient):
super().__init__(name=name, description=self.DESCRIPTION)
self._model_client = model_client
self._system_messages = [SystemMessage(content=self.SYSTEM_MESSAGE)]
self._message_thread: List[UserMessage | AssistantMessage] = []
self._model_context: List[LLMMessage] = []

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
# Add messages to the thread.
# Add messages to the model context and detect stopping.
for msg in messages:
self._message_thread.append(msg.content)
if not isinstance(msg, TextMessage | MultiModalMessage | StopMessage):
raise ValueError(f"Unsupported message type: {type(msg)}")
self._model_context.append(UserMessage(content=msg.content, source=msg.source))

# Generate an inference result based on the thread.
llm_messages = self._system_messages + self._message_thread
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)

# Add the response to the thread.
self._message_thread.append(AssistantMessage(content=result.content, source=self.name))
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

# Detect pause request.
request_pause = "terminate" in result.content.strip().lower()
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
return StopMessage(content=result.content, source=self.name)

return ChatMessage(content=UserMessage(content=result.content, source=self.name), request_pause=request_pause)
return TextMessage(content=result.content, source=self.name)
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from autogen_core.base import MessageContext
from autogen_core.components import DefaultTopicId, RoutedAgent, event

from ...agents import BaseChatAgent, ChatMessage
from ._messages import ContentPublishEvent, ContentRequestEvent
from ...agents import BaseChatAgent, MultiModalMessage, StopMessage, TextMessage
from ._events import ContentPublishEvent, ContentRequestEvent


class BaseChatAgentContainer(RoutedAgent):
Expand All @@ -21,20 +21,26 @@ def __init__(self, parent_topic_type: str, agent: BaseChatAgent) -> None:
super().__init__(description=agent.description)
self._parent_topic_type = parent_topic_type
self._agent = agent
self._message_buffer: List[ChatMessage] = []
self._message_buffer: List[TextMessage | MultiModalMessage | StopMessage] = []

@event
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
"""Handle a content publish event by appending the content to the buffer."""
self._message_buffer.append(ChatMessage(content=message.content, request_pause=message.request_pause))
if not isinstance(message.agent_message, TextMessage | MultiModalMessage | StopMessage):
raise ValueError(
f"Unexpected message type: {type(message.agent_message)}. "
"The message must be a text, multimodal, or stop message."
)
self._message_buffer.append(message.agent_message)

@event
async def handle_content_request(self, message: ContentRequestEvent, ctx: MessageContext) -> None:
"""Handle a content request event by passing the messages in the buffer
to the delegate agent and publish the response."""
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
# TODO: handle tool call messages.
assert isinstance(response, TextMessage | MultiModalMessage | StopMessage)
self._message_buffer.clear()
await self.publish_message(
ContentPublishEvent(content=response.content, request_pause=response.request_pause),
topic_id=DefaultTopicId(type=self._parent_topic_type),
ContentPublishEvent(agent_message=response), topic_id=DefaultTopicId(type=self._parent_topic_type)
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from autogen_core.base import MessageContext, TopicId
from autogen_core.components import RoutedAgent, event
from autogen_core.components.models import AssistantMessage, UserMessage

from ._messages import ContentPublishEvent, ContentRequestEvent
from ...agents import MultiModalMessage, StopMessage, TextMessage
from ._events import ContentPublishEvent, ContentRequestEvent


class BaseGroupChatManager(RoutedAgent):
Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(
raise ValueError("The group topic type must not be the same as the parent topic type.")
self._participant_topic_types = participant_topic_types
self._participant_descriptions = participant_descriptions
self._message_thread: List[UserMessage | AssistantMessage] = []
self._message_thread: List[TextMessage | MultiModalMessage | StopMessage] = []

@event
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
Expand All @@ -61,23 +61,27 @@ async def handle_content_publish(self, message: ContentPublishEvent, ctx: Messag
group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source)

# TODO: use something else other than print.
assert isinstance(message.content, UserMessage) or isinstance(message.content, AssistantMessage)
sys.stdout.write(f"{'-'*80}\n{message.content.source}:\n{message.content.content}\n")
sys.stdout.write(f"{'-'*80}\n{message.agent_message.source}:\n{message.agent_message.content}\n")

# Process event from parent.
if ctx.topic_id.type == self._parent_topic_type:
self._message_thread.append(message.content)
self._message_thread.append(message.agent_message)
await self.publish_message(message, topic_id=group_chat_topic_id)
return

# Process event from the group chat this agent manages.
assert ctx.topic_id.type == self._group_topic_type
self._message_thread.append(message.content)
self._message_thread.append(message.agent_message)

if message.request_pause:
# If the message is a stop message, publish the last message as a TextMessage to the parent topic.
# TODO: custom handling the final message.
if isinstance(message.agent_message, StopMessage):
parent_topic_id = TopicId(type=self._parent_topic_type, source=ctx.topic_id.source)
await self.publish_message(
ContentPublishEvent(content=message.content, request_pause=True), topic_id=parent_topic_id
ContentPublishEvent(
agent_message=TextMessage(content=message.agent_message.content, source=self.metadata["type"])
),
topic_id=parent_topic_id,
)
return

Expand All @@ -100,7 +104,7 @@ async def handle_content_request(self, message: ContentRequestEvent, ctx: Messag
participant_topic_id = TopicId(type=speaker_topic_type, source=ctx.topic_id.source)
await self.publish_message(ContentRequestEvent(), topic_id=participant_topic_id)

async def select_speaker(self, thread: List[UserMessage | AssistantMessage]) -> str:
async def select_speaker(self, thread: List[TextMessage | MultiModalMessage | StopMessage]) -> str:
"""Select a speaker from the participants and return the
topic type of the selected speaker."""
raise NotImplementedError("Method not implemented")
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pydantic import BaseModel

from ...agents import MultiModalMessage, StopMessage, TextMessage


class ContentPublishEvent(BaseModel):
"""An event for sharing some data. Agents receive this event should
update their internal state (e.g., append to message history) with the
content of the event.
"""

agent_message: TextMessage | MultiModalMessage | StopMessage
"""The message published by the agent."""


class ContentRequestEvent(BaseModel):
"""An event for requesting to publish a content event.
Upon receiving this event, the agent should publish a ContentPublishEvent.
"""

...

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId
from autogen_core.components import ClosureAgent, TypeSubscription
from autogen_core.components.models import UserMessage

from ...agents import BaseChatAgent
from ...agents import BaseChatAgent, TextMessage
from .._base_team import BaseTeam, TeamRunResult
from ._base_chat_agent_container import BaseChatAgentContainer
from ._messages import ContentPublishEvent, ContentRequestEvent
from ._events import ContentPublishEvent, ContentRequestEvent
from ._round_robin_group_chat_manager import RoundRobinGroupChatManager


Expand Down Expand Up @@ -106,7 +105,7 @@ async def output_result(
team_topic_id = TopicId(type=team_topic_type, source=self._team_id)
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
await runtime.publish_message(
ContentPublishEvent(content=UserMessage(content=task, source="user"), request_pause=False),
ContentPublishEvent(agent_message=TextMessage(content=task, source="user")),
topic_id=team_topic_id,
)
await runtime.publish_message(ContentRequestEvent(), topic_id=group_chat_manager_topic_id)
Expand All @@ -121,7 +120,7 @@ async def output_result(

assert (
last_message is not None
and isinstance(last_message.content, UserMessage)
and isinstance(last_message.content.content, str)
and isinstance(last_message.agent_message, TextMessage)
and isinstance(last_message.agent_message.content, str)
)
return TeamRunResult(last_message.content.content)
return TeamRunResult(last_message.agent_message.content)
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List

from autogen_core.components.models import AssistantMessage, UserMessage

from ...agents import MultiModalMessage, StopMessage, TextMessage
from ._base_group_chat_manager import BaseGroupChatManager


Expand All @@ -23,7 +22,7 @@ def __init__(
)
self._next_speaker_index = 0

async def select_speaker(self, thread: List[UserMessage | AssistantMessage]) -> str:
async def select_speaker(self, thread: List[TextMessage | MultiModalMessage | StopMessage]) -> str:
"""Select a speaker from the participants in a round-robin fashion."""
current_speaker_index = self._next_speaker_index
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_topic_types)
Expand Down

0 comments on commit 18efc23

Please sign in to comment.