Skip to content

Commit

Permalink
Sequential processing for group chat participant using SequentialRout…
Browse files Browse the repository at this point in the history
…edAgent (#663)
  • Loading branch information
ekzhu authored Sep 28, 2024
1 parent 18efc23 commit 0fa6805
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import List

from autogen_core.base import MessageContext
from autogen_core.components import DefaultTopicId, RoutedAgent, event
from autogen_core.components import DefaultTopicId, event

from ...agents import BaseChatAgent, MultiModalMessage, StopMessage, TextMessage
from ._events import ContentPublishEvent, ContentRequestEvent
from ._sequential_routed_agent import SequentialRoutedAgent


class BaseChatAgentContainer(RoutedAgent):
class BaseChatAgentContainer(SequentialRoutedAgent):
"""A core agent class that delegates message handling to an
:class:`autogen_agentchat.agents.BaseChatAgent` so that it can be used in a
group chat team.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from typing import List

from autogen_core.base import MessageContext, TopicId
from autogen_core.components import RoutedAgent, event
from autogen_core.components import event

from ...agents import MultiModalMessage, StopMessage, TextMessage
from ._events import ContentPublishEvent, ContentRequestEvent
from ._sequential_routed_agent import SequentialRoutedAgent


class BaseGroupChatManager(RoutedAgent):
class BaseGroupChatManager(SequentialRoutedAgent):
"""Base class for a group chat manager that manages a group chat with multiple participants.
It is the responsibility of the caller to ensure:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import asyncio
from typing import Any

from autogen_core.base import MessageContext
from autogen_core.components import RoutedAgent


class FIFOLock:
"""A lock that ensures coroutines acquire the lock in the order they request it."""

def __init__(self) -> None:
self._queue = asyncio.Queue[asyncio.Event]()
self._locked = False

async def acquire(self) -> None:
# If the lock is not held by any coroutine, set the lock to be held
# by the current coroutine.
if not self._locked:
self._locked = True
return

# If the lock is held by another coroutine, create an event and put it
# in the queue. Wait for the event to be set.
event = asyncio.Event()
await self._queue.put(event)
await event.wait()

def release(self) -> None:
if not self._queue.empty():
# If there are events in the queue, get the next event and set it.
next_event = self._queue.get_nowait()
next_event.set()
else:
# If there are no events in the queue, release the lock.
self._locked = False


class SequentialRoutedAgent(RoutedAgent):
"""A subclass of :class:`autogen_core.components.RoutedAgent` that ensures
messages are handled sequentially in the order they arrive."""

def __init__(self, description: str) -> None:
super().__init__(description=description)
self._fifo_lock = FIFOLock()

async def on_message(self, message: Any, ctx: MessageContext) -> Any | None:
await self._fifo_lock.acquire()
try:
return await super().on_message(message, ctx)
finally:
self._fifo_lock.release()
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import asyncio
import random
from dataclasses import dataclass
from typing import List

import pytest
from autogen_agentchat.teams.group_chat._sequential_routed_agent import SequentialRoutedAgent
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, MessageContext
from autogen_core.components import DefaultTopicId, default_subscription, message_handler


@dataclass
class Message:
content: str


@default_subscription
class TestAgent(SequentialRoutedAgent):
def __init__(self, description: str) -> None:
super().__init__(description=description)
self.messages: List[Message] = []

@message_handler
async def handle_content_publish(self, message: Message, ctx: MessageContext) -> None:
# Sleep a random amount of time to simulate processing time.
await asyncio.sleep(random.random() / 100)
self.messages.append(message)


@pytest.mark.asyncio
async def test_sequential_routed_agent() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.start()
await TestAgent.register(runtime, type="test_agent", factory=lambda: TestAgent(description="Test Agent"))
test_agent_id = AgentId(type="test_agent", key="default")
for i in range(100):
await runtime.publish_message(Message(content=f"{i}"), topic_id=DefaultTopicId())
await runtime.stop_when_idle()
test_agent = await runtime.try_get_underlying_agent_instance(test_agent_id, TestAgent)
for i in range(100):
assert test_agent.messages[i].content == f"{i}"

0 comments on commit 0fa6805

Please sign in to comment.