3
3
import asyncio
4
4
import inspect
5
5
import types
6
- from collections .abc import Sequence
6
+ from collections .abc import AsyncGenerator , Sequence
7
7
from contextlib import ExitStack
8
8
from dataclasses import dataclass , field
9
9
from functools import cached_property
@@ -170,7 +170,7 @@ async def main():
170
170
if infer_name and self .name is None :
171
171
self ._infer_name (inspect .currentframe ())
172
172
173
- history : list [ HistoryStep [ StateT , T ]] = []
173
+ graph_run = GraphRun [ StateT , DepsT , T ]( self , state = state , deps = deps )
174
174
with ExitStack () as stack :
175
175
run_span : logfire_api .LogfireSpan | None = None
176
176
if self ._auto_instrument :
@@ -184,19 +184,12 @@ async def main():
184
184
185
185
next_node = start_node
186
186
while True :
187
- next_node = await self .next (next_node , history , state = state , deps = deps , infer_name = False )
187
+ next_node = await graph_run .next (next_node )
188
188
if isinstance (next_node , End ):
189
- history . append ( EndStep ( result = next_node ))
189
+ history = graph_run . history
190
190
if run_span is not None :
191
191
run_span .set_attribute ('history' , history )
192
192
return next_node .data , history
193
- elif not isinstance (next_node , BaseNode ):
194
- if TYPE_CHECKING :
195
- typing_extensions .assert_never (next_node )
196
- else :
197
- raise exceptions .GraphRuntimeError (
198
- f'Invalid node return type: `{ type (next_node ).__name__ } `. Expected `BaseNode` or `End`.'
199
- )
200
193
201
194
def run_sync (
202
195
self : Graph [StateT , DepsT , T ],
@@ -510,3 +503,64 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None:
510
503
if item is self :
511
504
self .name = name
512
505
return
506
+
507
+
508
+ class GraphRun (Generic [StateT , DepsT , RunEndT ]):
509
+ def __init__ (
510
+ self ,
511
+ graph : Graph [StateT , DepsT , RunEndT ],
512
+ * ,
513
+ state : StateT = None ,
514
+ deps : DepsT = None ,
515
+ ):
516
+ self .graph = graph
517
+ self .state = state
518
+ self .deps = deps
519
+
520
+ self .history : list [HistoryStep [StateT , RunEndT ]] = []
521
+ self .final_result : End [RunEndT ] | None = None
522
+
523
+ self ._agen : (
524
+ AsyncGenerator [BaseNode [StateT , DepsT , RunEndT ] | End [RunEndT ], BaseNode [StateT , DepsT , RunEndT ]] | None
525
+ ) = None
526
+
527
+ async def next (
528
+ self : GraphRun [StateT , DepsT , T ], node : BaseNode [StateT , DepsT , T ]
529
+ ) -> BaseNode [StateT , DepsT , Any ] | End [T ]:
530
+ agen = await self ._get_primed_agen ()
531
+ return await agen .asend (node )
532
+
533
+ async def _get_primed_agen (
534
+ self : GraphRun [StateT , DepsT , T ],
535
+ ) -> AsyncGenerator [BaseNode [StateT , DepsT , T ] | End [T ], BaseNode [StateT , DepsT , T ]]:
536
+ graph = self .graph
537
+ state = self .state
538
+ deps = self .deps
539
+ history = self .history
540
+
541
+ if self ._agen is None :
542
+
543
+ async def _agen () -> AsyncGenerator [BaseNode [StateT , DepsT , T ] | End [T ], BaseNode [StateT , DepsT , T ]]:
544
+ next_node = yield # pyright: ignore[reportReturnType] # we prime the generator immediately below
545
+ while True :
546
+ next_node = await graph .next (next_node , history , state = state , deps = deps , infer_name = False )
547
+ if isinstance (next_node , End ):
548
+ history .append (EndStep (result = next_node ))
549
+ self .final_result = next_node
550
+ yield next_node
551
+ return
552
+ elif isinstance (next_node , BaseNode ):
553
+ next_node = yield next_node # Give user a chance to modify the next node
554
+ else :
555
+ if TYPE_CHECKING :
556
+ typing_extensions .assert_never (next_node )
557
+ else :
558
+ raise exceptions .GraphRuntimeError (
559
+ f'Invalid node return type: `{ type (next_node ).__name__ } `. Expected `BaseNode` or `End`.'
560
+ )
561
+
562
+ agen = _agen ()
563
+ await agen .__anext__ () # prime the generator
564
+
565
+ self ._agen = agen
566
+ return self ._agen
0 commit comments