Skip to content

Commit 0a3a97d

Browse files
committed
extracted FunctionSchema to public module tools
1 parent 17df4fc commit 0a3a97d

File tree

3 files changed

+28
-27
lines changed

3 files changed

+28
-27
lines changed

pydantic_ai_slim/pydantic_ai/_pydantic.py

+10-22
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from __future__ import annotations as _annotations
77

88
from inspect import Parameter, signature
9-
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
9+
from typing import TYPE_CHECKING, Any, Callable, cast, get_origin
1010

1111
from pydantic import ConfigDict
1212
from pydantic._internal import _decorators, _generate_schema, _typing_extra
@@ -20,24 +20,12 @@
2020
from ._utils import check_object_json_schema, is_model_like
2121

2222
if TYPE_CHECKING:
23-
from .tools import DocstringFormat, ObjectJsonSchema
23+
from .tools import DocstringFormat, FunctionSchema
2424

2525

2626
__all__ = ('function_schema',)
2727

2828

29-
class FunctionSchema(TypedDict):
30-
"""Internal information about a function schema."""
31-
32-
description: str
33-
validator: SchemaValidator
34-
json_schema: ObjectJsonSchema
35-
# if not None, the function takes a single by that name (besides potentially `info`)
36-
single_arg_name: str | None
37-
positional_fields: list[str]
38-
var_positional_field: str | None
39-
40-
4129
def function_schema( # noqa: C901
4230
function: Callable[..., Any],
4331
takes_ctx: bool,
@@ -161,14 +149,14 @@ def function_schema( # noqa: C901
161149
# and set it on the tool
162150
description = json_schema.pop('description', None)
163151

164-
return FunctionSchema(
165-
description=description,
166-
validator=schema_validator,
167-
json_schema=check_object_json_schema(json_schema),
168-
single_arg_name=single_arg_name,
169-
positional_fields=positional_fields,
170-
var_positional_field=var_positional_field,
171-
)
152+
return {
153+
'description': description,
154+
'validator': schema_validator,
155+
'json_schema': check_object_json_schema(json_schema),
156+
'single_arg_name': single_arg_name,
157+
'positional_fields': positional_fields,
158+
'var_positional_field': var_positional_field,
159+
}
172160

173161

174162
def takes_ctx(function: Callable[..., Any]) -> bool:

pydantic_ai_slim/pydantic_ai/tools.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
from collections.abc import Awaitable
66
from dataclasses import dataclass, field
7-
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
7+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypedDict, Union, cast
88

99
from pydantic import ValidationError
1010
from pydantic_core import SchemaValidator
@@ -19,6 +19,7 @@
1919
__all__ = (
2020
'AgentDepsT',
2121
'DocstringFormat',
22+
'FunctionSchema',
2223
'RunContext',
2324
'SystemPromptFunc',
2425
'ToolFuncContext',
@@ -35,6 +36,18 @@
3536
"""Type variable for agent dependencies."""
3637

3738

39+
class FunctionSchema(TypedDict):
40+
"""Internal information about a function schema."""
41+
42+
description: str
43+
validator: SchemaValidator
44+
json_schema: ObjectJsonSchema
45+
# if not None, the function takes a single by that name (besides potentially `info`)
46+
single_arg_name: str | None
47+
positional_fields: list[str]
48+
var_positional_field: str | None
49+
50+
3851
@dataclasses.dataclass
3952
class RunContext(Generic[AgentDepsT]):
4053
"""Information about the current call."""

tests/test_tools.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,10 @@ def plain_tool(x: int) -> int:
347347

348348
def test_init_tool_with_function_schema():
349349
def x_tool(x: int) -> None:
350-
pass
350+
raise NotImplementedError
351351

352352
def y_tool(y: str) -> None:
353-
pass
353+
raise NotImplementedError
354354

355355
y_fs = _pydantic.function_schema(
356356
y_tool, takes_ctx=False, docstring_format='auto', require_parameter_descriptions=False
@@ -370,10 +370,10 @@ def y_tool(y: str) -> None:
370370

371371
def test_init_tool_ctx_with_function_schema():
372372
def x_tool(ctx: RunContext[int], x: int) -> None:
373-
pass
373+
raise NotImplementedError
374374

375375
def y_tool(ctx: RunContext[int], y: str) -> None:
376-
pass
376+
raise NotImplementedError
377377

378378
y_fs = _pydantic.function_schema(
379379
y_tool, takes_ctx=True, docstring_format='auto', require_parameter_descriptions=False

0 commit comments

Comments
 (0)