-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sequential processing for group chat participant using SequentialRout…
…edAgent (#663)
- Loading branch information
Showing
4 changed files
with
99 additions
and
4 deletions.
There are no files selected for viewing
5 changes: 3 additions & 2 deletions
5
...es/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_chat_agent_container.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
...ages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_sequential_routed_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import asyncio | ||
from typing import Any | ||
|
||
from autogen_core.base import MessageContext | ||
from autogen_core.components import RoutedAgent | ||
|
||
|
||
class FIFOLock: | ||
"""A lock that ensures coroutines acquire the lock in the order they request it.""" | ||
|
||
def __init__(self) -> None: | ||
self._queue = asyncio.Queue[asyncio.Event]() | ||
self._locked = False | ||
|
||
async def acquire(self) -> None: | ||
# If the lock is not held by any coroutine, set the lock to be held | ||
# by the current coroutine. | ||
if not self._locked: | ||
self._locked = True | ||
return | ||
|
||
# If the lock is held by another coroutine, create an event and put it | ||
# in the queue. Wait for the event to be set. | ||
event = asyncio.Event() | ||
await self._queue.put(event) | ||
await event.wait() | ||
|
||
def release(self) -> None: | ||
if not self._queue.empty(): | ||
# If there are events in the queue, get the next event and set it. | ||
next_event = self._queue.get_nowait() | ||
next_event.set() | ||
else: | ||
# If there are no events in the queue, release the lock. | ||
self._locked = False | ||
|
||
|
||
class SequentialRoutedAgent(RoutedAgent): | ||
"""A subclass of :class:`autogen_core.components.RoutedAgent` that ensures | ||
messages are handled sequentially in the order they arrive.""" | ||
|
||
def __init__(self, description: str) -> None: | ||
super().__init__(description=description) | ||
self._fifo_lock = FIFOLock() | ||
|
||
async def on_message(self, message: Any, ctx: MessageContext) -> Any | None: | ||
await self._fifo_lock.acquire() | ||
try: | ||
return await super().on_message(message, ctx) | ||
finally: | ||
self._fifo_lock.release() |
42 changes: 42 additions & 0 deletions
42
python/packages/autogen-agentchat/tests/test_sequential_routed_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import asyncio | ||
import random | ||
from dataclasses import dataclass | ||
from typing import List | ||
|
||
import pytest | ||
from autogen_agentchat.teams.group_chat._sequential_routed_agent import SequentialRoutedAgent | ||
from autogen_core.application import SingleThreadedAgentRuntime | ||
from autogen_core.base import AgentId, MessageContext | ||
from autogen_core.components import DefaultTopicId, default_subscription, message_handler | ||
|
||
|
||
@dataclass | ||
class Message: | ||
content: str | ||
|
||
|
||
@default_subscription | ||
class TestAgent(SequentialRoutedAgent): | ||
def __init__(self, description: str) -> None: | ||
super().__init__(description=description) | ||
self.messages: List[Message] = [] | ||
|
||
@message_handler | ||
async def handle_content_publish(self, message: Message, ctx: MessageContext) -> None: | ||
# Sleep a random amount of time to simulate processing time. | ||
await asyncio.sleep(random.random() / 100) | ||
self.messages.append(message) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_sequential_routed_agent() -> None: | ||
runtime = SingleThreadedAgentRuntime() | ||
runtime.start() | ||
await TestAgent.register(runtime, type="test_agent", factory=lambda: TestAgent(description="Test Agent")) | ||
test_agent_id = AgentId(type="test_agent", key="default") | ||
for i in range(100): | ||
await runtime.publish_message(Message(content=f"{i}"), topic_id=DefaultTopicId()) | ||
await runtime.stop_when_idle() | ||
test_agent = await runtime.try_get_underlying_agent_instance(test_agent_id, TestAgent) | ||
for i in range(100): | ||
assert test_agent.messages[i].content == f"{i}" |