diff --git a/docs/agents.md b/docs/agents.md index 223c1a124..26e041ac9 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -220,6 +220,146 @@ Once the run finishes, `agent_run.final_result` becomes a [`AgentRunResult`][pyd --- +### Streaming + +Here is an example of streaming an agent run in combination with `async for` iteration: + +```python {title="streaming.py"} +import asyncio +from dataclasses import dataclass +from datetime import date + +from pydantic_ai import Agent +from pydantic_ai.messages import ( + FinalResultEvent, + FunctionToolCallEvent, + FunctionToolResultEvent, + PartDeltaEvent, + PartStartEvent, + TextPartDelta, + ToolCallPartDelta, +) +from pydantic_ai.tools import RunContext + + +@dataclass +class WeatherService: + async def get_forecast(self, location: str, forecast_date: date) -> str: + # In real code: call weather API, DB queries, etc. + return f'The forecast in {location} on {forecast_date} is 24°C and sunny.' + + async def get_historic_weather(self, location: str, forecast_date: date) -> str: + # In real code: call a historical weather API or DB + return ( + f'The weather in {location} on {forecast_date} was 18°C and partly cloudy.' + ) + + +weather_agent = Agent[WeatherService, str]( + 'openai:gpt-4o', + deps_type=WeatherService, + result_type=str, # We'll produce a final answer as plain text + system_prompt='Providing a weather forecast at the locations the user provides.', +) + + +@weather_agent.tool +async def weather_forecast( + ctx: RunContext[WeatherService], + location: str, + forecast_date: date, +) -> str: + if forecast_date >= date.today(): + return await ctx.deps.get_forecast(location, forecast_date) + else: + return await ctx.deps.get_historic_weather(location, forecast_date) + + +output_messages: list[str] = [] + + +async def main(): + user_prompt = 'What will the weather be like in Paris on Tuesday?' + + # Begin a node-by-node, streaming iteration + with weather_agent.iter(user_prompt, deps=WeatherService()) as run: + async for node in run: + if Agent.is_user_prompt_node(node): + # A user prompt node => The user has provided input + output_messages.append(f'=== UserPromptNode: {node.user_prompt} ===') + elif Agent.is_model_request_node(node): + # A model request node => We can stream tokens from the model's request + output_messages.append( + '=== ModelRequestNode: streaming partial request tokens ===' + ) + async with node.stream(run.ctx) as request_stream: + async for event in request_stream: + if isinstance(event, PartStartEvent): + output_messages.append( + f'[Request] Starting part {event.index}: {event.part!r}' + ) + elif isinstance(event, PartDeltaEvent): + if isinstance(event.delta, TextPartDelta): + output_messages.append( + f'[Request] Part {event.index} text delta: {event.delta.content_delta!r}' + ) + elif isinstance(event.delta, ToolCallPartDelta): + output_messages.append( + f'[Request] Part {event.index} args_delta={event.delta.args_delta}' + ) + elif isinstance(event, FinalResultEvent): + output_messages.append( + f'[Result] The model produced a final result (tool_name={event.tool_name})' + ) + elif Agent.is_handle_response_node(node): + # A handle-response node => The model returned some data, potentially calls a tool + output_messages.append( + '=== HandleResponseNode: streaming partial response & tool usage ===' + ) + async with node.stream(run.ctx) as handle_stream: + async for event in handle_stream: + if isinstance(event, FunctionToolCallEvent): + output_messages.append( + f'[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})' + ) + elif isinstance(event, FunctionToolResultEvent): + output_messages.append( + f'[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}' + ) + elif Agent.is_end_node(node): + assert run.result.data == node.data.data + # Once an End node is reached, the agent run is complete + output_messages.append(f'=== Final Agent Output: {run.result.data} ===') + + +if __name__ == '__main__': + asyncio.run(main()) + + print(output_messages) + """ + [ + '=== ModelRequestNode: streaming partial request tokens ===', + '[Request] Starting part 0: ToolCallPart(tool_name=\'weather_forecast\', args=\'{"location":"Pa\', tool_call_id=\'0001\', part_kind=\'tool-call\')', + '[Request] Part 0 args_delta=ris","forecast_', + '[Request] Part 0 args_delta=date":"2030-01-', + '[Request] Part 0 args_delta=01"}', + '=== HandleResponseNode: streaming partial response & tool usage ===', + '[Tools] The LLM calls tool=\'weather_forecast\' with args={"location":"Paris","forecast_date":"2030-01-01"} (tool_call_id=\'0001\')', + "[Tools] Tool call '0001' returned => The forecast in Paris on 2030-01-01 is 24°C and sunny.", + '=== ModelRequestNode: streaming partial request tokens ===', + "[Request] Starting part 0: TextPart(content='It will be ', part_kind='text')", + '[Result] The model produced a final result (tool_name=None)', + "[Request] Part 0 text delta: 'warm and sunny '", + "[Request] Part 0 text delta: 'in Paris on '", + "[Request] Part 0 text delta: 'Tuesday.'", + '=== HandleResponseNode: streaming partial response & tool usage ===', + '=== Final Agent Output: It will be warm and sunny in Paris on Tuesday. ===', + ] + """ +``` + +--- + ### Additional Configuration #### Usage Limits diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index ec9f16568..d4a7ac14f 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -2,7 +2,6 @@ import asyncio import dataclasses -from abc import ABC from collections.abc import AsyncIterator, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar @@ -10,7 +9,7 @@ from typing import Any, Generic, Literal, Union, cast import logfire_api -from typing_extensions import TypeVar, assert_never +from typing_extensions import TypeGuard, TypeVar, assert_never from pydantic_graph import BaseNode, Graph, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT @@ -55,6 +54,7 @@ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) T = TypeVar('T') +S = TypeVar('S') NoneType = type(None) EndStrategy = Literal['early', 'exhaustive'] """The strategy for handling multiple tool calls when a final result is found. @@ -107,8 +107,31 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]): run_span: logfire_api.LogfireSpan +class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): + """The base class for all agent nodes. + + Using subclass of `BaseNode` for all nodes reduces the amount of boilerplate of generics everywhere + """ + + +def is_agent_node( + node: BaseNode[GraphAgentState, GraphAgentDeps[T, Any], result.FinalResult[S]] | End[result.FinalResult[S]], +) -> TypeGuard[AgentNode[T, S]]: + """Check if the provided node is an instance of `AgentNode`. + + Usage: + + if is_agent_node(node): + # `node` is an AgentNode + ... + + This method preserves the generic parameters on the narrowed type, unlike `isinstance(node, AgentNode)`. + """ + return isinstance(node, AgentNode) + + @dataclasses.dataclass -class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC): +class UserPromptNode(AgentNode[DepsT, NodeRunEndT]): user_prompt: str | Sequence[_messages.UserContent] system_prompts: tuple[str, ...] @@ -215,7 +238,7 @@ async def add_tool(tool: Tool[DepsT]) -> None: @dataclasses.dataclass -class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): +class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]): """Make a request to the model using the last message in state.message_history.""" request: _messages.ModelRequest @@ -236,12 +259,30 @@ async def run( return await self._make_request(ctx) + @asynccontextmanager + async def stream( + self, + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], + ) -> AsyncIterator[result.AgentStream[DepsT, T]]: + async with self._stream(ctx) as streamed_response: + agent_stream = result.AgentStream[DepsT, T]( + streamed_response, + ctx.deps.result_schema, + ctx.deps.result_validators, + build_run_context(ctx), + ctx.deps.usage_limits, + ) + yield agent_stream + # In case the user didn't manually consume the full stream, ensure it is fully consumed here, + # otherwise usage won't be properly counted: + async for _ in agent_stream: + pass + @asynccontextmanager async def _stream( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], ) -> AsyncIterator[models.StreamedResponse]: - # TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public assert not self._did_stream, 'stream() should only be called once per node' model_settings, model_request_parameters = await self._prepare_request(ctx) @@ -319,7 +360,7 @@ def _finish_handling( @dataclasses.dataclass -class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): +class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]): """Process a model response, and decide whether to end the run or make a new request.""" model_response: _messages.ModelResponse @@ -575,7 +616,7 @@ async def process_function_tools( for task in done: index = tasks.index(task) result = task.result() - yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index]) + yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index]) if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)): results_by_index[index] = result else: diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 84e4c00df..041aeb99d 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -9,9 +9,9 @@ from typing import Any, Callable, Generic, cast, final, overload import logfire_api -from typing_extensions import TypeVar, deprecated +from typing_extensions import TypeGuard, TypeVar, deprecated -from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext +from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop from . import ( @@ -46,7 +46,6 @@ ModelRequestNode = _agent_graph.ModelRequestNode UserPromptNode = _agent_graph.UserPromptNode - __all__ = ( 'Agent', 'AgentRun', @@ -71,6 +70,7 @@ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) T = TypeVar('T') +S = TypeVar('S') NoneType = type(None) RunResultDataT = TypeVar('RunResultDataT') """Type variable for the result data of a run where `result_type` was customized on the run call.""" @@ -646,10 +646,9 @@ async def main(): ) as agent_run: first_node = agent_run.next_node # start with the first node assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node - node: BaseNode[Any, Any, Any] = cast(BaseNode[Any, Any, Any], first_node) + node = first_node while True: - if isinstance(node, _agent_graph.ModelRequestNode): - node = cast(_agent_graph.ModelRequestNode[AgentDepsT, Any], node) + if self.is_model_request_node(node): graph_ctx = agent_run.ctx async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage] @@ -717,9 +716,9 @@ async def on_complete() -> None: ) break next_node = await agent_run.next(node) - if not isinstance(next_node, BaseNode): + if not isinstance(next_node, _agent_graph.AgentNode): raise exceptions.AgentRunError('Should have produced a StreamedRunResult before getting here') - node = cast(BaseNode[Any, Any, Any], next_node) + node = cast(_agent_graph.AgentNode[Any, Any], next_node) if not yielded: raise exceptions.AgentRunError('Agent run finished without producing a final result') @@ -1173,6 +1172,46 @@ def _prepare_result_schema( else: return self._result_schema # pyright: ignore[reportReturnType] + @staticmethod + def is_model_request_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeGuard[_agent_graph.ModelRequestNode[T, S]]: + """Check if the node is a `ModelRequestNode`, narrowing the type if it is. + + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, _agent_graph.ModelRequestNode) + + @staticmethod + def is_handle_response_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeGuard[_agent_graph.HandleResponseNode[T, S]]: + """Check if the node is a `HandleResponseNode`, narrowing the type if it is. + + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, _agent_graph.HandleResponseNode) + + @staticmethod + def is_user_prompt_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeGuard[_agent_graph.UserPromptNode[T, S]]: + """Check if the node is a `UserPromptNode`, narrowing the type if it is. + + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, _agent_graph.UserPromptNode) + + @staticmethod + def is_end_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeGuard[End[result.FinalResult[S]]]: + """Check if the node is a `End`, narrowing the type if it is. + + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, End) + @dataclasses.dataclass(repr=False) class AgentRun(Generic[AgentDepsT, ResultDataT]): @@ -1244,15 +1283,17 @@ def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.Grap @property def next_node( self, - ) -> ( - BaseNode[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT]] - | End[FinalResult[ResultDataT]] - ): + ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]: """The next node that will be run in the agent graph. This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. """ - return self._graph_run.next_node + next_node = self._graph_run.next_node + if isinstance(next_node, End): + return next_node + if _agent_graph.is_agent_node(next_node): + return next_node + raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover @property def result(self) -> AgentRunResult[ResultDataT] | None: @@ -1273,45 +1314,24 @@ def result(self) -> AgentRunResult[ResultDataT] | None: def __aiter__( self, - ) -> AsyncIterator[ - BaseNode[ - _agent_graph.GraphAgentState, - _agent_graph.GraphAgentDeps[AgentDepsT, Any], - FinalResult[ResultDataT], - ] - | End[FinalResult[ResultDataT]] - ]: + ) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]]: """Provide async-iteration over the nodes in the agent run.""" return self async def __anext__( self, - ) -> ( - BaseNode[ - _agent_graph.GraphAgentState, - _agent_graph.GraphAgentDeps[AgentDepsT, Any], - FinalResult[ResultDataT], - ] - | End[FinalResult[ResultDataT]] - ): + ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]: """Advance to the next node automatically based on the last returned node.""" - return await self._graph_run.__anext__() + next_node = await self._graph_run.__anext__() + if _agent_graph.is_agent_node(next_node): + return next_node + assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' + return next_node async def next( self, - node: BaseNode[ - _agent_graph.GraphAgentState, - _agent_graph.GraphAgentDeps[AgentDepsT, Any], - FinalResult[ResultDataT], - ], - ) -> ( - BaseNode[ - _agent_graph.GraphAgentState, - _agent_graph.GraphAgentDeps[AgentDepsT, Any], - FinalResult[ResultDataT], - ] - | End[FinalResult[ResultDataT]] - ): + node: _agent_graph.AgentNode[AgentDepsT, ResultDataT], + ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]: """Manually drive the agent run by passing in the node you want to run next. This lets you inspect or mutate the node before continuing execution, or skip certain nodes @@ -1378,7 +1398,11 @@ async def main(): """ # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate. - return await self._graph_run.next(node) + next_node = await self._graph_run.next(node) + if _agent_graph.is_agent_node(next_node): + return next_node + assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' + return next_node def usage(self) -> _usage.Usage: """Get usage statistics for the run so far, including token usage, model requests, and so on.""" diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 661ef089d..bda8c6201 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -533,9 +533,24 @@ class PartDeltaEvent: """Event type identifier, used as a discriminator.""" +@dataclass +class FinalResultEvent: + """An event indicating the response to the current model request matches the result schema.""" + + tool_name: str | None + """The name of the result tool that was called. `None` if the result is from text content and not from a tool.""" + event_kind: Literal['final_result'] = 'final_result' + """Event type identifier, used as a discriminator.""" + + ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')] """An event in the model response stream, either starting a new part or applying a delta to an existing one.""" +AgentStreamEvent = Annotated[ + Union[PartStartEvent, PartDeltaEvent, FinalResultEvent], pydantic.Discriminator('event_kind') +] +"""An event in the agent stream.""" + @dataclass class FunctionToolCallEvent: @@ -558,7 +573,7 @@ class FunctionToolResultEvent: result: ToolReturnPart | RetryPromptPart """The result of the call to the function tool.""" - call_id: str + tool_call_id: str """An ID used to match the result to its original call.""" event_kind: Literal['function_tool_result'] = 'function_tool_result' """Event type identifier, used as a discriminator.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 084f18df8..1d6d25085 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -177,6 +177,8 @@ class DeltaToolCall: """Incremental change to the name of the tool.""" json_args: str | None = None """Incremental change to the arguments as JSON""" + tool_call_id: str | None = None + """Incremental change to the tool call ID.""" DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall] @@ -224,7 +226,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: vendor_part_id=dtc_index, tool_name=delta_tool_call.name, args=delta_tool_call.json_args, - tool_call_id=None, + tool_call_id=delta_tool_call.tool_call_id, ) if maybe_event is not None: yield maybe_event diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 7646de5bf..140e5a025 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -7,9 +7,10 @@ from typing import Generic, Union, cast import logfire_api -from typing_extensions import TypeVar +from typing_extensions import TypeVar, assert_type from . import _result, _utils, exceptions, messages as _messages, models +from .messages import AgentStreamEvent, FinalResultEvent from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits @@ -51,6 +52,125 @@ _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') +@dataclass +class AgentStream(Generic[AgentDepsT, ResultDataT]): + _raw_stream_response: models.StreamedResponse + _result_schema: _result.ResultSchema[ResultDataT] | None + _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] + _run_ctx: RunContext[AgentDepsT] + _usage_limits: UsageLimits | None + + _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) + _final_result_event: FinalResultEvent | None = field(default=None, init=False) + _initial_run_ctx_usage: Usage = field(init=False) + + def __post_init__(self): + self._initial_run_ctx_usage = copy(self._run_ctx.usage) + + async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]: + """Asynchronously stream the (validated) agent outputs.""" + async for response in self.stream_responses(debounce_by=debounce_by): + if self._final_result_event is not None: + yield await self._validate_response(response, self._final_result_event.tool_name, allow_partial=True) + if self._final_result_event is not None: + yield await self._validate_response( + self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False + ) + + async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]: + """Asynchronously stream the (unvalidated) model responses for the agent.""" + # if the message currently has any parts with content, yield before streaming + msg = self._raw_stream_response.get() + for part in msg.parts: + if part.has_content(): + yield msg + break + + async with _utils.group_by_temporal(self, debounce_by) as group_iter: + async for _items in group_iter: + yield self._raw_stream_response.get() # current state of the response + + def usage(self) -> Usage: + """Return the usage of the whole run. + + !!! note + This won't return the full usage until the stream is finished. + """ + return self._initial_run_ctx_usage + self._raw_stream_response.usage() + + async def _validate_response( + self, message: _messages.ModelResponse, result_tool_name: str | None, *, allow_partial: bool = False + ) -> ResultDataT: + """Validate a structured result message.""" + if self._result_schema is not None and result_tool_name is not None: + match = self._result_schema.find_named_tool(message.parts, result_tool_name) + if match is None: + raise exceptions.UnexpectedModelBehavior( + f'Invalid response, unable to find tool: {self._result_schema.tool_names()}' + ) + + call, result_tool = match + result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) + + for validator in self._result_validators: + result_data = await validator.validate(result_data, call, self._run_ctx) + return result_data + else: + text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) + for validator in self._result_validators: + text = await validator.validate( + text, + None, + self._run_ctx, + ) + # Since there is no result tool, we can assume that str is compatible with ResultDataT + return cast(ResultDataT, text) + + def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: + """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. + + This proxies the _raw_stream_response and sends all events to the agent stream, while also checking for matches + on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the + first match is found. + """ + if self._agent_stream_iterator is not None: + return self._agent_stream_iterator + + async def aiter(): + result_schema = self._result_schema + allow_text_result = result_schema is None or result_schema.allow_text_result + + def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None: + """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" + if isinstance(e, _messages.PartStartEvent): + new_part = e.part + if isinstance(new_part, _messages.ToolCallPart): + if result_schema is not None and (match := result_schema.find_tool([new_part])): + call, _ = match + return _messages.FinalResultEvent(tool_name=call.tool_name) + elif allow_text_result: + assert_type(e, _messages.PartStartEvent) + return _messages.FinalResultEvent(tool_name=None) + + usage_checking_stream = _get_usage_checking_stream_response( + self._raw_stream_response, self._usage_limits, self.usage + ) + async for event in usage_checking_stream: + yield event + if (final_result_event := _get_final_result_event(event)) is not None: + self._final_result_event = final_result_event + yield final_result_event + break + + # If we broke out of the above loop, we need to yield the rest of the events + # If we didn't, this will just be a no-op + async for event in usage_checking_stream: + yield event + + self._agent_stream_iterator = aiter() + return self._agent_stream_iterator + + @dataclass class StreamedRunResult(Generic[AgentDepsT, ResultDataT]): """Result of a streamed run that returns structured data via a tool call.""" diff --git a/tests/test_examples.py b/tests/test_examples.py index 6974d319a..3499a1918 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -182,6 +182,9 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: 'What is the weather like in West London and in Wiltshire?': ( 'The weather in West London is raining, while in Wiltshire it is sunny.' ), + 'What will the weather be like in Paris on Tuesday?': ToolCallPart( + tool_name='weather_forecast', args={'location': 'Paris', 'forecast_date': '2030-01-01'}, tool_call_id='0001' + ), 'Tell me a joke.': 'Did you hear about the toothpaste scandal? They called it Colgate.', 'Explain?': 'This is an excellent joke invented by Samuel Colvin, it needs no explanation.', 'What is the capital of France?': 'Paris', @@ -270,6 +273,13 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: ), } +tool_responses: dict[tuple[str, str], str] = { + ( + 'weather_forecast', + 'The forecast in Paris on 2030-01-01 is 24°C and sunny.', + ): 'It will be warm and sunny in Paris on Tuesday.', +} + async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # pragma: no cover # noqa: C901 m = messages[-1].parts[-1] @@ -348,35 +358,53 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes raise RuntimeError(f'Unexpected message: {m}') -async def stream_model_logic( +async def stream_model_logic( # noqa C901 messages: list[ModelMessage], info: AgentInfo ) -> AsyncIterator[str | DeltaToolCalls]: # pragma: no cover - m = messages[-1].parts[-1] - if isinstance(m, UserPromptPart): - assert isinstance(m.content, str) - if response := text_responses.get(m.content): - if isinstance(response, str): - words = response.split(' ') - chunk: list[str] = [] - for work in words: - chunk.append(work) - if len(chunk) == 3: - yield ' '.join(chunk) + ' ' - chunk.clear() - if chunk: - yield ' '.join(chunk) - return - else: - json_text = response.args_as_json_str() - - yield {1: DeltaToolCall(name=response.tool_name)} - for chunk_index in range(0, len(json_text), 15): - text_chunk = json_text[chunk_index : chunk_index + 15] - yield {1: DeltaToolCall(json_args=text_chunk)} - return + async def stream_text_response(r: str) -> AsyncIterator[str]: + if isinstance(r, str): + words = r.split(' ') + chunk: list[str] = [] + for word in words: + chunk.append(word) + if len(chunk) == 3: + yield ' '.join(chunk) + ' ' + chunk.clear() + if chunk: + yield ' '.join(chunk) + + async def stream_tool_call_response(r: ToolCallPart) -> AsyncIterator[DeltaToolCalls]: + json_text = r.args_as_json_str() + + yield {1: DeltaToolCall(name=r.tool_name, tool_call_id=r.tool_call_id)} + for chunk_index in range(0, len(json_text), 15): + text_chunk = json_text[chunk_index : chunk_index + 15] + yield {1: DeltaToolCall(json_args=text_chunk)} + + async def stream_part_response(r: str | ToolCallPart) -> AsyncIterator[str | DeltaToolCalls]: + if isinstance(r, str): + async for chunk in stream_text_response(r): + yield chunk + else: + async for chunk in stream_tool_call_response(r): + yield chunk + + last_part = messages[-1].parts[-1] + if isinstance(last_part, UserPromptPart): + assert isinstance(last_part.content, str) + if response := text_responses.get(last_part.content): + async for chunk in stream_part_response(response): + yield chunk + return + elif isinstance(last_part, ToolReturnPart): + assert isinstance(last_part.content, str) + if response := tool_responses.get((last_part.tool_name, last_part.content)): + async for chunk in stream_part_response(response): + yield chunk + return sys.stdout.write(str(debug.format(messages, info))) - raise RuntimeError(f'Unexpected message: {m}') + raise RuntimeError(f'Unexpected message: {last_part}') def mock_infer_model(model: Model | KnownModelName) -> Model: diff --git a/tests/test_streaming.py b/tests/test_streaming.py index b4612e6c8..92bcc4a73 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -2,14 +2,18 @@ import datetime import json +import re from collections.abc import AsyncIterator +from copy import deepcopy from datetime import timezone +from typing import Union import pytest from inline_snapshot import snapshot from pydantic import BaseModel from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages +from pydantic_ai.agent import AgentRun from pydantic_ai.messages import ( ModelMessage, ModelRequest, @@ -22,7 +26,8 @@ ) from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import Usage +from pydantic_ai.result import AgentStream, FinalResult, Usage +from pydantic_graph import End from .conftest import IsNow @@ -739,3 +744,109 @@ async def test_custom_result_type_default_structured() -> None: async with agent.run_stream('test', result_type=str) as result: response = await result.get_data() assert response == snapshot('success (no tool calls)') + + +async def test_iter_stream_output(): + m = TestModel(custom_result_text='The cat sat on the mat.') + + agent = Agent(m) + + @agent.result_validator + def result_validator_simple(data: str) -> str: + # Make a substitution in the validated results + return re.sub('cat sat', 'bat sat', data) + + run: AgentRun + stream: AgentStream + messages: list[str] = [] + + stream_usage: Usage | None = None + with agent.iter('Hello') as run: + async for node in run: + if agent.is_model_request_node(node): + async with node.stream(run.ctx) as stream: + async for chunk in stream.stream_output(debounce_by=None): + messages.append(chunk) + stream_usage = deepcopy(stream.usage()) + assert run.next_node == End(data=FinalResult(data='The bat sat on the mat.', tool_name=None)) + assert ( + run.usage() + == stream_usage + == Usage(requests=1, request_tokens=51, response_tokens=7, total_tokens=58, details=None) + ) + + assert messages == [ + '', + 'The ', + 'The cat ', + 'The bat sat ', + 'The bat sat on ', + 'The bat sat on the ', + 'The bat sat on the mat.', + 'The bat sat on the mat.', + ] + + +async def test_iter_stream_responses(): + m = TestModel(custom_result_text='The cat sat on the mat.') + + agent = Agent(m) + + @agent.result_validator + def result_validator_simple(data: str) -> str: + # Make a substitution in the validated results + return re.sub('cat sat', 'bat sat', data) + + run: AgentRun + stream: AgentStream + messages: list[ModelResponse] = [] + with agent.iter('Hello') as run: + async for node in run: + if agent.is_model_request_node(node): + async with node.stream(run.ctx) as stream: + async for chunk in stream.stream_responses(debounce_by=None): + messages.append(chunk) + + assert messages == [ + ModelResponse( + parts=[TextPart(content=text, part_kind='text')], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + kind='response', + ) + for text in [ + '', + '', + 'The ', + 'The cat ', + 'The cat sat ', + 'The cat sat on ', + 'The cat sat on the ', + 'The cat sat on the mat.', + ] + ] + + # Note: as you can see above, the result validator is not applied to the streamed responses, just the final result: + assert run.result is not None + assert run.result.data == 'The bat sat on the mat.' + + +async def test_stream_iter_structured_validator() -> None: + class NotResultType(BaseModel): + not_value: str + + agent = Agent[None, Union[ResultType, NotResultType]]('test', result_type=Union[ResultType, NotResultType]) # pyright: ignore[reportArgumentType] + + @agent.result_validator + def result_validator(data: ResultType | NotResultType) -> ResultType | NotResultType: + assert isinstance(data, ResultType) + return ResultType(value=data.value + ' (validated)') + + outputs: list[ResultType] = [] + with agent.iter('test') as run: + async for node in run: + if agent.is_model_request_node(node): + async with node.stream(run.ctx) as stream: + async for output in stream.stream_output(debounce_by=None): + outputs.append(output) + assert outputs == [ResultType(value='a (validated)'), ResultType(value='a (validated)')]