-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add langchain tool adapter in autogen-ext (#570)
* add langhcain tool adapter * remove langchain package * fix type errors * test type fixes * fix imports * install extras in CI * improve typing and use to_thread * pin min langchain version * install all extras in ci test * update to langchain 0.3.1 * install extras in CI * ignore pyright errors * add missing uv sync extra reqs --------- Co-authored-by: Leonardo Pinheiro <[email protected]> Co-authored-by: Eric Zhu <[email protected]> Co-authored-by: Ryan Sweet <[email protected]> Co-authored-by: Jack Gerrits <[email protected]>
- Loading branch information
1 parent
18d52f6
commit 6cfa29b
Showing
7 changed files
with
298 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from ._langchain_adapter import LangChainToolAdapter | ||
|
||
__all__ = ["LangChainToolAdapter"] |
75 changes: 75 additions & 0 deletions
75
python/packages/autogen-ext/src/autogen_ext/tools/langchain/_langchain_adapter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import asyncio | ||
import inspect | ||
from typing import Any, Callable, Dict, Tuple, Type, cast | ||
|
||
from autogen_core.base import CancellationToken | ||
from autogen_core.components.tools import BaseTool | ||
from pydantic import BaseModel, Field, create_model | ||
from pydantic.fields import FieldInfo | ||
|
||
from langchain.tools import Tool as LangChainTool | ||
|
||
FieldDefinition = Tuple[Type[Any], FieldInfo] | ||
FieldsDict = Dict[str, FieldDefinition] | ||
|
||
|
||
class LangChainToolAdapter(BaseTool[BaseModel, Any]): | ||
langchain_tool: LangChainTool | ||
_callable: Callable[..., Any] | ||
|
||
def __init__(self, langchain_tool: LangChainTool): | ||
self.langchain_tool = langchain_tool | ||
|
||
# Extract name and description | ||
name = langchain_tool.name | ||
description = langchain_tool.description or "" | ||
|
||
# Determine the callable method | ||
if hasattr(langchain_tool, "func") and callable(langchain_tool.func): | ||
assert langchain_tool.func is not None | ||
self._callable = langchain_tool.func | ||
elif hasattr(langchain_tool, "_run") and callable(langchain_tool._run): # pyright: ignore | ||
self._callable = langchain_tool._run # type: ignore | ||
else: | ||
raise AttributeError( | ||
f"The provided LangChain tool '{name}' does not have a callable 'func' or '_run' method." | ||
) | ||
|
||
# Determine args_type | ||
if langchain_tool.args_schema: # pyright: ignore | ||
args_type = langchain_tool.args_schema # pyright: ignore | ||
else: | ||
# Infer args_type from the callable's signature | ||
sig = inspect.signature(cast(Callable[..., Any], self._callable)) | ||
fields = { | ||
k: (v.annotation, Field(...)) | ||
for k, v in sig.parameters.items() | ||
if k != "self" and v.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) | ||
} | ||
args_type = create_model(f"{name}Args", **fields) # type: ignore | ||
# Note: type ignore is used due to a LangChain typing limitation | ||
|
||
# Ensure args_type is a subclass of BaseModel | ||
if not issubclass(args_type, BaseModel): | ||
raise ValueError(f"Failed to create a valid Pydantic v2 model for {name}") | ||
|
||
# Assume return_type as Any if not specified | ||
return_type: Type[Any] = object | ||
|
||
super().__init__(args_type, return_type, name, description) | ||
|
||
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any: | ||
# Prepare arguments | ||
kwargs = args.model_dump() | ||
|
||
# Determine if the callable is asynchronous | ||
if inspect.iscoroutinefunction(self._callable): | ||
result = await self._callable(**kwargs) | ||
else: | ||
# Run in a thread to avoid blocking the event loop | ||
result = await asyncio.to_thread(self._call_sync, kwargs) | ||
|
||
return result | ||
|
||
def _call_sync(self, kwargs: Dict[str, Any]) -> Any: | ||
return self._callable(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from typing import Optional, Type | ||
|
||
import pytest | ||
from autogen_core.base import CancellationToken | ||
from autogen_ext.tools.langchain import LangChainToolAdapter # type: ignore | ||
from langchain.tools import BaseTool as LangChainTool # type: ignore | ||
from langchain.tools import tool # pyright: ignore | ||
from langchain_core.callbacks.manager import AsyncCallbackManagerForToolRun, CallbackManagerForToolRun | ||
from pydantic import BaseModel, Field | ||
|
||
|
||
@tool # type: ignore | ||
def add(a: int, b: int) -> int: | ||
"""Add two numbers""" | ||
return a + b | ||
|
||
|
||
class CalculatorInput(BaseModel): | ||
a: int = Field(description="first number") | ||
b: int = Field(description="second number") | ||
|
||
|
||
class CustomCalculatorTool(LangChainTool): | ||
name: str = "Calculator" | ||
description: str = "useful for when you need to answer questions about math" | ||
args_schema: Type[BaseModel] = CalculatorInput | ||
return_direct: bool = True | ||
|
||
def _run(self, a: int, b: int, run_manager: Optional[CallbackManagerForToolRun] = None) -> int: | ||
"""Use the tool.""" | ||
return a * b | ||
|
||
async def _arun( | ||
self, | ||
a: int, | ||
b: int, | ||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | ||
) -> int: | ||
"""Use the tool asynchronously.""" | ||
return self._run(a, b, run_manager=run_manager.get_sync() if run_manager else None) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_langchain_tool_adapter() -> None: | ||
# Create a LangChain tool | ||
langchain_tool = add # type: ignore | ||
|
||
# Create an adapter | ||
adapter = LangChainToolAdapter(langchain_tool) # pyright: ignore | ||
|
||
# Test schema generation | ||
schema = adapter.schema | ||
|
||
assert schema["name"] == "add" | ||
assert "description" in schema | ||
assert schema["description"] == "Add two numbers" | ||
assert "parameters" in schema | ||
assert schema["parameters"]["type"] == "object" | ||
assert "properties" in schema["parameters"] | ||
assert "a" in schema["parameters"]["properties"] | ||
assert "b" in schema["parameters"]["properties"] | ||
assert schema["parameters"]["properties"]["a"]["type"] == "integer" | ||
assert schema["parameters"]["properties"]["b"]["type"] == "integer" | ||
assert "required" in schema["parameters"] | ||
assert set(schema["parameters"]["required"]) == {"a", "b"} | ||
assert len(schema["parameters"]["properties"]) == 2 | ||
|
||
# Test run method | ||
result = await adapter.run_json({"a": 2, "b": 3}, CancellationToken()) | ||
assert result == 5 | ||
|
||
# Test that the adapter's run method can be called multiple times | ||
result = await adapter.run_json({"a": 5, "b": 7}, CancellationToken()) | ||
assert result == 12 | ||
|
||
# Test CustomCalculatorTool | ||
custom_langchain_tool = CustomCalculatorTool() | ||
custom_adapter = LangChainToolAdapter(custom_langchain_tool) # pyright: ignore | ||
|
||
# Test schema generation for CustomCalculatorTool | ||
custom_schema = custom_adapter.schema | ||
|
||
assert custom_schema["name"] == "Calculator" | ||
assert custom_schema["description"] == "useful for when you need to answer questions about math" # type: ignore | ||
assert "parameters" in custom_schema | ||
assert custom_schema["parameters"]["type"] == "object" | ||
assert "properties" in custom_schema["parameters"] | ||
assert "a" in custom_schema["parameters"]["properties"] | ||
assert "b" in custom_schema["parameters"]["properties"] | ||
assert custom_schema["parameters"]["properties"]["a"]["type"] == "integer" | ||
assert custom_schema["parameters"]["properties"]["b"]["type"] == "integer" | ||
assert "required" in custom_schema["parameters"] | ||
assert set(custom_schema["parameters"]["required"]) == {"a", "b"} | ||
|
||
# Test run method for CustomCalculatorTool | ||
custom_result = await custom_adapter.run_json({"a": 3, "b": 4}, CancellationToken()) | ||
assert custom_result == 12 |
Oops, something went wrong.