From 1e947830277bd8a48cebc3101a58943f8c413d25 Mon Sep 17 00:00:00 2001 From: gagb Date: Sun, 26 Jan 2025 18:13:56 -0800 Subject: [PATCH 1/2] feat: add extra_create_args to AssistantAgent for model client customization --- .../src/autogen_agentchat/agents/_assistant_agent.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index b1144d9c466..6744e725d91 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -64,6 +64,7 @@ class AssistantAgentConfig(BaseModel): system_message: str | None = None reflect_on_tool_use: bool tool_call_summary_format: str + extra_create_args: Mapping[str, Any] | None = None class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): @@ -147,6 +148,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): Available variables: `{tool_name}`, `{arguments}`, `{result}`. For example, `"{tool_name}: {result}"` will create a summary like `"tool_name: result"`. memory (Sequence[Memory] | None, optional): The memory store to use for the agent. Defaults to `None`. + extra_create_args (Mapping[str, Any] | None, optional): Additional arguments to pass to the model client during the create method call. Defaults to `None`. Raises: ValueError: If tool names are not unique. @@ -271,6 +273,7 @@ def __init__( reflect_on_tool_use: bool = False, tool_call_summary_format: str = "{result}", memory: Sequence[Memory] | None = None, + extra_create_args: Mapping[str, Any] | None = None, ): super().__init__(name=name, description=description) self._model_client = model_client @@ -337,6 +340,7 @@ def __init__( self._reflect_on_tool_use = reflect_on_tool_use self._tool_call_summary_format = tool_call_summary_format self._is_running = False + self._extra_create_args = extra_create_args or {} @property def produced_message_types(self) -> Sequence[type[ChatMessage]]: @@ -384,7 +388,7 @@ async def on_messages_stream( # Generate an inference result based on the current model context. llm_messages = self._system_messages + await self._model_context.get_messages() model_result = await self._model_client.create( - llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token + llm_messages, tools=self._tools + self._handoff_tools, extra_create_args=self._extra_create_args, cancellation_token=cancellation_token ) # Add the response to the model context. @@ -465,7 +469,9 @@ async def on_messages_stream( if self._reflect_on_tool_use: # Generate another inference result based on the tool call and result. llm_messages = self._system_messages + await self._model_context.get_messages() - model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token) + model_result = await self._model_client.create( + llm_messages, extra_create_args=self._extra_create_args, cancellation_token=cancellation_token + ) assert isinstance(model_result.content, str) # Add the response to the model context. await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name)) @@ -540,6 +546,7 @@ def _to_config(self) -> AssistantAgentConfig: else None, reflect_on_tool_use=self._reflect_on_tool_use, tool_call_summary_format=self._tool_call_summary_format, + extra_create_args=self._extra_create_args, ) @classmethod @@ -555,4 +562,5 @@ def _from_config(cls, config: AssistantAgentConfig) -> Self: system_message=config.system_message, reflect_on_tool_use=config.reflect_on_tool_use, tool_call_summary_format=config.tool_call_summary_format, + extra_create_args=config.extra_create_args, ) From a7ee98b9084fe575a9ef10fbbc73cec7c9eb3772 Mon Sep 17 00:00:00 2001 From: gagb Date: Sun, 26 Jan 2025 22:26:52 -0800 Subject: [PATCH 2/2] Run poe check --- .../src/autogen_agentchat/agents/_assistant_agent.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 6744e725d91..384dcf55fc2 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -28,7 +28,7 @@ SystemMessage, UserMessage, ) -from autogen_core.tools import FunctionTool, BaseTool +from autogen_core.tools import BaseTool, FunctionTool from pydantic import BaseModel from typing_extensions import Self @@ -388,7 +388,10 @@ async def on_messages_stream( # Generate an inference result based on the current model context. llm_messages = self._system_messages + await self._model_context.get_messages() model_result = await self._model_client.create( - llm_messages, tools=self._tools + self._handoff_tools, extra_create_args=self._extra_create_args, cancellation_token=cancellation_token + llm_messages, + tools=self._tools + self._handoff_tools, + extra_create_args=self._extra_create_args, + cancellation_token=cancellation_token, ) # Add the response to the model context.