@@ -108,20 +108,20 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
108
108
109
109
110
110
@dataclasses .dataclass
111
- class UserPromptNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], NodeRunEndT ], ABC ):
111
+ class UserPromptNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], result . FinalResult [ NodeRunEndT ] ], ABC ):
112
112
user_prompt : str | Sequence [_messages .UserContent ]
113
113
114
114
system_prompts : tuple [str , ...]
115
115
system_prompt_functions : list [_system_prompt .SystemPromptRunner [DepsT ]]
116
116
system_prompt_dynamic_functions : dict [str , _system_prompt .SystemPromptRunner [DepsT ]]
117
117
118
118
async def run (
119
- self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]
119
+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
120
120
) -> ModelRequestNode [DepsT , NodeRunEndT ]:
121
121
return ModelRequestNode [DepsT , NodeRunEndT ](request = await self ._get_first_message (ctx ))
122
122
123
123
async def _get_first_message (
124
- self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]
124
+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
125
125
) -> _messages .ModelRequest :
126
126
run_context = build_run_context (ctx )
127
127
history , next_message = await self ._prepare_messages (self .user_prompt , ctx .state .message_history , run_context )
@@ -215,7 +215,7 @@ async def add_tool(tool: Tool[DepsT]) -> None:
215
215
216
216
217
217
@dataclasses .dataclass
218
- class ModelRequestNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], NodeRunEndT ]):
218
+ class ModelRequestNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], result . FinalResult [ NodeRunEndT ] ]):
219
219
"""Make a request to the model using the last message in state.message_history."""
220
220
221
221
request : _messages .ModelRequest
@@ -319,7 +319,7 @@ def _finish_handling(
319
319
320
320
321
321
@dataclasses .dataclass
322
- class HandleResponseNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], NodeRunEndT ]):
322
+ class HandleResponseNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], result . FinalResult [ NodeRunEndT ] ]):
323
323
"""Process a model response, and decide whether to end the run or make a new request."""
324
324
325
325
model_response : _messages .ModelResponse
@@ -341,7 +341,7 @@ async def run(
341
341
342
342
@asynccontextmanager
343
343
async def stream (
344
- self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]
344
+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
345
345
) -> AsyncIterator [AsyncIterator [_messages .HandleResponseEvent ]]:
346
346
"""Process the model response and yield events for the start and end of each function tool call."""
347
347
with _logfire .span ('handle model response' , run_step = ctx .state .run_step ) as handle_span :
@@ -366,7 +366,7 @@ async def stream(
366
366
handle_span .message = f'handle model response -> { tool_responses_str } '
367
367
368
368
async def _run_stream (
369
- self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]
369
+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
370
370
) -> AsyncIterator [_messages .HandleResponseEvent ]:
371
371
if self ._events_iterator is None :
372
372
# Ensure that the stream is only run once
@@ -670,7 +670,7 @@ def get_captured_run_messages() -> _RunMessages:
670
670
671
671
def build_agent_graph (
672
672
name : str | None , deps_type : type [DepsT ], result_type : type [ResultT ]
673
- ) -> Graph [GraphAgentState , GraphAgentDeps [DepsT , Any ], result .FinalResult [ResultT ]]:
673
+ ) -> Graph [GraphAgentState , GraphAgentDeps [DepsT , result . FinalResult [ ResultT ] ], result .FinalResult [ResultT ]]:
674
674
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
675
675
nodes = (
676
676
UserPromptNode [DepsT ],
0 commit comments