Skip to content

Commit

Permalink
Get SelectorGroupChat working for Llama models. (#5409)
Browse files Browse the repository at this point in the history
Get's SelectorGroupChat working for llama by:

1. Using a UserMessage rather than a SystemMessage
2. Normalizing how roles are presented (one agent per line)
3. Normalizing how the transcript is constructed (a blank line between
every message)
  • Loading branch information
afourney authored Feb 7, 2025
1 parent c8e4ad8 commit 59e392c
Showing 1 changed file with 21 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, Dict, List, Mapping, Sequence

from autogen_core import Component, ComponentModel
from autogen_core.models import ChatCompletionClient, SystemMessage, UserMessage
from autogen_core.models import ChatCompletionClient, ModelFamily, SystemMessage, UserMessage
from pydantic import BaseModel
from typing_extensions import Self

Expand Down Expand Up @@ -110,18 +110,17 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
message += " [Image]"
else:
raise ValueError(f"Unexpected message type in selector: {type(msg)}")
history_messages.append(message)
history_messages.append(
message.rstrip() + "\n\n"
) # Create some consistency for how messages are separated in the transcript
history = "\n".join(history_messages)

# Construct agent roles, we are using the participant topic type as the agent name.
roles = "\n".join(
[
f"{topic_type}: {description}".strip()
for topic_type, description in zip(
self._participant_topic_types, self._participant_descriptions, strict=True
)
]
)
# Each agent sould appear on a single line.
roles = ""
for topic_type, description in zip(self._participant_topic_types, self._participant_descriptions, strict=True):
roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
roles = roles.strip()

# Construct agent list to be selected, skip the previous speaker if not allowed.
if self._previous_speaker is not None and not self._allow_repeated_speaker:
Expand All @@ -136,11 +135,20 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
roles=roles, participants=str(participants), history=history
)
select_speaker_messages: List[SystemMessage | UserMessage]
if self._model_client.model_info["family"].startswith("gemini"):
select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")]
else:
if self._model_client.model_info["family"] in [
ModelFamily.GPT_4,
ModelFamily.GPT_4O,
ModelFamily.GPT_35,
ModelFamily.O1,
ModelFamily.O3,
]:
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
else:
# Many other models need a UserMessage to respond to
select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")]

response = await self._model_client.create(messages=select_speaker_messages)

assert isinstance(response.content, str)
mentions = self._mentioned_agents(response.content, self._participant_topic_types)
if len(mentions) != 1:
Expand Down

0 comments on commit 59e392c

Please sign in to comment.