Skip to content

Commit

Permalink
Mitigates #5401 by optionally prepending names to messages.
Browse files Browse the repository at this point in the history
  • Loading branch information
afourney committed Feb 8, 2025
1 parent edbd201 commit b233a83
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,30 @@ def type_to_role(message: LLMMessage) -> ChatCompletionRole:
return "tool"


def user_message_to_oai(message: UserMessage) -> ChatCompletionUserMessageParam:
def user_message_to_oai(message: UserMessage, prepend_name: bool = False) -> ChatCompletionUserMessageParam:
assert_valid_name(message.source)
if isinstance(message.content, str):
return ChatCompletionUserMessageParam(
content=message.content,
content=(f"{message.source} said:\n" if prepend_name else "") + message.content,
role="user",
name=message.source,
)
else:
parts: List[ChatCompletionContentPartParam] = []
for part in message.content:
if isinstance(part, str):
oai_part = ChatCompletionContentPartTextParam(
text=part,
type="text",
)
if prepend_name:
# Append the name to the first text part
oai_part = ChatCompletionContentPartTextParam(
text=f"{message.source} said:\n" + part,
type="text",
)
prepend_name = False
else:
oai_part = ChatCompletionContentPartTextParam(
text=part,
type="text",
)
parts.append(oai_part)
elif isinstance(part, Image):
# TODO: support url based images
Expand Down Expand Up @@ -211,11 +219,11 @@ def assistant_message_to_oai(
)


def to_oai_type(message: LLMMessage) -> Sequence[ChatCompletionMessageParam]:
def to_oai_type(message: LLMMessage, prepend_name: bool = False) -> Sequence[ChatCompletionMessageParam]:
if isinstance(message, SystemMessage):
return [system_message_to_oai(message)]
elif isinstance(message, UserMessage):
return [user_message_to_oai(message)]
return [user_message_to_oai(message, prepend_name)]
elif isinstance(message, AssistantMessage):
return [assistant_message_to_oai(message)]
else:
Expand Down Expand Up @@ -356,8 +364,10 @@ def __init__(
create_args: Dict[str, Any],
model_capabilities: Optional[ModelCapabilities] = None, # type: ignore
model_info: Optional[ModelInfo] = None,
add_name_prefixes: bool = False,
):
self._client = client
self._add_name_prefixes = add_name_prefixes
if model_capabilities is None and model_info is None:
try:
self._model_info = _model_info.get_info(create_args["model"])
Expand Down Expand Up @@ -451,7 +461,7 @@ async def create(
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output.")

oai_messages_nested = [to_oai_type(m) for m in messages]
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
oai_messages = [item for sublist in oai_messages_nested for item in sublist]

if self.model_info["function_calling"] is False and len(tools) > 0:
Expand Down Expand Up @@ -672,7 +682,7 @@ async def create_stream(
create_args = self._create_args.copy()
create_args.update(extra_create_args)

oai_messages_nested = [to_oai_type(m) for m in messages]
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
oai_messages = [item for sublist in oai_messages_nested for item in sublist]

# TODO: allow custom handling.
Expand Down Expand Up @@ -874,7 +884,7 @@ def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool |
# Message tokens.
for message in messages:
num_tokens += tokens_per_message
oai_message = to_oai_type(message)
oai_message = to_oai_type(message, prepend_name=self._add_name_prefixes)
for oai_message_part in oai_message:
for key, value in oai_message_part.items():
if value is None:
Expand Down Expand Up @@ -1074,11 +1084,19 @@ def __init__(self, **kwargs: Unpack[OpenAIClientConfiguration]):
model_info = kwargs["model_info"]
del copied_args["model_info"]

add_name_prefixes: bool = False
if "add_name_prefixes" in kwargs:
add_name_prefixes = kwargs["add_name_prefixes"]

Check warning on line 1089 in python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py#L1089

Added line #L1089 was not covered by tests

client = _openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args)

super().__init__(
client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info
client=client,
create_args=create_args,
model_capabilities=model_capabilities,
model_info=model_info,
add_name_prefixes=add_name_prefixes,
)

def __getstate__(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -1215,11 +1233,19 @@ def __init__(self, **kwargs: Unpack[AzureOpenAIClientConfiguration]):
model_info = kwargs["model_info"]
del copied_args["model_info"]

add_name_prefixes: bool = False
if "add_name_prefixes" in kwargs:
add_name_prefixes = kwargs["add_name_prefixes"]

Check warning on line 1238 in python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py#L1238

Added line #L1238 was not covered by tests

client = _azure_openai_client_from_config(copied_args)
create_args = _create_args_from_config(copied_args)
self._raw_config: Dict[str, Any] = copied_args
super().__init__(
client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info
client=client,
create_args=create_args,
model_capabilities=model_capabilities,
model_info=model_info,
add_name_prefixes=add_name_prefixes,
)

def __getstate__(self) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class BaseOpenAIClientConfiguration(CreateArguments, total=False):
max_retries: int
model_capabilities: ModelCapabilities # type: ignore
model_info: ModelInfo
add_name_prefixes: bool
"""What functionality the model supports, determined by default from model name but is overriden if value passed."""
default_headers: Dict[str, str] | None

Expand Down Expand Up @@ -75,6 +76,7 @@ class BaseOpenAIClientConfigurationConfigModel(CreateArgumentsConfigModel):
max_retries: int | None = None
model_capabilities: ModelCapabilities | None = None # type: ignore
model_info: ModelInfo | None = None
add_name_prefixes: bool | None = None
default_headers: Dict[str, str] | None = None


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from autogen_core.tools import BaseTool, FunctionTool
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
from autogen_ext.models.openai._model_info import resolve_model
from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools
from autogen_ext.models.openai._openai_client import calculate_vision_tokens, convert_tools, to_oai_type
from openai.resources.beta.chat.completions import AsyncCompletions as BetaAsyncCompletions
from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
Expand Down Expand Up @@ -1050,4 +1050,56 @@ async def test_ollama() -> None:
assert chunks[-1].thought is not None


@pytest.mark.asyncio
async def test_add_name_prefixes(monkeypatch: pytest.MonkeyPatch) -> None:
sys_message = SystemMessage(content="You are a helpful AI agent, and you answer questions in a friendly way.")
assistant_message = AssistantMessage(content="Hello, how can I help you?", source="Assistant")
user_text_message = UserMessage(content="Hello, I am from Seattle.", source="Adam")
user_mm_message = UserMessage(
content=[
"Here is a postcard from Seattle:",
Image.from_base64(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4z8AAAAMBAQDJ/pLvAAAAAElFTkSuQmCC"
),
],
source="Adam",
)

# Default conversion
oai_sys = to_oai_type(sys_message)[0]
oai_asst = to_oai_type(assistant_message)[0]
oai_text = to_oai_type(user_text_message)[0]
oai_mm = to_oai_type(user_mm_message)[0]

converted_sys = to_oai_type(sys_message, prepend_name=True)[0]
converted_asst = to_oai_type(assistant_message, prepend_name=True)[0]
converted_text = to_oai_type(user_text_message, prepend_name=True)[0]
converted_mm = to_oai_type(user_mm_message, prepend_name=True)[0]

# Invariants
assert "content" in oai_sys
assert "content" in oai_asst
assert "content" in oai_text
assert "content" in oai_mm
assert "content" in converted_sys
assert "content" in converted_asst
assert "content" in converted_text
assert "content" in converted_mm
assert oai_sys["role"] == converted_sys["role"]
assert oai_sys["content"] == converted_sys["content"]
assert oai_asst["role"] == converted_asst["role"]
assert oai_asst["content"] == converted_asst["content"]
assert oai_text["role"] == converted_text["role"]
assert oai_mm["role"] == converted_mm["role"]
assert isinstance(oai_mm["content"], list)
assert isinstance(converted_mm["content"], list)
assert len(oai_mm["content"]) == len(converted_mm["content"])
assert "text" in converted_mm["content"][0]
assert "text" in oai_mm["content"][0]

# Name prepended
assert str(converted_text["content"]) == "Adam said:\n" + str(oai_text["content"])
assert str(converted_mm["content"][0]["text"]) == "Adam said:\n" + str(oai_mm["content"][0]["text"])


# TODO: add integration tests for Azure OpenAI using AAD token.

0 comments on commit b233a83

Please sign in to comment.