Skip to content

Commit bb41987

Browse files
authored
Fix agent graph types (#983)
1 parent 96be03d commit bb41987

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,20 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
108108

109109

110110
@dataclasses.dataclass
111-
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
111+
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC):
112112
user_prompt: str | Sequence[_messages.UserContent]
113113

114114
system_prompts: tuple[str, ...]
115115
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
116116
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
117117

118118
async def run(
119-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
119+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
120120
) -> ModelRequestNode[DepsT, NodeRunEndT]:
121121
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
122122

123123
async def _get_first_message(
124-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
124+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
125125
) -> _messages.ModelRequest:
126126
run_context = build_run_context(ctx)
127127
history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context)
@@ -215,7 +215,7 @@ async def add_tool(tool: Tool[DepsT]) -> None:
215215

216216

217217
@dataclasses.dataclass
218-
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
218+
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
219219
"""Make a request to the model using the last message in state.message_history."""
220220

221221
request: _messages.ModelRequest
@@ -319,7 +319,7 @@ def _finish_handling(
319319

320320

321321
@dataclasses.dataclass
322-
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]):
322+
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
323323
"""Process a model response, and decide whether to end the run or make a new request."""
324324

325325
model_response: _messages.ModelResponse
@@ -341,7 +341,7 @@ async def run(
341341

342342
@asynccontextmanager
343343
async def stream(
344-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
344+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
345345
) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]:
346346
"""Process the model response and yield events for the start and end of each function tool call."""
347347
with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span:
@@ -366,7 +366,7 @@ async def stream(
366366
handle_span.message = f'handle model response -> {tool_responses_str}'
367367

368368
async def _run_stream(
369-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
369+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
370370
) -> AsyncIterator[_messages.HandleResponseEvent]:
371371
if self._events_iterator is None:
372372
# Ensure that the stream is only run once
@@ -670,7 +670,7 @@ def get_captured_run_messages() -> _RunMessages:
670670

671671
def build_agent_graph(
672672
name: str | None, deps_type: type[DepsT], result_type: type[ResultT]
673-
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]:
673+
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[ResultT]], result.FinalResult[ResultT]]:
674674
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
675675
nodes = (
676676
UserPromptNode[DepsT],

0 commit comments

Comments
 (0)