Skip to content

Commit

Permalink
Feature: Get Nested Agents in a GroupChat (#1636)
Browse files Browse the repository at this point in the history
* implements features

* fix docstring

* adds test

* resolve some comments

* remove unused group chat manager from test

* list implementation

* better naming

* resolve comments

* adds one more test

* checks case when agent doesnt exist

* clean up

---------

Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
WaelKarkoub and sonichi authored Feb 15, 2024
1 parent b270a2e commit a52f52a
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 7 deletions.
1 change: 1 addition & 0 deletions autogen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .version import __version__
from .oai import *
from .agentchat import *
from .exception_utils import *
from .code_utils import DEFAULT_MODEL, FAST_MODEL


Expand Down
35 changes: 28 additions & 7 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


from ..code_utils import content_str
from ..exception_utils import AgentNameConflict
from .agent import Agent
from .conversable_agent import ConversableAgent
from ..runtime_logging import logging_enabled, log_new_agent
Expand Down Expand Up @@ -174,9 +175,26 @@ def append(self, message: Dict, speaker: Agent):
message["content"] = content_str(message["content"])
self.messages.append(message)

def agent_by_name(self, name: str) -> Agent:
"""Returns the agent with a given name."""
return self.agents[self.agent_names.index(name)]
def agent_by_name(
self, name: str, recursive: bool = False, raise_on_name_conflict: bool = False
) -> Optional[Agent]:
"""Returns the agent with a given name. If recursive is True, it will search in nested teams."""
agents = self.nested_agents() if recursive else self.agents
filtered_agents = [agent for agent in agents if agent.name == name]

if raise_on_name_conflict and len(filtered_agents) > 1:
raise AgentNameConflict()

return filtered_agents[0] if filtered_agents else None

def nested_agents(self) -> List[Agent]:
"""Returns all agents in the group chat manager."""
agents = self.agents.copy()
for agent in agents:
if isinstance(agent, GroupChatManager):
# Recursive call for nested teams
agents.extend(agent.groupchat.nested_agents())
return agents

def next_agent(self, agent: Agent, agents: Optional[List[Agent]] = None) -> Agent:
"""Return the next agent in the list."""
Expand Down Expand Up @@ -390,10 +408,8 @@ def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents:
)

# Return the result
try:
return self.agent_by_name(name)
except ValueError:
return self.next_agent(last_speaker, agents)
agent = self.agent_by_name(name)
return agent if agent else self.next_agent(last_speaker, agents)

def _participant_roles(self, agents: List[Agent] = None) -> str:
# Default to all agents registered
Expand Down Expand Up @@ -480,6 +496,11 @@ def __init__(
ignore_async_in_sync_chat=True,
)

@property
def groupchat(self) -> GroupChat:
"""Returns the group chat managed by the group chat manager."""
return self._groupchat

def chat_messages_for_summary(self, agent: Agent) -> List[Dict]:
"""The list of messages in the group chat as a conversation to summarize.
The agent is ignored.
Expand Down
3 changes: 3 additions & 0 deletions autogen/exception_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class AgentNameConflict(Exception):
def __init__(self, msg="Found multiple agents with the same name.", *args, **kwargs):
super().__init__(msg, *args, **kwargs)
132 changes: 132 additions & 0 deletions test/agentchat/test_groupchat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict, List, Optional, Type
from autogen import AgentNameConflict
import pytest
from unittest import mock
import builtins
Expand Down Expand Up @@ -672,6 +674,136 @@ def test_clear_agents_history():
]


def test_get_agent_by_name():
def agent(name: str) -> autogen.ConversableAgent:
return autogen.ConversableAgent(
name=name,
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
)

def team(members: List[autogen.Agent], name: str) -> autogen.Agent:
gc = autogen.GroupChat(agents=members, messages=[])

return autogen.GroupChatManager(groupchat=gc, name=name, llm_config=False)

team_member1 = agent("team1_member1")
team_member2 = agent("team1_member2")
team_dup_member1 = agent("team1_member1")
team_dup_member2 = agent("team1_member2")

user = agent("user")
team1 = team([team_member1, team_member2], "team1")
team1_duplicate = team([team_dup_member1, team_dup_member2], "team1")

gc = autogen.GroupChat(agents=[user, team1, team1_duplicate], messages=[])

# Testing default arguments
assert gc.agent_by_name("user") == user
assert gc.agent_by_name("team1") == team1 or gc.agent_by_name("team1") == team1_duplicate

# Testing recursive search
assert gc.agent_by_name("user", recursive=True) == user
assert (
gc.agent_by_name("team1_member1", recursive=True) == team_member1
or gc.agent_by_name("team1_member1", recursive=True) == team_dup_member1
)

# Get agent that does not exist
assert gc.agent_by_name("team2") is None
assert gc.agent_by_name("team2", recursive=True) is None
assert gc.agent_by_name("team2", raise_on_name_conflict=True) is None
assert gc.agent_by_name("team2", recursive=True, raise_on_name_conflict=True) is None

# Testing naming conflict
with pytest.raises(AgentNameConflict):
gc.agent_by_name("team1", raise_on_name_conflict=True)

# Testing name conflict with recursive search
with pytest.raises(AgentNameConflict):
gc.agent_by_name("team1_member1", recursive=True, raise_on_name_conflict=True)


def test_get_nested_agents_in_groupchat():
def agent(name: str) -> autogen.ConversableAgent:
return autogen.ConversableAgent(
name=name,
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
)

def team(name: str) -> autogen.ConversableAgent:
member1 = agent(f"member1_{name}")
member2 = agent(f"member2_{name}")

gc = autogen.GroupChat(agents=[member1, member2], messages=[])

return autogen.GroupChatManager(groupchat=gc, name=name, llm_config=False)

user = agent("user")
team1 = team("team1")
team2 = team("team2")

gc = autogen.GroupChat(agents=[user, team1, team2], messages=[])

agents = gc.nested_agents()
assert len(agents) == 7


def test_nested_teams_chat():
"""Tests chat capabilities of nested teams"""
team1_msg = {"content": "Hello from team 1"}
team2_msg = {"content": "Hello from team 2"}

def agent(name: str, auto_reply: Optional[Dict[str, Any]] = None) -> autogen.ConversableAgent:
return autogen.ConversableAgent(
name=name,
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply=auto_reply,
)

def team(name: str, auto_reply: Optional[Dict[str, Any]] = None) -> autogen.ConversableAgent:
member1 = agent(f"member1_{name}", auto_reply=auto_reply)
member2 = agent(f"member2_{name}", auto_reply=auto_reply)

gc = autogen.GroupChat(agents=[member1, member2], messages=[])

return autogen.GroupChatManager(groupchat=gc, name=name, llm_config=False)

def chat(gc_manager: autogen.GroupChatManager):
team1_member1 = gc_manager.groupchat.agent_by_name("member1_team1", recursive=True)
team2_member2 = gc_manager.groupchat.agent_by_name("member2_team2", recursive=True)

assert team1_member1 is not None
assert team2_member2 is not None

team1_member1.send(team1_msg, team2_member2, request_reply=True)

user = agent("user")
team1 = team("team1", auto_reply=team1_msg)
team2 = team("team2", auto_reply=team2_msg)

gc = autogen.GroupChat(agents=[user, team1, team2], messages=[])
gc_manager = autogen.GroupChatManager(groupchat=gc, llm_config=False)

chat(gc_manager)

team1_member1 = gc.agent_by_name("member1_team1", recursive=True)
team2_member2 = gc.agent_by_name("member2_team2", recursive=True)

assert team1_member1 and team2_member2

msg = team1_member1.chat_messages[team2_member2][0]
reply = team1_member1.chat_messages[team2_member2][1]

assert msg["content"] == team1_msg["content"]
assert reply["content"] == team2_msg["content"]


if __name__ == "__main__":
# test_func_call_groupchat()
# test_broadcast()
Expand Down

0 comments on commit a52f52a

Please sign in to comment.