@@ -110,20 +110,20 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
110
110
111
111
112
112
@dataclasses .dataclass
113
- class UserPromptNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], NodeRunEndT ], ABC ):
113
+ class UserPromptNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], result . FinalResult [ NodeRunEndT ] ], ABC ):
114
114
user_prompt : str | Sequence [_messages .UserContent ]
115
115
116
116
system_prompts : tuple [str , ...]
117
117
system_prompt_functions : list [_system_prompt .SystemPromptRunner [DepsT ]]
118
118
system_prompt_dynamic_functions : dict [str , _system_prompt .SystemPromptRunner [DepsT ]]
119
119
120
120
async def run (
121
- self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]
121
+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
122
122
) -> ModelRequestNode [DepsT , NodeRunEndT ]:
123
123
return ModelRequestNode [DepsT , NodeRunEndT ](request = await self ._get_first_message (ctx ))
124
124
125
125
async def _get_first_message (
126
- self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]
126
+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
127
127
) -> _messages .ModelRequest :
128
128
run_context = build_run_context (ctx )
129
129
history , next_message = await self ._prepare_messages (self .user_prompt , ctx .state .message_history , run_context )
@@ -217,7 +217,7 @@ async def add_tool(tool: Tool[DepsT]) -> None:
217
217
218
218
219
219
@dataclasses .dataclass
220
- class ModelRequestNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], NodeRunEndT ]):
220
+ class ModelRequestNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], result . FinalResult [ NodeRunEndT ] ]):
221
221
"""Make a request to the model using the last message in state.message_history."""
222
222
223
223
request : _messages .ModelRequest
@@ -339,7 +339,7 @@ def _finish_handling(
339
339
340
340
341
341
@dataclasses .dataclass
342
- class HandleResponseNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], NodeRunEndT ]):
342
+ class HandleResponseNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], result . FinalResult [ NodeRunEndT ] ]):
343
343
"""Process a model response, and decide whether to end the run or make a new request."""
344
344
345
345
model_response : _messages .ModelResponse
@@ -361,7 +361,7 @@ async def run(
361
361
362
362
@asynccontextmanager
363
363
async def stream (
364
- self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]
364
+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
365
365
) -> AsyncIterator [AsyncIterator [_messages .HandleResponseEvent ]]:
366
366
"""Process the model response and yield events for the start and end of each function tool call."""
367
367
with _logfire .span ('handle model response' , run_step = ctx .state .run_step ) as handle_span :
@@ -386,7 +386,7 @@ async def stream(
386
386
handle_span .message = f'handle model response -> { tool_responses_str } '
387
387
388
388
async def _run_stream (
389
- self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]
389
+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
390
390
) -> AsyncIterator [_messages .HandleResponseEvent ]:
391
391
if self ._events_iterator is None :
392
392
# Ensure that the stream is only run once
@@ -690,7 +690,7 @@ def get_captured_run_messages() -> _RunMessages:
690
690
691
691
def build_agent_graph (
692
692
name : str | None , deps_type : type [DepsT ], result_type : type [ResultT ]
693
- ) -> Graph [GraphAgentState , GraphAgentDeps [DepsT , Any ], result .FinalResult [ResultT ]]:
693
+ ) -> Graph [GraphAgentState , GraphAgentDeps [DepsT , result . FinalResult [ ResultT ] ], result .FinalResult [ResultT ]]:
694
694
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
695
695
nodes = (
696
696
UserPromptNode [DepsT ],
0 commit comments