Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make FunctionTools Serializable (Declarative) #5052

Merged
merged 14 commits into from
Jan 24, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
TimeoutTermination,
TokenUsageTermination,
)
from autogen_core import ComponentLoader, ComponentModel
from autogen_core import ComponentLoader, ComponentModel, CancellationToken
from autogen_core.tools import FunctionTool
from autogen_core.code_executor import ImportFromModule


@pytest.mark.asyncio
Expand Down Expand Up @@ -92,3 +94,59 @@ async def test_termination_declarative() -> None:
# Test loading complex composition
loaded_composite = ComponentLoader.load_component(composite_config)
assert isinstance(loaded_composite, AndTerminationCondition)


@pytest.mark.asyncio
async def test_function_tool() -> None:
"""Test FunctionTool with different function types and features."""

# Test sync and async functions
def sync_func(x: int, y: str) -> str:
return y * x

async def async_func(x: float, y: float, cancellation_token: CancellationToken) -> float:
if cancellation_token.is_cancelled():
raise Exception("Cancelled")
return x + y

# Create tools with different configurations
sync_tool = FunctionTool(
func=sync_func, description="Multiply string", global_imports=[ImportFromModule("typing", ("Dict",))]
)
async_tool = FunctionTool(
func=async_func,
description="Add numbers",
name="custom_adder",
global_imports=[ImportFromModule("autogen_core", ("CancellationToken",))],
)

# Test serialization and config

sync_config = sync_tool.dump_component()
assert isinstance(sync_config, ComponentModel)
assert sync_config.config["name"] == "sync_func"
assert len(sync_config.config["global_imports"]) == 1
assert not sync_config.config["has_cancellation_support"]

async_config = async_tool.dump_component()
assert async_config.config["name"] == "custom_adder"
assert async_config.config["has_cancellation_support"]

# Test deserialization and execution
loaded_sync = FunctionTool.load_component(sync_config, FunctionTool)
loaded_async = FunctionTool.load_component(async_config, FunctionTool)

# Test execution and validation
token = CancellationToken()
assert await loaded_sync.run_json({"x": 2, "y": "test"}, token) == "testtest"
assert await loaded_async.run_json({"x": 1.5, "y": 2.5}, token) == 4.0

# Test error cases
with pytest.raises(ValueError):
# Type error
await loaded_sync.run_json({"x": "invalid", "y": "test"}, token)

cancelled_token = CancellationToken()
cancelled_token.cancel()
with pytest.raises(Exception, match="Cancelled"):
await loaded_async.run_json({"x": 1.0, "y": 2.0}, cancelled_token)
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
module: str
imports: Tuple[Union[str, Alias], ...]

## backward compatibility
# backward compatibility
def __init__(
self,
module: str,
Expand Down Expand Up @@ -214,3 +214,11 @@

content += " ..."
return content


def to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
return _to_code(func)

Check warning on line 220 in python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py#L220

Added line #L220 was not covered by tests


def import_to_str(im: Import) -> str:
return _import_to_str(im)

Check warning on line 224 in python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py#L224

Added line #L224 was not covered by tests
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from ._base import BaseTool, BaseToolWithState, ParametersSchema, Tool, ToolSchema
from ._function_tool import FunctionTool

Expand Down
9 changes: 7 additions & 2 deletions python/packages/autogen-core/src/autogen_core/tools/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing_extensions import NotRequired

from .. import CancellationToken
from .._component_config import ComponentBase
from .._function_utils import normalize_annotated_type

T = TypeVar("T", bound=BaseModel, contravariant=True)
Expand Down Expand Up @@ -56,7 +57,9 @@ def load_state_json(self, state: Mapping[str, Any]) -> None: ...
StateT = TypeVar("StateT", bound=BaseModel)


class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT], ComponentBase[BaseModel]):
component_type = "tool"

def __init__(
self,
args_type: Type[ArgsT],
Expand Down Expand Up @@ -132,7 +135,7 @@ def load_state_json(self, state: Mapping[str, Any]) -> None:
pass


class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]):
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT], ComponentBase[BaseModel]):
def __init__(
self,
args_type: Type[ArgsT],
Expand All @@ -144,6 +147,8 @@ def __init__(
super().__init__(args_type, return_type, name, description)
self._state_type = state_type

component_type = "tool"

@abstractmethod
def save_state(self) -> StateT: ...

Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,32 @@
import asyncio
import functools
from typing import Any, Callable
from textwrap import dedent
from typing import Any, Callable, Sequence

from pydantic import BaseModel
from typing_extensions import Self

from .. import CancellationToken
from .._component_config import Component
from .._function_utils import (
args_base_model_from_signature,
get_typed_signature,
)
from ..code_executor._func_with_reqs import Import, import_to_str, to_code
from ._base import BaseTool


class FunctionTool(BaseTool[BaseModel, BaseModel]):
class FunctionToolConfig(BaseModel):
"""Configuration for a function tool."""

source_code: str
name: str
description: str
global_imports: Sequence[Import]
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
has_cancellation_support: bool


class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]):
"""
Create custom tools by wrapping standard Python functions.

Expand Down Expand Up @@ -64,8 +78,14 @@
asyncio.run(example())
"""

def __init__(self, func: Callable[..., Any], description: str, name: str | None = None) -> None:
component_provider_override = "autogen_core.tools.FunctionTool"
component_config_schema = FunctionToolConfig

def __init__(
self, func: Callable[..., Any], description: str, name: str | None = None, global_imports: Sequence[Import] = []
) -> None:
self._func = func
self._global_imports = global_imports
signature = get_typed_signature(func)
func_name = name or func.__name__
args_model = args_base_model_from_signature(func_name + "args", signature)
Expand Down Expand Up @@ -98,3 +118,44 @@
result = await future

return result

def _to_config(self) -> FunctionToolConfig:
return FunctionToolConfig(

Check warning on line 123 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L123

Added line #L123 was not covered by tests
source_code=dedent(to_code(self._func)),
global_imports=self._global_imports,
name=self.name,
description=self.description,
has_cancellation_support=self._has_cancellation_support,
)

@classmethod
def _from_config(cls, config: FunctionToolConfig) -> Self:
exec_globals: dict[str, Any] = {}

Check warning on line 133 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L133

Added line #L133 was not covered by tests

# Execute imports first
for import_stmt in config.global_imports:
import_code = import_to_str(import_stmt)
try:
exec(import_code, exec_globals)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(

Check warning on line 141 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L136-L141

Added lines #L136 - L141 were not covered by tests
f"Failed to import {import_code}: Module not found. Please ensure the module is installed."
) from e
except ImportError as e:
raise ImportError(f"Failed to import {import_code}: {str(e)}") from e
except Exception as e:
raise RuntimeError(f"Unexpected error while importing {import_code}: {str(e)}") from e

Check warning on line 147 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L144-L147

Added lines #L144 - L147 were not covered by tests

# Execute function code
try:
exec(config.source_code, exec_globals)
func_name = config.source_code.split("def ")[1].split("(")[0]
except Exception as e:
raise ValueError(f"Could not compile and load function: {e}") from e

Check warning on line 154 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L150-L154

Added lines #L150 - L154 were not covered by tests

# Get function and verify it's callable
func: Callable[..., Any] = exec_globals[func_name]
if not callable(func):
raise TypeError(f"Expected function but got {type(func)}")

Check warning on line 159 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L157-L159

Added lines #L157 - L159 were not covered by tests

return cls(func, "", None)

Check warning on line 161 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L161

Added line #L161 was not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@
Image,
MessageHandlerContext,
)
from autogen_core.models import FinishReasons
from autogen_core.logging import LLMCallEvent
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
ChatCompletionTokenLogprob,
CreateResult,
FinishReasons,
FunctionExecutionResultMessage,
LLMMessage,
ModelCapabilities, # type: ignore
Expand Down
Loading