Skip to content

Commit 2cfa693

Browse files
committed
Merge branch 'main' into dmontagu/graph-run-streaming
2 parents 5008c9e + bb41987 commit 2cfa693

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -110,20 +110,20 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
110110

111111

112112
@dataclasses.dataclass
113-
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
113+
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC):
114114
user_prompt: str | Sequence[_messages.UserContent]
115115

116116
system_prompts: tuple[str, ...]
117117
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
118118
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
119119

120120
async def run(
121-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
121+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
122122
) -> ModelRequestNode[DepsT, NodeRunEndT]:
123123
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
124124

125125
async def _get_first_message(
126-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
126+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
127127
) -> _messages.ModelRequest:
128128
run_context = build_run_context(ctx)
129129
history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
@@ -217,7 +217,7 @@ async def add_tool(tool: Tool[DepsT]) -> None:
217217

218218

219219
@dataclasses.dataclass
220-
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
220+
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
221221
"""Make a request to the model using the last message in state.message_history."""
222222

223223
request: _messages.ModelRequest
@@ -339,7 +339,7 @@ def _finish_handling(
339339

340340

341341
@dataclasses.dataclass
342-
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
342+
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
343343
"""Process a model response, and decide whether to end the run or make a new request."""
344344

345345
model_response: _messages.ModelResponse
@@ -361,7 +361,7 @@ async def run(
361361

362362
@asynccontextmanager
363363
async def stream(
364-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
364+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
365365
) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
366366
"""Process the model response and yield events for the start and end of each function tool call."""
367367
with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
@@ -386,7 +386,7 @@ async def stream(
386386
handle_span.message = f'handle model response -> {tool_responses_str}'
387387

388388
async def _run_stream(
389-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
389+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
390390
) -> AsyncIterator[_messages.HandleResponseEvent]:
391391
if self._events_iterator is None:
392392
# Ensure that the stream is only run once
@@ -690,7 +690,7 @@ def get_captured_run_messages() -> _RunMessages:
690690

691691
def build_agent_graph(
692692
name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
693-
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]:
693+
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[ResultT]], result.FinalResult[ResultT]]:
694694
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
695695
nodes = (
696696
UserPromptNode[DepsT],

0 commit comments

Comments
 (0)