Skip to content

Commit 1443b49

Browse files
committed
Add an API for streaming tool calls performed by HandleResponseNode
1 parent 6c00d1d commit 1443b49

File tree

2 files changed

+157
-62
lines changed

2 files changed

+157
-62
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

+129-62
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import dataclasses
5+
import uuid
56
from abc import ABC
67
from collections.abc import AsyncIterator, Iterator, Sequence
78
from contextlib import asynccontextmanager, contextmanager
@@ -267,43 +268,93 @@ async def run(
267268

268269
@dataclasses.dataclass
269270
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
270-
"""Process e response from a model, decide whether to end the run or make a new request."""
271+
"""Process the response from a model, decide whether to end the run or make a new request."""
271272

272273
model_response: _messages.ModelResponse
273274

275+
_stream: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False)
276+
_next_node: ModelRequestNode[DepsT, NodeRunEndT] | FinalResultNode[DepsT, NodeRunEndT] | None = field(
277+
default=None, repr=False
278+
)
279+
_tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False)
280+
274281
async def run(
275282
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
276283
) -> Union[ModelRequestNode[DepsT, NodeRunEndT], FinalResultNode[DepsT, NodeRunEndT]]: # noqa UP007
284+
async with self.run_stream(ctx):
285+
pass
286+
287+
# the stream should set `self._next_node` before it ends:
288+
assert (next_node := self._next_node) is not None
289+
return next_node
290+
291+
@asynccontextmanager
292+
async def run_stream(
293+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
294+
) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
277295
with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
278-
texts: list[str] = []
279-
tool_calls: list[_messages.ToolCallPart] = []
280-
for part in self.model_response.parts:
281-
if isinstance(part, _messages.TextPart):
282-
# ignore empty content for text parts, see #437
283-
if part.content:
284-
texts.append(part.content)
285-
elif isinstance(part, _messages.ToolCallPart):
286-
tool_calls.append(part)
296+
stream = self._run_stream(ctx)
297+
yield stream
298+
299+
# Run the stream to completion if it was not finished:
300+
async for _event in stream:
301+
pass
302+
303+
# Set the next node based on the final state of the stream
304+
next_node = self._next_node
305+
if isinstance(next_node, FinalResultNode):
306+
handle_span.set_attribute('result', next_node.data)
307+
handle_span.message = 'handle model response -> final result'
308+
elif tool_responses := self._tool_responses:
309+
# TODO: We could drop `self._tool_responses` if we drop this set_attribute
310+
# I'm thinking it might be better to just create a span for the handling of each tool
311+
# than to set an attribute here.
312+
handle_span.set_attribute('tool_responses', tool_responses)
313+
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
314+
handle_span.message = f'handle model response -> {tool_responses_str}'
315+
316+
async def _run_stream(
317+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
318+
) -> AsyncIterator[_messages.HandleResponseEvent]:
319+
if self._stream is None:
320+
# Ensure that the stream is only run once
321+
322+
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
323+
texts: list[str] = []
324+
tool_calls: list[_messages.ToolCallPart] = []
325+
for part in self.model_response.parts:
326+
if isinstance(part, _messages.TextPart):
327+
# ignore empty content for text parts, see #437
328+
if part.content:
329+
texts.append(part.content)
330+
elif isinstance(part, _messages.ToolCallPart):
331+
tool_calls.append(part)
332+
else:
333+
assert_never(part)
334+
335+
# At the moment, we prioritize at least executing tool calls if they are present.
336+
# In the future, we'd consider making this configurable at the agent or run level.
337+
# This accounts for cases like anthropic returns that might contain a text response
338+
# and a tool call response, where the text response just indicates the tool call will happen.
339+
if tool_calls:
340+
async for event in self._handle_tool_calls(ctx, tool_calls):
341+
yield event
342+
elif texts:
343+
# No events are emitted during the handling of text responses, so we don't need to yield anything
344+
self._next_node = await self._handle_text_response(ctx, texts)
287345
else:
288-
assert_never(part)
289-
290-
# At the moment, we prioritize at least executing tool calls if they are present.
291-
# In the future, we'd consider making this configurable at the agent or run level.
292-
# This accounts for cases like anthropic returns that might contain a text response
293-
# and a tool call response, where the text response just indicates the tool call will happen.
294-
if tool_calls:
295-
return await self._handle_tool_calls_response(ctx, tool_calls, handle_span)
296-
elif texts:
297-
return await self._handle_text_response(ctx, texts, handle_span)
298-
else:
299-
raise exceptions.UnexpectedModelBehavior('Received empty model response')
346+
raise exceptions.UnexpectedModelBehavior('Received empty model response')
347+
348+
self._stream = _run_stream()
349+
350+
async for event in self._stream:
351+
yield event
300352

301-
async def _handle_tool_calls_response(
353+
async def _handle_tool_calls(
302354
self,
303355
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
304356
tool_calls: list[_messages.ToolCallPart],
305-
handle_span: logfire_api.LogfireSpan,
306-
):
357+
) -> AsyncIterator[_messages.HandleResponseEvent]:
307358
result_schema = ctx.deps.result_schema
308359

309360
# first look for the result tool call
@@ -324,26 +375,24 @@ async def _handle_tool_calls_response(
324375
final_result = MarkFinalResult(result_data, call.tool_name)
325376

326377
# Then build the other request parts based on end strategy
327-
tool_responses = await _process_function_tools(tool_calls, final_result and final_result.tool_name, ctx)
378+
tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
379+
async for event in _process_function_tools(
380+
tool_calls, final_result and final_result.tool_name, ctx, tool_responses
381+
):
382+
yield event
328383

329384
if final_result:
330-
handle_span.set_attribute('result', final_result.data)
331-
handle_span.message = 'handle model response -> final result'
332-
return FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses)
385+
self._next_node = FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses)
333386
else:
334387
if tool_responses:
335-
handle_span.set_attribute('tool_responses', tool_responses)
336-
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
337-
handle_span.message = f'handle model response -> {tool_responses_str}'
338388
parts.extend(tool_responses)
339-
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
389+
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts))
340390

341391
async def _handle_text_response(
342392
self,
343393
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
344394
texts: list[str],
345-
handle_span: logfire_api.LogfireSpan,
346-
):
395+
) -> ModelRequestNode[DepsT, NodeRunEndT] | FinalResultNode[DepsT, NodeRunEndT]:
347396
result_schema = ctx.deps.result_schema
348397

349398
text = '\n\n'.join(texts)
@@ -355,8 +404,6 @@ async def _handle_text_response(
355404
ctx.state.increment_retries(ctx.deps.max_result_retries)
356405
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
357406
else:
358-
handle_span.set_attribute('result', result_data)
359-
handle_span.message = 'handle model response -> final result'
360407
return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None))
361408
else:
362409
ctx.state.increment_retries(ctx.deps.max_result_retries)
@@ -560,11 +607,15 @@ async def on_complete():
560607
last_message = messages[-1]
561608
assert isinstance(last_message, _messages.ModelResponse)
562609
tool_calls = [part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)]
563-
parts = await _process_function_tools(
610+
611+
parts: list[_messages.ModelRequestPart] = []
612+
async for _event in _process_function_tools(
564613
tool_calls,
565614
result_tool_name,
566615
ctx,
567-
)
616+
parts,
617+
):
618+
pass
568619
# TODO: Should we do something here related to the retry count?
569620
# Maybe we should move the incrementing of the retry count to where we actually make a request?
570621
# if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
@@ -590,25 +641,27 @@ async def _process_function_tools(
590641
tool_calls: list[_messages.ToolCallPart],
591642
result_tool_name: str | None,
592643
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
593-
) -> list[_messages.ModelRequestPart]:
644+
output_parts: list[_messages.ModelRequestPart],
645+
) -> AsyncIterator[_messages.HandleResponseEvent]:
594646
"""Process function (non-result) tool calls in parallel.
595647
596648
Also add stub return parts for any other tools that need it.
597-
"""
598-
parts: list[_messages.ModelRequestPart] = []
599-
tasks: list[asyncio.Task[_messages.ToolReturnPart | _messages.RetryPromptPart]] = []
600649
650+
Because async iterators can't have return values, we use `parts` as an output argument.
651+
"""
601652
stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early'
602653
result_schema = ctx.deps.result_schema
603654

604655
# we rely on the fact that if we found a result, it's the first result tool in the last
605656
found_used_result_tool = False
606657
run_context = _build_run_context(ctx)
607658

659+
calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
660+
call_index_to_event_id: dict[int, uuid.UUID] = {}
608661
for call in tool_calls:
609662
if call.tool_name == result_tool_name and not found_used_result_tool:
610663
found_used_result_tool = True
611-
parts.append(
664+
output_parts.append(
612665
_messages.ToolReturnPart(
613666
tool_name=call.tool_name,
614667
content='Final result processed.',
@@ -617,41 +670,55 @@ async def _process_function_tools(
617670
)
618671
elif tool := ctx.deps.function_tools.get(call.tool_name):
619672
if stub_function_tools:
620-
parts.append(
673+
output_parts.append(
621674
_messages.ToolReturnPart(
622675
tool_name=call.tool_name,
623676
content='Tool not executed - a final result was already processed.',
624677
tool_call_id=call.tool_call_id,
625678
)
626679
)
627680
else:
628-
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
681+
event = _messages.FunctionToolCallEvent(call)
682+
yield event
683+
call_index_to_event_id[len(calls_to_run)] = event.call_id
684+
calls_to_run.append((tool, call))
629685
elif result_schema is not None and call.tool_name in result_schema.tools:
630686
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
631687
# validation, we don't add another part here
632688
if result_tool_name is not None:
633-
parts.append(
634-
_messages.ToolReturnPart(
635-
tool_name=call.tool_name,
636-
content='Result tool not used - a final result was already processed.',
637-
tool_call_id=call.tool_call_id,
638-
)
689+
part = _messages.ToolReturnPart(
690+
tool_name=call.tool_name,
691+
content='Result tool not used - a final result was already processed.',
692+
tool_call_id=call.tool_call_id,
639693
)
694+
output_parts.append(part)
640695
else:
641-
parts.append(_unknown_tool(call.tool_name, ctx))
696+
output_parts.append(_unknown_tool(call.tool_name, ctx))
697+
698+
if not calls_to_run:
699+
return
642700

643701
# Run all tool tasks in parallel
644-
if tasks:
645-
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
646-
task_results: Sequence[_messages.ToolReturnPart | _messages.RetryPromptPart] = await asyncio.gather(*tasks)
647-
for result in task_results:
648-
if isinstance(result, _messages.ToolReturnPart):
649-
parts.append(result)
650-
elif isinstance(result, _messages.RetryPromptPart):
651-
parts.append(result)
702+
results_by_index: dict[int, _messages.ModelRequestPart] = {}
703+
with _logfire.span('running {tools=}', tools=[call.tool_name for _, call in calls_to_run]):
704+
# TODO: Should we wrap each individual tool call in a dedicated span?
705+
tasks = [asyncio.create_task(tool.run(call, run_context), name=call.tool_name) for tool, call in calls_to_run]
706+
pending = tasks
707+
while pending:
708+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
709+
for task in done:
710+
index = tasks.index(task)
711+
result = task.result()
712+
yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index])
713+
if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
714+
results_by_index[index] = result
652715
else:
653716
assert_never(result)
654-
return parts
717+
718+
# We append the results at the end, rather than as they are received, to retain a consistent ordering
719+
# This is mostly just to simplify testing
720+
for k in sorted(results_by_index):
721+
output_parts.append(results_by_index[k])
655722

656723

657724
def _unknown_tool(

pydantic_ai_slim/pydantic_ai/messages.py

+28
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations as _annotations
22

3+
import uuid
34
from dataclasses import dataclass, field, replace
45
from datetime import datetime
56
from typing import Annotated, Any, Literal, Union, cast, overload
@@ -445,3 +446,30 @@ class PartDeltaEvent:
445446

446447
ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
447448
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""
449+
450+
451+
@dataclass
452+
class FunctionToolCallEvent:
453+
"""An event indicating the start to a call to a function tool."""
454+
455+
part: ToolCallPart
456+
"""The (function) tool call to make."""
457+
call_id: uuid.UUID = field(default_factory=uuid.uuid4, repr=False)
458+
"""An ID used to match the call to its result."""
459+
event_kind: Literal['function_tool_call'] = field(default='function_tool_call', repr=False)
460+
"""Event type identifier, used as a discriminator."""
461+
462+
463+
@dataclass
464+
class FunctionToolResultEvent:
465+
"""An event indicating the result of a function tool call."""
466+
467+
result: ToolReturnPart | RetryPromptPart
468+
"""The result of the call to the function tool."""
469+
call_id: uuid.UUID = field(default_factory=uuid.uuid4, repr=False)
470+
"""An ID used to match the result to its original call."""
471+
event_kind: Literal['function_tool_result'] = field(default='function_tool_result', repr=False)
472+
"""Event type identifier, used as a discriminator."""
473+
474+
475+
HandleResponseEvent = Annotated[Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('kind')]

0 commit comments

Comments
 (0)