Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use .iter() API to fully replace existing streaming implementation #951

Merged
merged 10 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,146 @@ Once the run finishes, `agent_run.final_result` becomes a [`AgentRunResult`][pyd

---

### Streaming

Here is an example of streaming an agent run in combination with `async for` iteration:

```python {title="streaming.py"}
import asyncio
from dataclasses import dataclass
from datetime import date

from pydantic_ai import Agent
from pydantic_ai.messages import (
FinalResultEvent,
FunctionToolCallEvent,
FunctionToolResultEvent,
PartDeltaEvent,
PartStartEvent,
TextPartDelta,
ToolCallPartDelta,
)
from pydantic_ai.tools import RunContext


@dataclass
class WeatherService:
async def get_forecast(self, location: str, forecast_date: date) -> str:
# In real code: call weather API, DB queries, etc.
return f'The forecast in {location} on {forecast_date} is 24°C and sunny.'

async def get_historic_weather(self, location: str, forecast_date: date) -> str:
# In real code: call a historical weather API or DB
return (
f'The weather in {location} on {forecast_date} was 18°C and partly cloudy.'
)


weather_agent = Agent[WeatherService, str](
'openai:gpt-4o',
deps_type=WeatherService,
result_type=str, # We'll produce a final answer as plain text
system_prompt='Providing a weather forecast at the locations the user provides.',
)


@weather_agent.tool
async def weather_forecast(
ctx: RunContext[WeatherService],
location: str,
forecast_date: date,
) -> str:
if forecast_date >= date.today():
return await ctx.deps.get_forecast(location, forecast_date)
else:
return await ctx.deps.get_historic_weather(location, forecast_date)


output_messages: list[str] = []


async def main():
user_prompt = 'What will the weather be like in Paris on Tuesday?'

# Begin a node-by-node, streaming iteration
with weather_agent.iter(user_prompt, deps=WeatherService()) as run:
async for node in run:
if Agent.is_user_prompt_node(node):
# A user prompt node => The user has provided input
output_messages.append(f'=== UserPromptNode: {node.user_prompt} ===')
elif Agent.is_model_request_node(node):
# A model request node => We can stream tokens from the model's request
output_messages.append(
'=== ModelRequestNode: streaming partial request tokens ==='
)
async with node.stream(run.ctx) as request_stream:
async for event in request_stream:
if isinstance(event, PartStartEvent):
output_messages.append(
f'[Request] Starting part {event.index}: {event.part!r}'
)
elif isinstance(event, PartDeltaEvent):
if isinstance(event.delta, TextPartDelta):
output_messages.append(
f'[Request] Part {event.index} text delta: {event.delta.content_delta!r}'
)
elif isinstance(event.delta, ToolCallPartDelta):
output_messages.append(
f'[Request] Part {event.index} args_delta={event.delta.args_delta}'
)
elif isinstance(event, FinalResultEvent):
output_messages.append(
f'[Result] The model produced a final result (tool_name={event.tool_name})'
)
elif Agent.is_handle_response_node(node):
# A handle-response node => The model returned some data, potentially calls a tool
output_messages.append(
'=== HandleResponseNode: streaming partial response & tool usage ==='
)
async with node.stream(run.ctx) as handle_stream:
async for event in handle_stream:
if isinstance(event, FunctionToolCallEvent):
output_messages.append(
f'[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})'
)
elif isinstance(event, FunctionToolResultEvent):
output_messages.append(
f'[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}'
)
elif Agent.is_end_node(node):
assert run.result.data == node.data.data
# Once an End node is reached, the agent run is complete
output_messages.append(f'=== Final Agent Output: {run.result.data} ===')


if __name__ == '__main__':
asyncio.run(main())

print(output_messages)
"""
[
'=== ModelRequestNode: streaming partial request tokens ===',
'[Request] Starting part 0: ToolCallPart(tool_name=\'weather_forecast\', args=\'{"location":"Pa\', tool_call_id=\'0001\', part_kind=\'tool-call\')',
'[Request] Part 0 args_delta=ris","forecast_',
'[Request] Part 0 args_delta=date":"2030-01-',
'[Request] Part 0 args_delta=01"}',
'=== HandleResponseNode: streaming partial response & tool usage ===',
'[Tools] The LLM calls tool=\'weather_forecast\' with args={"location":"Paris","forecast_date":"2030-01-01"} (tool_call_id=\'0001\')',
"[Tools] Tool call '0001' returned => The forecast in Paris on 2030-01-01 is 24°C and sunny.",
'=== ModelRequestNode: streaming partial request tokens ===',
"[Request] Starting part 0: TextPart(content='It will be ', part_kind='text')",
'[Result] The model produced a final result (tool_name=None)',
"[Request] Part 0 text delta: 'warm and sunny '",
"[Request] Part 0 text delta: 'in Paris on '",
"[Request] Part 0 text delta: 'Tuesday.'",
'=== HandleResponseNode: streaming partial response & tool usage ===',
'=== Final Agent Output: It will be warm and sunny in Paris on Tuesday. ===',
]
"""
```

---

### Additional Configuration

#### Usage Limits
Expand Down
55 changes: 48 additions & 7 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

import asyncio
import dataclasses
from abc import ABC
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from dataclasses import field
from typing import Any, Generic, Literal, Union, cast

import logfire_api
from typing_extensions import TypeVar, assert_never
from typing_extensions import TypeGuard, TypeVar, assert_never

from pydantic_graph import BaseNode, Graph, GraphRunContext
from pydantic_graph.nodes import End, NodeRunEndT
Expand Down Expand Up @@ -55,6 +54,7 @@
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)

T = TypeVar('T')
S = TypeVar('S')
NoneType = type(None)
EndStrategy = Literal['early', 'exhaustive']
"""The strategy for handling multiple tool calls when a final result is found.
Expand Down Expand Up @@ -107,8 +107,31 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
run_span: logfire_api.LogfireSpan


class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using a type-alias instead of a class here, and it caused some issues because I couldn't inherit from it.

I think may be possible to drop the AgentNode class and use a type alias (just inheriting from the non-aliased parametrized BaseNode), and either way I think it's not currently public, but I think having this type allows us to remove most references to BaseNode in agent.py, which I think ultimately makes a lot of types more readable.

Unless one of you really objects, I would prefer to keep it around for now rather than wrestle with the consequences of removing it (in terms of verbosity and/or type-checking challenges). Especially considering it isn't a public API anyway.

"""The base class for all agent nodes.

Using subclass of `BaseNode` for all nodes reduces the amount of boilerplate of generics everywhere
"""


def is_agent_node(
node: BaseNode[GraphAgentState, GraphAgentDeps[T, Any], result.FinalResult[S]] | End[result.FinalResult[S]],
) -> TypeGuard[AgentNode[T, S]]:
"""Check if the provided node is an instance of `AgentNode`.

Usage:

if is_agent_node(node):
# `node` is an AgentNode
...

This method preserves the generic parameters on the narrowed type, unlike `isinstance(node, AgentNode)`.
"""
return isinstance(node, AgentNode)


@dataclasses.dataclass
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC):
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
user_prompt: str | Sequence[_messages.UserContent]

system_prompts: tuple[str, ...]
Expand Down Expand Up @@ -215,7 +238,7 @@ async def add_tool(tool: Tool[DepsT]) -> None:


@dataclasses.dataclass
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
"""Make a request to the model using the last message in state.message_history."""

request: _messages.ModelRequest
Expand All @@ -236,12 +259,30 @@ async def run(

return await self._make_request(ctx)

@asynccontextmanager
async def stream(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
) -> AsyncIterator[result.AgentStream[DepsT, T]]:
async with self._stream(ctx) as streamed_response:
agent_stream = result.AgentStream[DepsT, T](
streamed_response,
ctx.deps.result_schema,
ctx.deps.result_validators,
build_run_context(ctx),
ctx.deps.usage_limits,
)
yield agent_stream
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
# otherwise usage won't be properly counted:
async for _ in agent_stream:
pass
Comment on lines +276 to +279
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kludex not 100% sure if we should do this force-consume-the-stream (when there isn't an exception), but I think it's what we should do. (Unless we confirm that you don't get charged for streams you don't consume, but that seems unlikely.)


@asynccontextmanager
async def _stream(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
) -> AsyncIterator[models.StreamedResponse]:
# TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
assert not self._did_stream, 'stream() should only be called once per node'

model_settings, model_request_parameters = await self._prepare_request(ctx)
Expand Down Expand Up @@ -319,7 +360,7 @@ def _finish_handling(


@dataclasses.dataclass
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
"""Process a model response, and decide whether to end the run or make a new request."""

model_response: _messages.ModelResponse
Expand Down Expand Up @@ -575,7 +616,7 @@ async def process_function_tools(
for task in done:
index = tasks.index(task)
result = task.result()
yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index])
yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index])
if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
results_by_index[index] = result
else:
Expand Down
Loading