Skip to content

Commit 571d805

Browse files
committed
Add GraphRun object
1 parent 739f179 commit 571d805

File tree

2 files changed

+66
-12
lines changed

2 files changed

+66
-12
lines changed

pydantic_graph/pydantic_graph/graph.py

+65-11
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import inspect
55
import types
6-
from collections.abc import Sequence
6+
from collections.abc import AsyncGenerator, Sequence
77
from contextlib import ExitStack
88
from dataclasses import dataclass, field
99
from functools import cached_property
@@ -170,7 +170,7 @@ async def main():
170170
if infer_name and self.name is None:
171171
self._infer_name(inspect.currentframe())
172172

173-
history: list[HistoryStep[StateT, T]] = []
173+
graph_run = GraphRun[StateT, DepsT, T](self, state=state, deps=deps)
174174
with ExitStack() as stack:
175175
run_span: logfire_api.LogfireSpan | None = None
176176
if self._auto_instrument:
@@ -184,19 +184,12 @@ async def main():
184184

185185
next_node = start_node
186186
while True:
187-
next_node = await self.next(next_node, history, state=state, deps=deps, infer_name=False)
187+
next_node = await graph_run.next(next_node)
188188
if isinstance(next_node, End):
189-
history.append(EndStep(result=next_node))
189+
history = graph_run.history
190190
if run_span is not None:
191191
run_span.set_attribute('history', history)
192192
return next_node.data, history
193-
elif not isinstance(next_node, BaseNode):
194-
if TYPE_CHECKING:
195-
typing_extensions.assert_never(next_node)
196-
else:
197-
raise exceptions.GraphRuntimeError(
198-
f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.'
199-
)
200193

201194
def run_sync(
202195
self: Graph[StateT, DepsT, T],
@@ -510,3 +503,64 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None:
510503
if item is self:
511504
self.name = name
512505
return
506+
507+
508+
class GraphRun(Generic[StateT, DepsT, RunEndT]):
509+
def __init__(
510+
self,
511+
graph: Graph[StateT, DepsT, RunEndT],
512+
*,
513+
state: StateT = None,
514+
deps: DepsT = None,
515+
):
516+
self.graph = graph
517+
self.state = state
518+
self.deps = deps
519+
520+
self.history: list[HistoryStep[StateT, RunEndT]] = []
521+
self.final_result: End[RunEndT] | None = None
522+
523+
self._agen: (
524+
AsyncGenerator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT], BaseNode[StateT, DepsT, RunEndT]] | None
525+
) = None
526+
527+
async def next(
528+
self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T]
529+
) -> BaseNode[StateT, DepsT, Any] | End[T]:
530+
agen = await self._get_primed_agen()
531+
return await agen.asend(node)
532+
533+
async def _get_primed_agen(
534+
self: GraphRun[StateT, DepsT, T],
535+
) -> AsyncGenerator[BaseNode[StateT, DepsT, T] | End[T], BaseNode[StateT, DepsT, T]]:
536+
graph = self.graph
537+
state = self.state
538+
deps = self.deps
539+
history = self.history
540+
541+
if self._agen is None:
542+
543+
async def _agen() -> AsyncGenerator[BaseNode[StateT, DepsT, T] | End[T], BaseNode[StateT, DepsT, T]]:
544+
next_node = yield # pyright: ignore[reportReturnType] # we prime the generator immediately below
545+
while True:
546+
next_node = await graph.next(next_node, history, state=state, deps=deps, infer_name=False)
547+
if isinstance(next_node, End):
548+
history.append(EndStep(result=next_node))
549+
self.final_result = next_node
550+
yield next_node
551+
return
552+
elif isinstance(next_node, BaseNode):
553+
next_node = yield next_node # Give user a chance to modify the next node
554+
else:
555+
if TYPE_CHECKING:
556+
typing_extensions.assert_never(next_node)
557+
else:
558+
raise exceptions.GraphRuntimeError(
559+
f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.'
560+
)
561+
562+
agen = _agen()
563+
await agen.__anext__() # prime the generator
564+
565+
self._agen = agen
566+
return self._agen

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,4 +193,4 @@ skip = '.git*,*.svg,*.lock,*.css'
193193
check-hidden = true
194194
# Ignore "formatting" like **L**anguage
195195
ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b'
196-
# ignore-words-list = ''
196+
ignore-words-list = 'asend'

0 commit comments

Comments
 (0)