From 59e392cd0f5e6075d2f6f5b527a0914825049286 Mon Sep 17 00:00:00 2001 From: afourney Date: Thu, 6 Feb 2025 16:03:17 -0800 Subject: [PATCH] Get SelectorGroupChat working for Llama models. (#5409) Get's SelectorGroupChat working for llama by: 1. Using a UserMessage rather than a SystemMessage 2. Normalizing how roles are presented (one agent per line) 3. Normalizing how the transcript is constructed (a blank line between every message) --- .../teams/_group_chat/_selector_group_chat.py | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index bf4bc95946ad..de0ef3247c69 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Mapping, Sequence from autogen_core import Component, ComponentModel -from autogen_core.models import ChatCompletionClient, SystemMessage, UserMessage +from autogen_core.models import ChatCompletionClient, ModelFamily, SystemMessage, UserMessage from pydantic import BaseModel from typing_extensions import Self @@ -110,18 +110,17 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str: message += " [Image]" else: raise ValueError(f"Unexpected message type in selector: {type(msg)}") - history_messages.append(message) + history_messages.append( + message.rstrip() + "\n\n" + ) # Create some consistency for how messages are separated in the transcript history = "\n".join(history_messages) # Construct agent roles, we are using the participant topic type as the agent name. - roles = "\n".join( - [ - f"{topic_type}: {description}".strip() - for topic_type, description in zip( - self._participant_topic_types, self._participant_descriptions, strict=True - ) - ] - ) + # Each agent sould appear on a single line. + roles = "" + for topic_type, description in zip(self._participant_topic_types, self._participant_descriptions, strict=True): + roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n" + roles = roles.strip() # Construct agent list to be selected, skip the previous speaker if not allowed. if self._previous_speaker is not None and not self._allow_repeated_speaker: @@ -136,11 +135,20 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str: roles=roles, participants=str(participants), history=history ) select_speaker_messages: List[SystemMessage | UserMessage] - if self._model_client.model_info["family"].startswith("gemini"): - select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")] - else: + if self._model_client.model_info["family"] in [ + ModelFamily.GPT_4, + ModelFamily.GPT_4O, + ModelFamily.GPT_35, + ModelFamily.O1, + ModelFamily.O3, + ]: select_speaker_messages = [SystemMessage(content=select_speaker_prompt)] + else: + # Many other models need a UserMessage to respond to + select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")] + response = await self._model_client.create(messages=select_speaker_messages) + assert isinstance(response.content, str) mentions = self._mentioned_agents(response.content, self._participant_topic_types) if len(mentions) != 1: