Skip to content

Commit 6586af3

Browse files
committed
Use .iter() API to fully replace existing streaming implementation
1 parent bb41987 commit 6586af3

File tree

5 files changed

+351
-6
lines changed

5 files changed

+351
-6
lines changed

docs/agents.md

+175
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,181 @@ Once the run finishes, `agent_run.final_result` becomes a [`AgentRunResult`][pyd
220220

221221
---
222222

223+
### Streaming
224+
225+
Here is an example of streaming in combination with `async for`:
226+
227+
```python {title="streaming.py"}
228+
import asyncio
229+
from dataclasses import dataclass
230+
from datetime import date
231+
232+
from pydantic_ai import (
233+
Agent,
234+
capture_run_messages,
235+
)
236+
from pydantic_ai.agent import is_handle_response_node, is_model_request_node
237+
from pydantic_ai.messages import (
238+
PartStartEvent,
239+
PartDeltaEvent,
240+
FunctionToolCallEvent,
241+
FunctionToolResultEvent,
242+
FinalResultEvent, TextPartDelta, ToolCallPartDelta,
243+
)
244+
from pydantic_ai.tools import RunContext
245+
from pydantic_graph import End
246+
247+
248+
@dataclass
249+
class WeatherService:
250+
async def get_forecast(self, location: str, forecast_date: date) -> str:
251+
# In real code: call weather API, DB queries, etc.
252+
return f"The forecast in {location} on {forecast_date} is 24°C and sunny."
253+
254+
async def get_historic_weather(self, location: str, forecast_date: date) -> str:
255+
# In real code: call a historical weather API or DB
256+
return f"The weather in {location} on {forecast_date} was 18°C and partly cloudy."
257+
258+
259+
weather_agent = Agent[WeatherService, str](
260+
"openai:gpt-4o",
261+
deps_type=WeatherService,
262+
result_type=str, # We'll produce a final answer as plain text
263+
system_prompt="Providing a weather forecast at the locations the user provides.",
264+
)
265+
266+
267+
@weather_agent.tool
268+
async def weather_forecast(
269+
ctx: RunContext[WeatherService],
270+
location: str,
271+
forecast_date: date,
272+
) -> str:
273+
if forecast_date >= date.today():
274+
return await ctx.deps.get_forecast(location, forecast_date)
275+
else:
276+
return await ctx.deps.get_historic_weather(location, forecast_date)
277+
278+
279+
async def main():
280+
# The user asks for tomorrow's weather in Paris
281+
user_prompt = "What will the weather be like in Paris tomorrow?"
282+
283+
# We'll capture raw messages for debugging
284+
with capture_run_messages() as messages:
285+
# Provide a WeatherService instance as the agent's dependencies
286+
deps = WeatherService()
287+
288+
# Begin a node-by-node, streaming iteration
289+
with weather_agent.iter(user_prompt, deps=deps) as run:
290+
node = run.next_node # The first node to run
291+
while not isinstance(node, End):
292+
if is_model_request_node(node):
293+
# A model request node => We can stream tokens from the model's request
294+
print("=== ModelRequestNode: streaming partial request tokens ===")
295+
async with node.stream(run.ctx) as request_stream:
296+
async for event in request_stream:
297+
if isinstance(event, PartStartEvent):
298+
print(f"[Request] Starting part {event.index}: {event.part!r}")
299+
elif isinstance(event, PartDeltaEvent):
300+
if isinstance(event.delta, TextPartDelta):
301+
print(f"[Request] Part {event.index} text delta: {event.delta.content_delta!r}")
302+
elif isinstance(event.delta, ToolCallPartDelta):
303+
print(f"[Request] Part {event.index} args_delta={event.delta.args_delta}")
304+
elif isinstance(event, FinalResultEvent):
305+
print(f"[Result] The model produced a final result (tool_name={event.tool_name})")
306+
307+
elif is_handle_response_node(node):
308+
# A handle-response node => The model returned some data, potentially calls a tool
309+
print("=== HandleResponseNode: streaming partial response & tool usage ===")
310+
async with node.stream(run.ctx) as handle_stream:
311+
async for event in handle_stream:
312+
if isinstance(event, FunctionToolCallEvent):
313+
print(f"[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})")
314+
elif isinstance(event, FunctionToolResultEvent):
315+
print(f"[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}")
316+
317+
node = await run.next(node)
318+
319+
# Once an End node is reached, the agent run is complete
320+
assert run.result is not None
321+
print("\n=== Final Agent Output ===")
322+
print("Forecast:", run.result.data)
323+
324+
# Show the raw messages exchanged
325+
print("\n=== Raw Messages Captured ===")
326+
for m in messages:
327+
print(" -", m)
328+
329+
330+
if __name__ == "__main__":
331+
asyncio.run(main())
332+
333+
"""
334+
=== ModelRequestNode: streaming partial request tokens ===
335+
[Request] Starting part 0: ToolCallPart(tool_name='weather_forecast', args='', tool_call_id='call_Q0QqiZfIhHyNViiLG7jT0G9R', part_kind='tool-call')
336+
[Request] Part 0 args_delta={"
337+
[Request] Part 0 args_delta=location
338+
[Request] Part 0 args_delta=":"
339+
[Request] Part 0 args_delta=Paris
340+
[Request] Part 0 args_delta=","
341+
[Request] Part 0 args_delta=forecast
342+
[Request] Part 0 args_delta=_date
343+
[Request] Part 0 args_delta=":"
344+
[Request] Part 0 args_delta=202
345+
[Request] Part 0 args_delta=3
346+
[Request] Part 0 args_delta=-
347+
[Request] Part 0 args_delta=11
348+
[Request] Part 0 args_delta=-
349+
[Request] Part 0 args_delta=02
350+
[Request] Part 0 args_delta="}
351+
=== HandleResponseNode: streaming partial response & tool usage ===
352+
[Tools] The LLM calls tool='weather_forecast' with args={"location":"Paris","forecast_date":"2023-11-02"} (tool_call_id='call_Q0QqiZfIhHyNViiLG7jT0G9R')
353+
[Tools] Tool call 'call_Q0QqiZfIhHyNViiLG7jT0G9R' returned => The weather in Paris on 2023-11-02 was 18°C and partly cloudy.
354+
=== ModelRequestNode: streaming partial request tokens ===
355+
[Request] Starting part 0: TextPart(content='', part_kind='text')
356+
[Result] The model produced a final result (tool_name=None)
357+
[Request] Part 0 text delta: 'The'
358+
[Request] Part 0 text delta: ' weather'
359+
[Request] Part 0 text delta: ' forecast'
360+
[Request] Part 0 text delta: ' for'
361+
[Request] Part 0 text delta: ' Paris'
362+
[Request] Part 0 text delta: ' tomorrow'
363+
[Request] Part 0 text delta: ','
364+
[Request] Part 0 text delta: ' November'
365+
[Request] Part 0 text delta: ' '
366+
[Request] Part 0 text delta: '2'
367+
[Request] Part 0 text delta: ','
368+
[Request] Part 0 text delta: ' '
369+
[Request] Part 0 text delta: '202'
370+
[Request] Part 0 text delta: '3'
371+
[Request] Part 0 text delta: ','
372+
[Request] Part 0 text delta: ' is'
373+
[Request] Part 0 text delta: ' expected'
374+
[Request] Part 0 text delta: ' to'
375+
[Request] Part 0 text delta: ' be'
376+
[Request] Part 0 text delta: ' '
377+
[Request] Part 0 text delta: '18'
378+
[Request] Part 0 text delta: '°C'
379+
[Request] Part 0 text delta: ' and'
380+
[Request] Part 0 text delta: ' partly'
381+
[Request] Part 0 text delta: ' cloudy'
382+
[Request] Part 0 text delta: '.'
383+
=== HandleResponseNode: streaming partial response & tool usage ===
384+
385+
=== Final Agent Output ===
386+
Forecast: The weather forecast for Paris tomorrow, November 2, 2023, is expected to be 18°C and partly cloudy.
387+
388+
=== Raw Messages Captured ===
389+
- ModelRequest(parts=[SystemPromptPart(content='Providing a weather forecast at the locations the user provides.', dynamic_ref=None, part_kind='system-prompt'), UserPromptPart(content='What will the weather be like in Paris tomorrow?', timestamp=datetime.datetime(2025, 2, 25, 7, 16, 4, 867863, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request')
390+
- ModelResponse(parts=[ToolCallPart(tool_name='weather_forecast', args='{"location":"Paris","forecast_date":"2023-11-02"}', tool_call_id='call_Q0QqiZfIhHyNViiLG7jT0G9R', part_kind='tool-call')], model_name='gpt-4o', timestamp=datetime.datetime(2025, 2, 25, 7, 16, 8, tzinfo=datetime.timezone.utc), kind='response')
391+
- ModelRequest(parts=[ToolReturnPart(tool_name='weather_forecast', content='The weather in Paris on 2023-11-02 was 18°C and partly cloudy.', tool_call_id='call_Q0QqiZfIhHyNViiLG7jT0G9R', timestamp=datetime.datetime(2025, 2, 25, 7, 16, 9, 150432, tzinfo=datetime.timezone.utc), part_kind='tool-return')], kind='request')
392+
- ModelResponse(parts=[TextPart(content='The weather forecast for Paris tomorrow, November 2, 2023, is expected to be 18°C and partly cloudy.', part_kind='text')], model_name='gpt-4o', timestamp=datetime.datetime(2025, 2, 25, 7, 16, 9, tzinfo=datetime.timezone.utc), kind='response')
393+
"""
394+
```
395+
396+
---
397+
223398
### Additional Configuration
224399

225400
#### Usage Limits

pydantic_ai_slim/pydantic_ai/_agent_graph.py

+35-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Any, Generic, Literal, Union, cast
1111

1212
import logfire_api
13-
from typing_extensions import TypeVar, assert_never
13+
from typing_extensions import TypeGuard, TypeVar, assert_never
1414

1515
from pydantic_graph import BaseNode, Graph, GraphRunContext
1616
from pydantic_graph.nodes import End, NodeRunEndT
@@ -40,6 +40,8 @@
4040
'HandleResponseNode',
4141
'build_run_context',
4242
'capture_run_messages',
43+
'is_model_request_node',
44+
'is_handle_response_node',
4345
)
4446

4547
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
@@ -236,12 +238,30 @@ async def run(
236238

237239
return await self._make_request(ctx)
238240

241+
@asynccontextmanager
242+
async def stream(
243+
self,
244+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
245+
) -> AsyncIterator[result.AgentStream[DepsT, T]]:
246+
async with self._stream(ctx) as streamed_response:
247+
agent_stream = result.AgentStream[DepsT, T](
248+
streamed_response,
249+
ctx.deps.result_schema,
250+
ctx.deps.result_validators,
251+
build_run_context(ctx),
252+
ctx.deps.usage_limits,
253+
)
254+
yield agent_stream
255+
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
256+
# otherwise usage won't be properly counted:
257+
async for _ in agent_stream:
258+
pass
259+
239260
@asynccontextmanager
240261
async def _stream(
241262
self,
242263
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
243264
) -> AsyncIterator[models.StreamedResponse]:
244-
# TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
245265
assert not self._did_stream, 'stream() should only be called once per node'
246266

247267
model_settings, model_request_parameters = await self._prepare_request(ctx)
@@ -575,7 +595,7 @@ async def process_function_tools(
575595
for task in done:
576596
index = tasks.index(task)
577597
result = task.result()
578-
yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index])
598+
yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index])
579599
if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
580600
results_by_index[index] = result
581601
else:
@@ -685,3 +705,15 @@ def build_agent_graph(
685705
auto_instrument=False,
686706
)
687707
return graph
708+
709+
710+
def is_model_request_node(
711+
node: BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]],
712+
) -> TypeGuard[ModelRequestNode[DepsT, NodeRunEndT]]:
713+
return isinstance(node, ModelRequestNode)
714+
715+
716+
def is_handle_response_node(
717+
node: BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]],
718+
) -> TypeGuard[HandleResponseNode[DepsT, NodeRunEndT]]:
719+
return isinstance(node, HandleResponseNode)

pydantic_ai_slim/pydantic_ai/agent.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@
4545
HandleResponseNode = _agent_graph.HandleResponseNode
4646
ModelRequestNode = _agent_graph.ModelRequestNode
4747
UserPromptNode = _agent_graph.UserPromptNode
48-
48+
is_handle_response_node = _agent_graph.is_handle_response_node
49+
is_model_request_node = _agent_graph.is_model_request_node
4950

5051
__all__ = (
5152
'Agent',
@@ -56,6 +57,8 @@
5657
'HandleResponseNode',
5758
'ModelRequestNode',
5859
'UserPromptNode',
60+
'is_handle_response_node',
61+
'is_model_request_node',
5962
)
6063

6164
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')

pydantic_ai_slim/pydantic_ai/messages.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -533,9 +533,24 @@ class PartDeltaEvent:
533533
"""Event type identifier, used as a discriminator."""
534534

535535

536+
@dataclass
537+
class FinalResultEvent:
538+
"""An event indicating the response to the current model request matches the result schema."""
539+
540+
tool_name: str | None
541+
"""The name of the result tool that was called. `None` if the result is from text content and not from a tool."""
542+
event_kind: Literal['final_result'] = 'final_result'
543+
"""Event type identifier, used as a discriminator."""
544+
545+
536546
ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
537547
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""
538548

549+
AgentStreamEvent = Annotated[
550+
Union[PartStartEvent, PartDeltaEvent, FinalResultEvent], pydantic.Discriminator('event_kind')
551+
]
552+
"""An event in the agent stream."""
553+
539554

540555
@dataclass
541556
class FunctionToolCallEvent:
@@ -558,7 +573,7 @@ class FunctionToolResultEvent:
558573

559574
result: ToolReturnPart | RetryPromptPart
560575
"""The result of the call to the function tool."""
561-
call_id: str
576+
tool_call_id: str
562577
"""An ID used to match the result to its original call."""
563578
event_kind: Literal['function_tool_result'] = 'function_tool_result'
564579
"""Event type identifier, used as a discriminator."""

0 commit comments

Comments
 (0)