|
7 | 7 | from typing import Generic, Union, cast
|
8 | 8 |
|
9 | 9 | import logfire_api
|
10 |
| -from typing_extensions import TypeVar |
| 10 | +from typing_extensions import TypeVar, assert_type |
11 | 11 |
|
12 | 12 | from . import _result, _utils, exceptions, messages as _messages, models
|
| 13 | +from .messages import AgentStreamEvent, FinalResultEvent |
13 | 14 | from .tools import AgentDepsT, RunContext
|
14 | 15 | from .usage import Usage, UsageLimits
|
15 | 16 |
|
|
51 | 52 | _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
52 | 53 |
|
53 | 54 |
|
| 55 | +@dataclass |
| 56 | +class AgentStream(Generic[AgentDepsT, ResultDataT]): |
| 57 | + _raw_stream_response: models.StreamedResponse |
| 58 | + _result_schema: _result.ResultSchema[ResultDataT] | None |
| 59 | + _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] |
| 60 | + _run_ctx: RunContext[AgentDepsT] |
| 61 | + _usage_limits: UsageLimits | None |
| 62 | + |
| 63 | + _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) |
| 64 | + _final_result_event: FinalResultEvent | None = field(default=None, init=False) |
| 65 | + _initial_run_ctx_usage: Usage = field(init=False) |
| 66 | + |
| 67 | + def __post_init__(self): |
| 68 | + self._initial_run_ctx_usage = copy(self._run_ctx.usage) |
| 69 | + |
| 70 | + async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]: |
| 71 | + """Asynchronously stream the (validated) agent outputs.""" |
| 72 | + async for response in self.stream_responses(debounce_by=debounce_by): |
| 73 | + if self._final_result_event is not None: |
| 74 | + yield await self._validate_response(response, self._final_result_event.tool_name, allow_partial=True) |
| 75 | + if self._final_result_event is not None: |
| 76 | + yield await self._validate_response( |
| 77 | + self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False |
| 78 | + ) |
| 79 | + |
| 80 | + async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]: |
| 81 | + """Asynchronously stream the (unvalidated) model responses for the agent.""" |
| 82 | + # if the message currently has any parts with content, yield before streaming |
| 83 | + msg = self._raw_stream_response.get() |
| 84 | + for part in msg.parts: |
| 85 | + if part.has_content(): |
| 86 | + yield msg |
| 87 | + break |
| 88 | + |
| 89 | + async with _utils.group_by_temporal(self, debounce_by) as group_iter: |
| 90 | + async for _items in group_iter: |
| 91 | + yield self._raw_stream_response.get() # current state of the response |
| 92 | + |
| 93 | + def usage(self) -> Usage: |
| 94 | + """Return the usage of the whole run. |
| 95 | +
|
| 96 | + !!! note |
| 97 | + This won't return the full usage until the stream is finished. |
| 98 | + """ |
| 99 | + return self._initial_run_ctx_usage + self._raw_stream_response.usage() |
| 100 | + |
| 101 | + async def _validate_response( |
| 102 | + self, message: _messages.ModelResponse, result_tool_name: str | None, *, allow_partial: bool = False |
| 103 | + ) -> ResultDataT: |
| 104 | + """Validate a structured result message.""" |
| 105 | + if self._result_schema is not None and result_tool_name is not None: |
| 106 | + match = self._result_schema.find_named_tool(message.parts, result_tool_name) |
| 107 | + if match is None: |
| 108 | + raise exceptions.UnexpectedModelBehavior( |
| 109 | + f'Invalid response, unable to find tool: {self._result_schema.tool_names()}' |
| 110 | + ) |
| 111 | + |
| 112 | + call, result_tool = match |
| 113 | + result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) |
| 114 | + |
| 115 | + for validator in self._result_validators: |
| 116 | + result_data = await validator.validate(result_data, call, self._run_ctx) |
| 117 | + return result_data |
| 118 | + else: |
| 119 | + text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) |
| 120 | + for validator in self._result_validators: |
| 121 | + text = await validator.validate( |
| 122 | + text, |
| 123 | + None, |
| 124 | + self._run_ctx, |
| 125 | + ) |
| 126 | + # Since there is no result tool, we can assume that str is compatible with ResultDataT |
| 127 | + return cast(ResultDataT, text) |
| 128 | + |
| 129 | + def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: |
| 130 | + """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. |
| 131 | +
|
| 132 | + This proxies the _raw_stream_response and sends all events to the agent stream, while also checking for matches |
| 133 | + on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the |
| 134 | + first match is found. |
| 135 | + """ |
| 136 | + if self._agent_stream_iterator is not None: |
| 137 | + return self._agent_stream_iterator |
| 138 | + |
| 139 | + async def aiter(): |
| 140 | + result_schema = self._result_schema |
| 141 | + allow_text_result = result_schema is None or result_schema.allow_text_result |
| 142 | + |
| 143 | + def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None: |
| 144 | + """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" |
| 145 | + if isinstance(e, _messages.PartStartEvent): |
| 146 | + new_part = e.part |
| 147 | + if isinstance(new_part, _messages.ToolCallPart): |
| 148 | + if result_schema is not None and (match := result_schema.find_tool([new_part])): |
| 149 | + call, _ = match |
| 150 | + return _messages.FinalResultEvent(tool_name=call.tool_name) |
| 151 | + elif allow_text_result: |
| 152 | + assert_type(e, _messages.PartStartEvent) |
| 153 | + return _messages.FinalResultEvent(tool_name=None) |
| 154 | + |
| 155 | + usage_checking_stream = _get_usage_checking_stream_response( |
| 156 | + self._raw_stream_response, self._usage_limits, self.usage |
| 157 | + ) |
| 158 | + async for event in usage_checking_stream: |
| 159 | + yield event |
| 160 | + if (final_result_event := _get_final_result_event(event)) is not None: |
| 161 | + self._final_result_event = final_result_event |
| 162 | + yield final_result_event |
| 163 | + break |
| 164 | + |
| 165 | + # If we broke out of the above loop, we need to yield the rest of the events |
| 166 | + # If we didn't, this will just be a no-op |
| 167 | + async for event in usage_checking_stream: |
| 168 | + yield event |
| 169 | + |
| 170 | + self._agent_stream_iterator = aiter() |
| 171 | + return self._agent_stream_iterator |
| 172 | + |
| 173 | + |
54 | 174 | @dataclass
|
55 | 175 | class StreamedRunResult(Generic[AgentDepsT, ResultDataT]):
|
56 | 176 | """Result of a streamed run that returns structured data via a tool call."""
|
|
0 commit comments