Skip to content

Commit 59e392c

Browse files
authored
Get SelectorGroupChat working for Llama models. (#5409)
Get's SelectorGroupChat working for llama by: 1. Using a UserMessage rather than a SystemMessage 2. Normalizing how roles are presented (one agent per line) 3. Normalizing how the transcript is constructed (a blank line between every message)
1 parent c8e4ad8 commit 59e392c

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, Dict, List, Mapping, Sequence
44

55
from autogen_core import Component, ComponentModel
6-
from autogen_core.models import ChatCompletionClient, SystemMessage, UserMessage
6+
from autogen_core.models import ChatCompletionClient, ModelFamily, SystemMessage, UserMessage
77
from pydantic import BaseModel
88
from typing_extensions import Self
99

@@ -110,18 +110,17 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
110110
message += " [Image]"
111111
else:
112112
raise ValueError(f"Unexpected message type in selector: {type(msg)}")
113-
history_messages.append(message)
113+
history_messages.append(
114+
message.rstrip() + "\n\n"
115+
) # Create some consistency for how messages are separated in the transcript
114116
history = "\n".join(history_messages)
115117

116118
# Construct agent roles, we are using the participant topic type as the agent name.
117-
roles = "\n".join(
118-
[
119-
f"{topic_type}: {description}".strip()
120-
for topic_type, description in zip(
121-
self._participant_topic_types, self._participant_descriptions, strict=True
122-
)
123-
]
124-
)
119+
# Each agent sould appear on a single line.
120+
roles = ""
121+
for topic_type, description in zip(self._participant_topic_types, self._participant_descriptions, strict=True):
122+
roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
123+
roles = roles.strip()
125124

126125
# Construct agent list to be selected, skip the previous speaker if not allowed.
127126
if self._previous_speaker is not None and not self._allow_repeated_speaker:
@@ -136,11 +135,20 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
136135
roles=roles, participants=str(participants), history=history
137136
)
138137
select_speaker_messages: List[SystemMessage | UserMessage]
139-
if self._model_client.model_info["family"].startswith("gemini"):
140-
select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")]
141-
else:
138+
if self._model_client.model_info["family"] in [
139+
ModelFamily.GPT_4,
140+
ModelFamily.GPT_4O,
141+
ModelFamily.GPT_35,
142+
ModelFamily.O1,
143+
ModelFamily.O3,
144+
]:
142145
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
146+
else:
147+
# Many other models need a UserMessage to respond to
148+
select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")]
149+
143150
response = await self._model_client.create(messages=select_speaker_messages)
151+
144152
assert isinstance(response.content, str)
145153
mentions = self._mentioned_agents(response.content, self._participant_topic_types)
146154
if len(mentions) != 1:

0 commit comments

Comments
 (0)