Skip to content

Commit 6460a73

Browse files
committed
More cleanup
1 parent e799024 commit 6460a73

File tree

5 files changed

+100
-84
lines changed

5 files changed

+100
-84
lines changed

pydantic_ai_slim/pydantic_ai/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from importlib.metadata import version
22

3-
from .agent import Agent, capture_run_messages
3+
from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
44
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
55
from .tools import RunContext, Tool
66

77
__all__ = (
88
'Agent',
9+
'EndStrategy',
10+
'HandleResponseNode',
11+
'ModelRequestNode',
12+
'UserPromptNode',
913
'capture_run_messages',
1014
'RunContext',
1115
'Tool',

pydantic_ai_slim/pydantic_ai/_agent_graph.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@
3232
ToolDefinition,
3333
)
3434

35+
__all__ = (
36+
'GraphAgentState',
37+
'GraphAgentDeps',
38+
'UserPromptNode',
39+
'ModelRequestNode',
40+
'HandleResponseNode',
41+
'build_run_context',
42+
'capture_run_messages',
43+
)
44+
3545
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
3646

3747
# while waiting for https://github.com/pydantic/logfire/issues/745
@@ -98,13 +108,18 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
98108

99109

100110
@dataclasses.dataclass
101-
class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
111+
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
102112
user_prompt: str
103113

104114
system_prompts: tuple[str, ...]
105115
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
106116
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
107117

118+
async def run(
119+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
120+
) -> ModelRequestNode[DepsT, NodeRunEndT]:
121+
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
122+
108123
async def _get_first_message(
109124
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
110125
) -> _messages.ModelRequest:
@@ -173,14 +188,6 @@ async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.Mod
173188
return messages
174189

175190

176-
@dataclasses.dataclass
177-
class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]):
178-
async def run(
179-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]
180-
) -> ModelRequestNode[DepsT, NodeRunEndT]:
181-
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
182-
183-
184191
async def _prepare_request_parameters(
185192
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
186193
) -> models.ModelRequestParameters:

pydantic_ai_slim/pydantic_ai/agent.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
result,
2626
usage as _usage,
2727
)
28-
from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export
2928
from .result import FinalResult, ResultDataT, StreamedRunResult
3029
from .settings import ModelSettings, merge_model_settings
3130
from .tools import (
@@ -40,7 +39,24 @@
4039
ToolPrepareFunc,
4140
)
4241

43-
__all__ = 'Agent', 'AgentRun', 'AgentRunResult', 'capture_run_messages', 'EndStrategy'
42+
# Re-exporting like this improves auto-import behavior in PyCharm
43+
capture_run_messages = _agent_graph.capture_run_messages
44+
EndStrategy = _agent_graph.EndStrategy
45+
HandleResponseNode = _agent_graph.HandleResponseNode
46+
ModelRequestNode = _agent_graph.ModelRequestNode
47+
UserPromptNode = _agent_graph.UserPromptNode
48+
49+
50+
__all__ = (
51+
'Agent',
52+
'AgentRun',
53+
'AgentRunResult',
54+
'capture_run_messages',
55+
'EndStrategy',
56+
'HandleResponseNode',
57+
'ModelRequestNode',
58+
'UserPromptNode',
59+
)
4460

4561
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
4662

pydantic_ai_slim/pydantic_ai/models/__init__.py

+2-67
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import httpx
1818
from typing_extensions import Literal
1919

20-
from .. import _utils, messages as _messages
2120
from .._parts_manager import ModelResponsePartsManager
2221
from ..exceptions import UserError
2322
from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent
@@ -235,6 +234,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
235234
236235
This method should be implemented by subclasses to translate the vendor-specific stream of events into
237236
pydantic_ai-format events.
237+
238+
It should use the `_parts_manager` to handle deltas, and should update the `_usage` attributes as it goes.
238239
"""
239240
raise NotImplementedError()
240241
# noinspection PyUnreachableCode
@@ -262,72 +263,6 @@ def timestamp(self) -> datetime:
262263
"""Get the timestamp of the response."""
263264
raise NotImplementedError()
264265

265-
async def stream_debounced_events(
266-
self, *, debounce_by: float | None = 0.1
267-
) -> AsyncIterator[list[ModelResponseStreamEvent]]:
268-
"""Stream the response as an async iterable of debounced lists of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
269-
async with _utils.group_by_temporal(self, debounce_by) as group_iter:
270-
async for items in group_iter:
271-
yield items
272-
273-
async def stream_structured(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
274-
"""Stream the response as an async iterable of [`ModelResponse`][pydantic_ai.messages.ModelResponse]s."""
275-
276-
async def _stream_structured_ungrouped() -> AsyncIterator[None]:
277-
# yield None # TODO: Might want to yield right away to ensure we can eagerly emit a ModelResponse even if we are waiting
278-
async for _event in self:
279-
yield None
280-
281-
async with _utils.group_by_temporal(_stream_structured_ungrouped(), debounce_by) as group_iter:
282-
async for _items in group_iter:
283-
yield self.get() # current state of the response
284-
285-
async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
286-
"""Stream the response as an async iterable of text."""
287-
288-
# Define a "merged" version of the iterator that will yield items that have already been retrieved
289-
# and items that we receive while streaming. We define a dedicated async iterator for this so we can
290-
# pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
291-
async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
292-
# yields tuples of (text_content, part_index)
293-
# we don't currently make use of the part_index, but in principle this may be useful
294-
# so we retain it here for now to make possible future refactors simpler
295-
msg = self.get()
296-
for i, part in enumerate(msg.parts):
297-
if isinstance(part, _messages.TextPart) and part.content:
298-
yield part.content, i
299-
300-
async for event in self:
301-
if (
302-
isinstance(event, _messages.PartStartEvent)
303-
and isinstance(event.part, _messages.TextPart)
304-
and event.part.content
305-
):
306-
yield event.part.content, event.index
307-
elif (
308-
isinstance(event, _messages.PartDeltaEvent)
309-
and isinstance(event.delta, _messages.TextPartDelta)
310-
and event.delta.content_delta
311-
):
312-
yield event.delta.content_delta, event.index
313-
314-
async def _stream_text_deltas() -> AsyncIterator[str]:
315-
async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
316-
async for items in group_iter:
317-
# Note: we are currently just dropping the part index on the group here
318-
yield ''.join([content for content, _ in items])
319-
320-
if delta:
321-
async for text in _stream_text_deltas():
322-
yield text
323-
else:
324-
# a quick benchmark shows it's faster to build up a string with concat when we're
325-
# yielding at each step
326-
deltas: list[str] = []
327-
async for text in _stream_text_deltas():
328-
deltas.append(text)
329-
yield ''.join(deltas)
330-
331266

332267
ALLOW_MODEL_REQUESTS = True
333268
"""Whether to allow requests to models.

pydantic_ai_slim/pydantic_ai/result.py

+59-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import logfire_api
1010
from typing_extensions import TypeVar
1111

12-
from . import _result, exceptions, messages as _messages, models
12+
from . import _result, _utils, exceptions, messages as _messages, models
1313
from .tools import AgentDepsT, RunContext
1414
from .usage import Usage, UsageLimits
1515

@@ -160,7 +160,6 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Resu
160160
Returns:
161161
An async iterable of the response data.
162162
"""
163-
self._stream_response.stream_structured(debounce_by=debounce_by)
164163
async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by):
165164
result = await self.validate_structured_result(structured_message, allow_partial=not is_last)
166165
yield result
@@ -183,11 +182,11 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None =
183182

184183
with _logfire.span('response stream text') as lf_span:
185184
if delta:
186-
async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
185+
async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
187186
yield text
188187
else:
189188
combined_validated_text = ''
190-
async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by):
189+
async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by):
191190
combined_validated_text = await self._validate_text_result(text)
192191
yield combined_validated_text
193192
lf_span.set_attribute('combined_text', combined_validated_text)
@@ -214,7 +213,7 @@ async def stream_structured(
214213
yield msg, False
215214
break
216215

217-
async for msg in self._stream_response.stream_structured(debounce_by=debounce_by):
216+
async for msg in self._stream_response_structured(debounce_by=debounce_by):
218217
yield msg, False
219218

220219
msg = self._stream_response.get()
@@ -289,6 +288,61 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None:
289288
self._all_messages.append(message)
290289
await self._on_complete()
291290

291+
async def _stream_response_structured(
292+
self, *, debounce_by: float | None = 0.1
293+
) -> AsyncIterator[_messages.ModelResponse]:
294+
async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter:
295+
async for _items in group_iter:
296+
yield self._stream_response.get()
297+
298+
async def _stream_response_text(
299+
self, *, delta: bool = False, debounce_by: float | None = 0.1
300+
) -> AsyncIterator[str]:
301+
"""Stream the response as an async iterable of text."""
302+
303+
# Define a "merged" version of the iterator that will yield items that have already been retrieved
304+
# and items that we receive while streaming. We define a dedicated async iterator for this so we can
305+
# pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
306+
async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
307+
# yields tuples of (text_content, part_index)
308+
# we don't currently make use of the part_index, but in principle this may be useful
309+
# so we retain it here for now to make possible future refactors simpler
310+
msg = self._stream_response.get()
311+
for i, part in enumerate(msg.parts):
312+
if isinstance(part, _messages.TextPart) and part.content:
313+
yield part.content, i
314+
315+
async for event in self._stream_response:
316+
if (
317+
isinstance(event, _messages.PartStartEvent)
318+
and isinstance(event.part, _messages.TextPart)
319+
and event.part.content
320+
):
321+
yield event.part.content, event.index
322+
elif (
323+
isinstance(event, _messages.PartDeltaEvent)
324+
and isinstance(event.delta, _messages.TextPartDelta)
325+
and event.delta.content_delta
326+
):
327+
yield event.delta.content_delta, event.index
328+
329+
async def _stream_text_deltas() -> AsyncIterator[str]:
330+
async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter:
331+
async for items in group_iter:
332+
# Note: we are currently just dropping the part index on the group here
333+
yield ''.join([content for content, _ in items])
334+
335+
if delta:
336+
async for text in _stream_text_deltas():
337+
yield text
338+
else:
339+
# a quick benchmark shows it's faster to build up a string with concat when we're
340+
# yielding at each step
341+
deltas: list[str] = []
342+
async for text in _stream_text_deltas():
343+
deltas.append(text)
344+
yield ''.join(deltas)
345+
292346

293347
@dataclass
294348
class FinalResult(Generic[ResultDataT]):

0 commit comments

Comments
 (0)