Skip to content

Commit 55ab7f0

Browse files
committed
Use .iter() API to fully replace existing streaming implementation
1 parent 8bf5bee commit 55ab7f0

File tree

3 files changed

+155
-2
lines changed

3 files changed

+155
-2
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,30 @@ async def run(
233233

234234
return await self._make_request(ctx)
235235

236+
@asynccontextmanager
237+
async def stream(
238+
self,
239+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
240+
) -> AsyncIterator[result.AgentStream[DepsT, T]]:
241+
async with self._stream(ctx) as streamed_response:
242+
agent_stream = result.AgentStream[DepsT, T](
243+
streamed_response,
244+
ctx.deps.result_schema,
245+
ctx.deps.result_validators,
246+
build_run_context(ctx),
247+
ctx.deps.usage_limits,
248+
)
249+
yield agent_stream
250+
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
251+
# otherwise usage won't be properly counted:
252+
async for _ in agent_stream:
253+
pass
254+
236255
@asynccontextmanager
237256
async def _stream(
238257
self,
239258
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
240259
) -> AsyncIterator[models.StreamedResponse]:
241-
# TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
242260
assert not self._did_stream, 'stream() should only be called once per node'
243261

244262
model_settings, model_request_parameters = await self._prepare_request(ctx)

pydantic_ai_slim/pydantic_ai/messages.py

+15
Original file line numberDiff line numberDiff line change
@@ -444,9 +444,24 @@ class PartDeltaEvent:
444444
"""Event type identifier, used as a discriminator."""
445445

446446

447+
@dataclass
448+
class FinalResultEvent:
449+
"""An event indicating the response to the current model request matches the result schema."""
450+
451+
tool_name: str | None
452+
"""The name of the result tool that was called. `None` if the result is from text content and not from a tool."""
453+
event_kind: Literal['final_result'] = 'final_result'
454+
"""Event type identifier, used as a discriminator."""
455+
456+
447457
ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
448458
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""
449459

460+
AgentStreamEvent = Annotated[
461+
Union[PartStartEvent, PartDeltaEvent, FinalResultEvent], pydantic.Discriminator('event_kind')
462+
]
463+
"""An event in the agent stream."""
464+
450465

451466
@dataclass
452467
class FunctionToolCallEvent:

pydantic_ai_slim/pydantic_ai/result.py

+121-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from typing import Generic, Union, cast
88

99
import logfire_api
10-
from typing_extensions import TypeVar
10+
from typing_extensions import TypeVar, assert_type
1111

1212
from . import _result, _utils, exceptions, messages as _messages, models
13+
from .messages import AgentStreamEvent, FinalResultEvent
1314
from .tools import AgentDepsT, RunContext
1415
from .usage import Usage, UsageLimits
1516

@@ -51,6 +52,125 @@
5152
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
5253

5354

55+
@dataclass
56+
class AgentStream(Generic[AgentDepsT, ResultDataT]):
57+
_raw_stream_response: models.StreamedResponse
58+
_result_schema: _result.ResultSchema[ResultDataT] | None
59+
_result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]]
60+
_run_ctx: RunContext[AgentDepsT]
61+
_usage_limits: UsageLimits | None
62+
63+
_agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
64+
_final_result_event: FinalResultEvent | None = field(default=None, init=False)
65+
_initial_run_ctx_usage: Usage = field(init=False)
66+
67+
def __post_init__(self):
68+
self._initial_run_ctx_usage = copy(self._run_ctx.usage)
69+
70+
async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]:
71+
"""Asynchronously stream the (validated) agent outputs."""
72+
async for response in self.stream_responses(debounce_by=debounce_by):
73+
if self._final_result_event is not None:
74+
yield await self._validate_response(response, self._final_result_event.tool_name, allow_partial=True)
75+
if self._final_result_event is not None:
76+
yield await self._validate_response(
77+
self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False
78+
)
79+
80+
async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
81+
"""Asynchronously stream the (unvalidated) model responses for the agent."""
82+
# if the message currently has any parts with content, yield before streaming
83+
msg = self._raw_stream_response.get()
84+
for part in msg.parts:
85+
if part.has_content():
86+
yield msg
87+
break
88+
89+
async with _utils.group_by_temporal(self, debounce_by) as group_iter:
90+
async for _items in group_iter:
91+
yield self._raw_stream_response.get() # current state of the response
92+
93+
def usage(self) -> Usage:
94+
"""Return the usage of the whole run.
95+
96+
!!! note
97+
This won't return the full usage until the stream is finished.
98+
"""
99+
return self._initial_run_ctx_usage + self._raw_stream_response.usage()
100+
101+
async def _validate_response(
102+
self, message: _messages.ModelResponse, result_tool_name: str | None, *, allow_partial: bool = False
103+
) -> ResultDataT:
104+
"""Validate a structured result message."""
105+
if self._result_schema is not None and result_tool_name is not None:
106+
match = self._result_schema.find_named_tool(message.parts, result_tool_name)
107+
if match is None:
108+
raise exceptions.UnexpectedModelBehavior(
109+
f'Invalid response, unable to find tool: {self._result_schema.tool_names()}'
110+
)
111+
112+
call, result_tool = match
113+
result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
114+
115+
for validator in self._result_validators:
116+
result_data = await validator.validate(result_data, call, self._run_ctx)
117+
return result_data
118+
else:
119+
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
120+
for validator in self._result_validators:
121+
text = await validator.validate(
122+
text,
123+
None,
124+
self._run_ctx,
125+
)
126+
# Since there is no result tool, we can assume that str is compatible with ResultDataT
127+
return cast(ResultDataT, text)
128+
129+
def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
130+
"""Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
131+
132+
This proxies the _raw_stream_response and sends all events to the agent stream, while also checking for matches
133+
on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the
134+
first match is found.
135+
"""
136+
if self._agent_stream_iterator is not None:
137+
return self._agent_stream_iterator
138+
139+
async def aiter():
140+
result_schema = self._result_schema
141+
allow_text_result = result_schema is None or result_schema.allow_text_result
142+
143+
def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None:
144+
"""Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
145+
if isinstance(e, _messages.PartStartEvent):
146+
new_part = e.part
147+
if isinstance(new_part, _messages.ToolCallPart):
148+
if result_schema is not None and (match := result_schema.find_tool([new_part])):
149+
call, _ = match
150+
return _messages.FinalResultEvent(tool_name=call.tool_name)
151+
elif allow_text_result:
152+
assert_type(e, _messages.PartStartEvent)
153+
return _messages.FinalResultEvent(tool_name=None)
154+
155+
usage_checking_stream = _get_usage_checking_stream_response(
156+
self._raw_stream_response, self._usage_limits, self.usage
157+
)
158+
async for event in usage_checking_stream:
159+
yield event
160+
if (final_result_event := _get_final_result_event(event)) is not None:
161+
self._final_result_event = final_result_event
162+
yield final_result_event
163+
break
164+
165+
# If we broke out of the above loop, we need to yield the rest of the events
166+
# If we didn't, this will just be a no-op
167+
async for event in usage_checking_stream:
168+
yield event
169+
170+
self._agent_stream_iterator = aiter()
171+
return self._agent_stream_iterator
172+
173+
54174
@dataclass
55175
class StreamedRunResult(Generic[AgentDepsT, ResultDataT]):
56176
"""Result of a streamed run that returns structured data via a tool call."""

0 commit comments

Comments
 (0)