Skip to content

Commit

Permalink
Add langchain tool adapter in autogen-ext (#570)
Browse files Browse the repository at this point in the history
* add langhcain tool adapter

* remove langchain package

* fix type errors

* test type fixes

* fix imports

* install extras in CI

* improve typing and use to_thread

* pin min langchain version

* install all extras in ci test

* update to langchain 0.3.1

* install extras in CI

* ignore pyright errors

* add missing uv sync extra reqs

---------

Co-authored-by: Leonardo Pinheiro <[email protected]>
Co-authored-by: Eric Zhu <[email protected]>
Co-authored-by: Ryan Sweet <[email protected]>
Co-authored-by: Jack Gerrits <[email protected]>
  • Loading branch information
5 people committed Sep 30, 2024
1 parent 18d52f6 commit 6cfa29b
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 63 deletions.
18 changes: 10 additions & 8 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- run: uv sync --locked
- run: uv sync --locked --all-extras
working-directory: ./python
- name: Run task
run: |
Expand All @@ -33,7 +33,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- run: uv sync --locked
- run: uv sync --locked --all-extras
working-directory: ./python
- name: Run task
run: |
Expand All @@ -59,7 +59,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- run: uv sync --locked
- run: uv sync --locked --all-extras
working-directory: ./python
- name: Run task
run: |
Expand All @@ -85,14 +85,13 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- run: uv sync --locked
- run: uv sync --locked --all-extras
working-directory: ./python
- name: Run task
run: |
source ${{ github.workspace }}/python/.venv/bin/activate
poe --directory ${{ matrix.package }} pyright
working-directory: ./python

test:
runs-on: ubuntu-latest
strategy:
Expand All @@ -110,7 +109,10 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- run: uv sync --locked
- name: Run uv sync
run: |
uv sync --locked --all-extras
working-directory: ./python
- name: Run task
run: |
Expand All @@ -129,7 +131,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- run: uv sync --locked
- run: uv sync --locked --all-extras
working-directory: ./python
- name: Run task
run: |
Expand All @@ -145,7 +147,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- run: uv sync --locked
- run: uv sync --locked --all-extras
working-directory: ./python
- name: Run task
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
with:
python-version: '3.11'
- run: |
uv sync --locked
uv sync --locked --all-extras
source .venv/bin/activate
poe --directory ./packages/autogen-core docs-build
mkdir -p docs-staging/autogen/dev/
Expand Down
8 changes: 6 additions & 2 deletions python/packages/autogen-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = ["autogen-core",

dependencies = [
"autogen-core",
]


[project.optional-dependencies]
langchain = ["langchain >= 0.3.1"]

[tool.hatch.build.targets.wheel]
packages = ["src/autogen_ext"]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._langchain_adapter import LangChainToolAdapter

__all__ = ["LangChainToolAdapter"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import asyncio
import inspect
from typing import Any, Callable, Dict, Tuple, Type, cast

from autogen_core.base import CancellationToken
from autogen_core.components.tools import BaseTool
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo

from langchain.tools import Tool as LangChainTool

FieldDefinition = Tuple[Type[Any], FieldInfo]
FieldsDict = Dict[str, FieldDefinition]


class LangChainToolAdapter(BaseTool[BaseModel, Any]):
langchain_tool: LangChainTool
_callable: Callable[..., Any]

def __init__(self, langchain_tool: LangChainTool):
self.langchain_tool = langchain_tool

# Extract name and description
name = langchain_tool.name
description = langchain_tool.description or ""

# Determine the callable method
if hasattr(langchain_tool, "func") and callable(langchain_tool.func):
assert langchain_tool.func is not None
self._callable = langchain_tool.func
elif hasattr(langchain_tool, "_run") and callable(langchain_tool._run): # pyright: ignore
self._callable = langchain_tool._run # type: ignore
else:
raise AttributeError(
f"The provided LangChain tool '{name}' does not have a callable 'func' or '_run' method."
)

# Determine args_type
if langchain_tool.args_schema: # pyright: ignore
args_type = langchain_tool.args_schema # pyright: ignore
else:
# Infer args_type from the callable's signature
sig = inspect.signature(cast(Callable[..., Any], self._callable))
fields = {
k: (v.annotation, Field(...))
for k, v in sig.parameters.items()
if k != "self" and v.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
}
args_type = create_model(f"{name}Args", **fields) # type: ignore
# Note: type ignore is used due to a LangChain typing limitation

# Ensure args_type is a subclass of BaseModel
if not issubclass(args_type, BaseModel):
raise ValueError(f"Failed to create a valid Pydantic v2 model for {name}")

# Assume return_type as Any if not specified
return_type: Type[Any] = object

super().__init__(args_type, return_type, name, description)

async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
# Prepare arguments
kwargs = args.model_dump()

# Determine if the callable is asynchronous
if inspect.iscoroutinefunction(self._callable):
result = await self._callable(**kwargs)
else:
# Run in a thread to avoid blocking the event loop
result = await asyncio.to_thread(self._call_sync, kwargs)

return result

def _call_sync(self, kwargs: Dict[str, Any]) -> Any:
return self._callable(**kwargs)
97 changes: 97 additions & 0 deletions python/packages/autogen-ext/tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Optional, Type

import pytest
from autogen_core.base import CancellationToken
from autogen_ext.tools.langchain import LangChainToolAdapter # type: ignore
from langchain.tools import BaseTool as LangChainTool # type: ignore
from langchain.tools import tool # pyright: ignore
from langchain_core.callbacks.manager import AsyncCallbackManagerForToolRun, CallbackManagerForToolRun
from pydantic import BaseModel, Field


@tool # type: ignore
def add(a: int, b: int) -> int:
"""Add two numbers"""
return a + b


class CalculatorInput(BaseModel):
a: int = Field(description="first number")
b: int = Field(description="second number")


class CustomCalculatorTool(LangChainTool):
name: str = "Calculator"
description: str = "useful for when you need to answer questions about math"
args_schema: Type[BaseModel] = CalculatorInput
return_direct: bool = True

def _run(self, a: int, b: int, run_manager: Optional[CallbackManagerForToolRun] = None) -> int:
"""Use the tool."""
return a * b

async def _arun(
self,
a: int,
b: int,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> int:
"""Use the tool asynchronously."""
return self._run(a, b, run_manager=run_manager.get_sync() if run_manager else None)


@pytest.mark.asyncio
async def test_langchain_tool_adapter() -> None:
# Create a LangChain tool
langchain_tool = add # type: ignore

# Create an adapter
adapter = LangChainToolAdapter(langchain_tool) # pyright: ignore

# Test schema generation
schema = adapter.schema

assert schema["name"] == "add"
assert "description" in schema
assert schema["description"] == "Add two numbers"
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert "properties" in schema["parameters"]
assert "a" in schema["parameters"]["properties"]
assert "b" in schema["parameters"]["properties"]
assert schema["parameters"]["properties"]["a"]["type"] == "integer"
assert schema["parameters"]["properties"]["b"]["type"] == "integer"
assert "required" in schema["parameters"]
assert set(schema["parameters"]["required"]) == {"a", "b"}
assert len(schema["parameters"]["properties"]) == 2

# Test run method
result = await adapter.run_json({"a": 2, "b": 3}, CancellationToken())
assert result == 5

# Test that the adapter's run method can be called multiple times
result = await adapter.run_json({"a": 5, "b": 7}, CancellationToken())
assert result == 12

# Test CustomCalculatorTool
custom_langchain_tool = CustomCalculatorTool()
custom_adapter = LangChainToolAdapter(custom_langchain_tool) # pyright: ignore

# Test schema generation for CustomCalculatorTool
custom_schema = custom_adapter.schema

assert custom_schema["name"] == "Calculator"
assert custom_schema["description"] == "useful for when you need to answer questions about math" # type: ignore
assert "parameters" in custom_schema
assert custom_schema["parameters"]["type"] == "object"
assert "properties" in custom_schema["parameters"]
assert "a" in custom_schema["parameters"]["properties"]
assert "b" in custom_schema["parameters"]["properties"]
assert custom_schema["parameters"]["properties"]["a"]["type"] == "integer"
assert custom_schema["parameters"]["properties"]["b"]["type"] == "integer"
assert "required" in custom_schema["parameters"]
assert set(custom_schema["parameters"]["required"]) == {"a", "b"}

# Test run method for CustomCalculatorTool
custom_result = await custom_adapter.run_json({"a": 3, "b": 4}, CancellationToken())
assert custom_result == 12
Loading

0 comments on commit 6cfa29b

Please sign in to comment.