2
2
3
3
import asyncio
4
4
import dataclasses
5
+ import uuid
5
6
from abc import ABC
6
7
from collections .abc import AsyncIterator , Iterator , Sequence
7
8
from contextlib import asynccontextmanager , contextmanager
@@ -267,43 +268,93 @@ async def run(
267
268
268
269
@dataclasses .dataclass
269
270
class HandleResponseNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], NodeRunEndT ]):
270
- """Process e response from a model, decide whether to end the run or make a new request."""
271
+ """Process the response from a model, decide whether to end the run or make a new request."""
271
272
272
273
model_response : _messages .ModelResponse
273
274
275
+ _stream : AsyncIterator [_messages .HandleResponseEvent ] | None = field (default = None , repr = False )
276
+ _next_node : ModelRequestNode [DepsT , NodeRunEndT ] | FinalResultNode [DepsT , NodeRunEndT ] | None = field (
277
+ default = None , repr = False
278
+ )
279
+ _tool_responses : list [_messages .ModelRequestPart ] = field (default_factory = list , repr = False )
280
+
274
281
async def run (
275
282
self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
276
283
) -> Union [ModelRequestNode [DepsT , NodeRunEndT ], FinalResultNode [DepsT , NodeRunEndT ]]: # noqa UP007
284
+ async with self .run_stream (ctx ):
285
+ pass
286
+
287
+ # the stream should set `self._next_node` before it ends:
288
+ assert (next_node := self ._next_node ) is not None
289
+ return next_node
290
+
291
+ @asynccontextmanager
292
+ async def run_stream (
293
+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]
294
+ ) -> AsyncIterator [AsyncIterator [_messages .HandleResponseEvent ]]:
277
295
with _logfire .span ('handle model response' , run_step = ctx .state .run_step ) as handle_span :
278
- texts : list [str ] = []
279
- tool_calls : list [_messages .ToolCallPart ] = []
280
- for part in self .model_response .parts :
281
- if isinstance (part , _messages .TextPart ):
282
- # ignore empty content for text parts, see #437
283
- if part .content :
284
- texts .append (part .content )
285
- elif isinstance (part , _messages .ToolCallPart ):
286
- tool_calls .append (part )
296
+ stream = self ._run_stream (ctx )
297
+ yield stream
298
+
299
+ # Run the stream to completion if it was not finished:
300
+ async for _event in stream :
301
+ pass
302
+
303
+ # Set the next node based on the final state of the stream
304
+ next_node = self ._next_node
305
+ if isinstance (next_node , FinalResultNode ):
306
+ handle_span .set_attribute ('result' , next_node .data )
307
+ handle_span .message = 'handle model response -> final result'
308
+ elif tool_responses := self ._tool_responses :
309
+ # TODO: We could drop `self._tool_responses` if we drop this set_attribute
310
+ # I'm thinking it might be better to just create a span for the handling of each tool
311
+ # than to set an attribute here.
312
+ handle_span .set_attribute ('tool_responses' , tool_responses )
313
+ tool_responses_str = ' ' .join (r .part_kind for r in tool_responses )
314
+ handle_span .message = f'handle model response -> { tool_responses_str } '
315
+
316
+ async def _run_stream (
317
+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]
318
+ ) -> AsyncIterator [_messages .HandleResponseEvent ]:
319
+ if self ._stream is None :
320
+ # Ensure that the stream is only run once
321
+
322
+ async def _run_stream () -> AsyncIterator [_messages .HandleResponseEvent ]:
323
+ texts : list [str ] = []
324
+ tool_calls : list [_messages .ToolCallPart ] = []
325
+ for part in self .model_response .parts :
326
+ if isinstance (part , _messages .TextPart ):
327
+ # ignore empty content for text parts, see #437
328
+ if part .content :
329
+ texts .append (part .content )
330
+ elif isinstance (part , _messages .ToolCallPart ):
331
+ tool_calls .append (part )
332
+ else :
333
+ assert_never (part )
334
+
335
+ # At the moment, we prioritize at least executing tool calls if they are present.
336
+ # In the future, we'd consider making this configurable at the agent or run level.
337
+ # This accounts for cases like anthropic returns that might contain a text response
338
+ # and a tool call response, where the text response just indicates the tool call will happen.
339
+ if tool_calls :
340
+ async for event in self ._handle_tool_calls (ctx , tool_calls ):
341
+ yield event
342
+ elif texts :
343
+ # No events are emitted during the handling of text responses, so we don't need to yield anything
344
+ self ._next_node = await self ._handle_text_response (ctx , texts )
287
345
else :
288
- assert_never (part )
289
-
290
- # At the moment, we prioritize at least executing tool calls if they are present.
291
- # In the future, we'd consider making this configurable at the agent or run level.
292
- # This accounts for cases like anthropic returns that might contain a text response
293
- # and a tool call response, where the text response just indicates the tool call will happen.
294
- if tool_calls :
295
- return await self ._handle_tool_calls_response (ctx , tool_calls , handle_span )
296
- elif texts :
297
- return await self ._handle_text_response (ctx , texts , handle_span )
298
- else :
299
- raise exceptions .UnexpectedModelBehavior ('Received empty model response' )
346
+ raise exceptions .UnexpectedModelBehavior ('Received empty model response' )
347
+
348
+ self ._stream = _run_stream ()
349
+
350
+ async for event in self ._stream :
351
+ yield event
300
352
301
- async def _handle_tool_calls_response (
353
+ async def _handle_tool_calls (
302
354
self ,
303
355
ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
304
356
tool_calls : list [_messages .ToolCallPart ],
305
- handle_span : logfire_api .LogfireSpan ,
306
- ):
357
+ ) -> AsyncIterator [_messages .HandleResponseEvent ]:
307
358
result_schema = ctx .deps .result_schema
308
359
309
360
# first look for the result tool call
@@ -324,26 +375,24 @@ async def _handle_tool_calls_response(
324
375
final_result = MarkFinalResult (result_data , call .tool_name )
325
376
326
377
# Then build the other request parts based on end strategy
327
- tool_responses = await _process_function_tools (tool_calls , final_result and final_result .tool_name , ctx )
378
+ tool_responses : list [_messages .ModelRequestPart ] = self ._tool_responses
379
+ async for event in _process_function_tools (
380
+ tool_calls , final_result and final_result .tool_name , ctx , tool_responses
381
+ ):
382
+ yield event
328
383
329
384
if final_result :
330
- handle_span .set_attribute ('result' , final_result .data )
331
- handle_span .message = 'handle model response -> final result'
332
- return FinalResultNode [DepsT , NodeRunEndT ](final_result , tool_responses )
385
+ self ._next_node = FinalResultNode [DepsT , NodeRunEndT ](final_result , tool_responses )
333
386
else :
334
387
if tool_responses :
335
- handle_span .set_attribute ('tool_responses' , tool_responses )
336
- tool_responses_str = ' ' .join (r .part_kind for r in tool_responses )
337
- handle_span .message = f'handle model response -> { tool_responses_str } '
338
388
parts .extend (tool_responses )
339
- return ModelRequestNode [DepsT , NodeRunEndT ](_messages .ModelRequest (parts = parts ))
389
+ self . _next_node = ModelRequestNode [DepsT , NodeRunEndT ](_messages .ModelRequest (parts = parts ))
340
390
341
391
async def _handle_text_response (
342
392
self ,
343
393
ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
344
394
texts : list [str ],
345
- handle_span : logfire_api .LogfireSpan ,
346
- ):
395
+ ) -> ModelRequestNode [DepsT , NodeRunEndT ] | FinalResultNode [DepsT , NodeRunEndT ]:
347
396
result_schema = ctx .deps .result_schema
348
397
349
398
text = '\n \n ' .join (texts )
@@ -355,8 +404,6 @@ async def _handle_text_response(
355
404
ctx .state .increment_retries (ctx .deps .max_result_retries )
356
405
return ModelRequestNode [DepsT , NodeRunEndT ](_messages .ModelRequest (parts = [e .tool_retry ]))
357
406
else :
358
- handle_span .set_attribute ('result' , result_data )
359
- handle_span .message = 'handle model response -> final result'
360
407
return FinalResultNode [DepsT , NodeRunEndT ](MarkFinalResult (result_data , None ))
361
408
else :
362
409
ctx .state .increment_retries (ctx .deps .max_result_retries )
@@ -560,11 +607,15 @@ async def on_complete():
560
607
last_message = messages [- 1 ]
561
608
assert isinstance (last_message , _messages .ModelResponse )
562
609
tool_calls = [part for part in last_message .parts if isinstance (part , _messages .ToolCallPart )]
563
- parts = await _process_function_tools (
610
+
611
+ parts : list [_messages .ModelRequestPart ] = []
612
+ async for _event in _process_function_tools (
564
613
tool_calls ,
565
614
result_tool_name ,
566
615
ctx ,
567
- )
616
+ parts ,
617
+ ):
618
+ pass
568
619
# TODO: Should we do something here related to the retry count?
569
620
# Maybe we should move the incrementing of the retry count to where we actually make a request?
570
621
# if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
@@ -590,25 +641,27 @@ async def _process_function_tools(
590
641
tool_calls : list [_messages .ToolCallPart ],
591
642
result_tool_name : str | None ,
592
643
ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
593
- ) -> list [_messages .ModelRequestPart ]:
644
+ output_parts : list [_messages .ModelRequestPart ],
645
+ ) -> AsyncIterator [_messages .HandleResponseEvent ]:
594
646
"""Process function (non-result) tool calls in parallel.
595
647
596
648
Also add stub return parts for any other tools that need it.
597
- """
598
- parts : list [_messages .ModelRequestPart ] = []
599
- tasks : list [asyncio .Task [_messages .ToolReturnPart | _messages .RetryPromptPart ]] = []
600
649
650
+ Because async iterators can't have return values, we use `parts` as an output argument.
651
+ """
601
652
stub_function_tools = bool (result_tool_name ) and ctx .deps .end_strategy == 'early'
602
653
result_schema = ctx .deps .result_schema
603
654
604
655
# we rely on the fact that if we found a result, it's the first result tool in the last
605
656
found_used_result_tool = False
606
657
run_context = _build_run_context (ctx )
607
658
659
+ calls_to_run : list [tuple [Tool [DepsT ], _messages .ToolCallPart ]] = []
660
+ call_index_to_event_id : dict [int , uuid .UUID ] = {}
608
661
for call in tool_calls :
609
662
if call .tool_name == result_tool_name and not found_used_result_tool :
610
663
found_used_result_tool = True
611
- parts .append (
664
+ output_parts .append (
612
665
_messages .ToolReturnPart (
613
666
tool_name = call .tool_name ,
614
667
content = 'Final result processed.' ,
@@ -617,41 +670,55 @@ async def _process_function_tools(
617
670
)
618
671
elif tool := ctx .deps .function_tools .get (call .tool_name ):
619
672
if stub_function_tools :
620
- parts .append (
673
+ output_parts .append (
621
674
_messages .ToolReturnPart (
622
675
tool_name = call .tool_name ,
623
676
content = 'Tool not executed - a final result was already processed.' ,
624
677
tool_call_id = call .tool_call_id ,
625
678
)
626
679
)
627
680
else :
628
- tasks .append (asyncio .create_task (tool .run (call , run_context ), name = call .tool_name ))
681
+ event = _messages .FunctionToolCallEvent (call )
682
+ yield event
683
+ call_index_to_event_id [len (calls_to_run )] = event .call_id
684
+ calls_to_run .append ((tool , call ))
629
685
elif result_schema is not None and call .tool_name in result_schema .tools :
630
686
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
631
687
# validation, we don't add another part here
632
688
if result_tool_name is not None :
633
- parts .append (
634
- _messages .ToolReturnPart (
635
- tool_name = call .tool_name ,
636
- content = 'Result tool not used - a final result was already processed.' ,
637
- tool_call_id = call .tool_call_id ,
638
- )
689
+ part = _messages .ToolReturnPart (
690
+ tool_name = call .tool_name ,
691
+ content = 'Result tool not used - a final result was already processed.' ,
692
+ tool_call_id = call .tool_call_id ,
639
693
)
694
+ output_parts .append (part )
640
695
else :
641
- parts .append (_unknown_tool (call .tool_name , ctx ))
696
+ output_parts .append (_unknown_tool (call .tool_name , ctx ))
697
+
698
+ if not calls_to_run :
699
+ return
642
700
643
701
# Run all tool tasks in parallel
644
- if tasks :
645
- with _logfire .span ('running {tools=}' , tools = [t .get_name () for t in tasks ]):
646
- task_results : Sequence [_messages .ToolReturnPart | _messages .RetryPromptPart ] = await asyncio .gather (* tasks )
647
- for result in task_results :
648
- if isinstance (result , _messages .ToolReturnPart ):
649
- parts .append (result )
650
- elif isinstance (result , _messages .RetryPromptPart ):
651
- parts .append (result )
702
+ results_by_index : dict [int , _messages .ModelRequestPart ] = {}
703
+ with _logfire .span ('running {tools=}' , tools = [call .tool_name for _ , call in calls_to_run ]):
704
+ # TODO: Should we wrap each individual tool call in a dedicated span?
705
+ tasks = [asyncio .create_task (tool .run (call , run_context ), name = call .tool_name ) for tool , call in calls_to_run ]
706
+ pending = tasks
707
+ while pending :
708
+ done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
709
+ for task in done :
710
+ index = tasks .index (task )
711
+ result = task .result ()
712
+ yield _messages .FunctionToolResultEvent (result , call_id = call_index_to_event_id [index ])
713
+ if isinstance (result , (_messages .ToolReturnPart , _messages .RetryPromptPart )):
714
+ results_by_index [index ] = result
652
715
else :
653
716
assert_never (result )
654
- return parts
717
+
718
+ # We append the results at the end, rather than as they are received, to retain a consistent ordering
719
+ # This is mostly just to simplify testing
720
+ for k in sorted (results_by_index ):
721
+ output_parts .append (results_by_index [k ])
655
722
656
723
657
724
def _unknown_tool (
0 commit comments