From 58d789a2498abc8528aeda5275fbdaa1a3fc2052 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Thu, 23 Jan 2025 19:49:43 -0800 Subject: [PATCH] Make FunctionTools Serializable (Declarative) (#5052) * vi1 for declarative tools * make functtools declarative * add tests * update imports * update formatting * move tests, format fixes * format updates * update test * add user warning to _from_config * add warning on load_component to docs --------- Co-authored-by: Ryan Sweet --- .../serialize-components.ipynb | 17 ++--- .../code_executor/_func_with_reqs.py | 10 ++- .../src/autogen_core/tools/_base.py | 9 ++- .../src/autogen_core/tools/_function_tool.py | 76 ++++++++++++++++++- .../tests/test_component_config.py | 69 ++++++++++++++++- 5 files changed, 165 insertions(+), 16 deletions(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/serialize-components.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/serialize-components.ipynb index 5a3855f48080..ff29efa9100a 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/serialize-components.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/serialize-components.ipynb @@ -6,21 +6,20 @@ "source": [ "# Serializing Components \n", "\n", - "AutoGen provides a {py:class}`~autogen_core.Component` configuration class that defines behaviours for to serialize/deserialize component into declarative specifications. This is useful for debugging, visualizing, and even for sharing your work with others. In this notebook, we will demonstrate how to serialize multiple components to a declarative specification like a JSON file. \n", + "AutoGen provides a {py:class}`~autogen_core.Component` configuration class that defines behaviours to serialize/deserialize component into declarative specifications. We can accomplish this by calling `.dump_component()` and `.load_component()` respectively. This is useful for debugging, visualizing, and even for sharing your work with others. In this notebook, we will demonstrate how to serialize multiple components to a declarative specification like a JSON file. \n", "\n", "\n", - "```{note}\n", - "This is work in progress\n", - "``` \n", + "```{warning}\n", "\n", - "We will be implementing declarative support for the following components:\n", + "ONLY LOAD COMPONENTS FROM TRUSTED SOURCES.\n", "\n", - "- Termination conditions ✔️\n", - "- Tools \n", - "- Agents \n", - "- Teams \n", + "With serilized components, each component implements the logic for how it is serialized and deserialized - i.e., how the declarative specification is generated and how it is converted back to an object. \n", "\n", + "In some cases, creating an object may include executing code (e.g., a serialized function). ONLY LOAD COMPONENTS FROM TRUSTED SOURCES. \n", + " \n", + "```\n", "\n", + " \n", "### Termination Condition Example \n", "\n", "In the example below, we will define termination conditions (a part of an agent team) in python, export this to a dictionary/json and also demonstrate how the termination condition object can be loaded from the dictionary/json. \n", diff --git a/python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py b/python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py index 77fc0b831427..50a8c5280935 100644 --- a/python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py +++ b/python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py @@ -42,7 +42,7 @@ class ImportFromModule: module: str imports: Tuple[Union[str, Alias], ...] - ## backward compatibility + # backward compatibility def __init__( self, module: str, @@ -214,3 +214,11 @@ def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str content += " ..." return content + + +def to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str: + return _to_code(func) + + +def import_to_str(im: Import) -> str: + return _import_to_str(im) diff --git a/python/packages/autogen-core/src/autogen_core/tools/_base.py b/python/packages/autogen-core/src/autogen_core/tools/_base.py index 7c4042e9afd6..b484ef84f3e9 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_base.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_base.py @@ -8,6 +8,7 @@ from typing_extensions import NotRequired from .. import CancellationToken +from .._component_config import ComponentBase from .._function_utils import normalize_annotated_type T = TypeVar("T", bound=BaseModel, contravariant=True) @@ -56,7 +57,9 @@ def load_state_json(self, state: Mapping[str, Any]) -> None: ... StateT = TypeVar("StateT", bound=BaseModel) -class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]): +class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT], ComponentBase[BaseModel]): + component_type = "tool" + def __init__( self, args_type: Type[ArgsT], @@ -132,7 +135,7 @@ def load_state_json(self, state: Mapping[str, Any]) -> None: pass -class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]): +class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT], ComponentBase[BaseModel]): def __init__( self, args_type: Type[ArgsT], @@ -144,6 +147,8 @@ def __init__( super().__init__(args_type, return_type, name, description) self._state_type = state_type + component_type = "tool" + @abstractmethod def save_state(self) -> StateT: ... diff --git a/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py b/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py index 026fc845e9c2..b43e061350e5 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py @@ -1,18 +1,33 @@ import asyncio import functools -from typing import Any, Callable +from textwrap import dedent +from typing import Any, Callable, Sequence +import warnings from pydantic import BaseModel +from typing_extensions import Self from .. import CancellationToken +from .._component_config import Component from .._function_utils import ( args_base_model_from_signature, get_typed_signature, ) +from ..code_executor._func_with_reqs import Import, import_to_str, to_code from ._base import BaseTool -class FunctionTool(BaseTool[BaseModel, BaseModel]): +class FunctionToolConfig(BaseModel): + """Configuration for a function tool.""" + + source_code: str + name: str + description: str + global_imports: Sequence[Import] + has_cancellation_support: bool + + +class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]): """ Create custom tools by wrapping standard Python functions. @@ -64,8 +79,14 @@ async def example(): asyncio.run(example()) """ - def __init__(self, func: Callable[..., Any], description: str, name: str | None = None) -> None: + component_provider_override = "autogen_core.tools.FunctionTool" + component_config_schema = FunctionToolConfig + + def __init__( + self, func: Callable[..., Any], description: str, name: str | None = None, global_imports: Sequence[Import] = [] + ) -> None: self._func = func + self._global_imports = global_imports signature = get_typed_signature(func) func_name = name or func.__name__ args_model = args_base_model_from_signature(func_name + "args", signature) @@ -98,3 +119,52 @@ async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> A result = await future return result + + def _to_config(self) -> FunctionToolConfig: + return FunctionToolConfig( + source_code=dedent(to_code(self._func)), + global_imports=self._global_imports, + name=self.name, + description=self.description, + has_cancellation_support=self._has_cancellation_support, + ) + + @classmethod + def _from_config(cls, config: FunctionToolConfig) -> Self: + warnings.warn( + "\n⚠️ SECURITY WARNING ⚠️\n" + "Loading a FunctionTool from config will execute code to import the provided global imports and and function code.\n" + "Only load configs from TRUSTED sources to prevent arbitrary code execution.", + UserWarning, + stacklevel=2, + ) + + exec_globals: dict[str, Any] = {} + + # Execute imports first + for import_stmt in config.global_imports: + import_code = import_to_str(import_stmt) + try: + exec(import_code, exec_globals) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Failed to import {import_code}: Module not found. Please ensure the module is installed." + ) from e + except ImportError as e: + raise ImportError(f"Failed to import {import_code}: {str(e)}") from e + except Exception as e: + raise RuntimeError(f"Unexpected error while importing {import_code}: {str(e)}") from e + + # Execute function code + try: + exec(config.source_code, exec_globals) + func_name = config.source_code.split("def ")[1].split("(")[0] + except Exception as e: + raise ValueError(f"Could not compile and load function: {e}") from e + + # Get function and verify it's callable + func: Callable[..., Any] = exec_globals[func_name] + if not callable(func): + raise TypeError(f"Expected function but got {type(func)}") + + return cls(func, "", None) diff --git a/python/packages/autogen-core/tests/test_component_config.py b/python/packages/autogen-core/tests/test_component_config.py index d59fde59c1b6..1f78e907a447 100644 --- a/python/packages/autogen-core/tests/test_component_config.py +++ b/python/packages/autogen-core/tests/test_component_config.py @@ -4,9 +4,11 @@ from typing import Any, Dict import pytest -from autogen_core import Component, ComponentBase, ComponentLoader, ComponentModel +from autogen_core import CancellationToken, Component, ComponentBase, ComponentLoader, ComponentModel from autogen_core._component_config import _type_to_provider_str # type: ignore +from autogen_core.code_executor import ImportFromModule from autogen_core.models import ChatCompletionClient +from autogen_core.tools import FunctionTool from autogen_test_utils import MyInnerComponent, MyOuterComponent from pydantic import BaseModel, ValidationError from typing_extensions import Self @@ -283,3 +285,68 @@ def test_component_version_from_dict() -> None: assert comp.info == "test" assert comp.__class__ == ComponentNonOneVersionWithUpgrade assert comp.dump_component().version == 2 + + +@pytest.mark.asyncio +async def test_function_tool() -> None: + """Test FunctionTool with different function types and features.""" + + # Test sync and async functions + def sync_func(x: int, y: str) -> str: + return y * x + + async def async_func(x: float, y: float, cancellation_token: CancellationToken) -> float: + if cancellation_token.is_cancelled(): + raise Exception("Cancelled") + return x + y + + # Create tools with different configurations + sync_tool = FunctionTool( + func=sync_func, description="Multiply string", global_imports=[ImportFromModule("typing", ("Dict",))] + ) + invalid_import_sync_tool = FunctionTool( + func=sync_func, description="Multiply string", global_imports=[ImportFromModule("invalid_module (", ("Dict",))] + ) + + invalid_import_config = invalid_import_sync_tool.dump_component() + # check that invalid import raises an error + with pytest.raises(RuntimeError): + _ = FunctionTool.load_component(invalid_import_config, FunctionTool) + + async_tool = FunctionTool( + func=async_func, + description="Add numbers", + name="custom_adder", + global_imports=[ImportFromModule("autogen_core", ("CancellationToken",))], + ) + + # Test serialization and config + + sync_config = sync_tool.dump_component() + assert isinstance(sync_config, ComponentModel) + assert sync_config.config["name"] == "sync_func" + assert len(sync_config.config["global_imports"]) == 1 + assert not sync_config.config["has_cancellation_support"] + + async_config = async_tool.dump_component() + assert async_config.config["name"] == "custom_adder" + assert async_config.config["has_cancellation_support"] + + # Test deserialization and execution + loaded_sync = FunctionTool.load_component(sync_config, FunctionTool) + loaded_async = FunctionTool.load_component(async_config, FunctionTool) + + # Test execution and validation + token = CancellationToken() + assert await loaded_sync.run_json({"x": 2, "y": "test"}, token) == "testtest" + assert await loaded_async.run_json({"x": 1.5, "y": 2.5}, token) == 4.0 + + # Test error cases + with pytest.raises(ValueError): + # Type error + await loaded_sync.run_json({"x": "invalid", "y": "test"}, token) + + cancelled_token = CancellationToken() + cancelled_token.cancel() + with pytest.raises(Exception, match="Cancelled"): + await loaded_async.run_json({"x": 1.0, "y": 2.0}, cancelled_token)