9
9
import logfire_api
10
10
from typing_extensions import TypeVar
11
11
12
- from . import _result , exceptions , messages as _messages , models
12
+ from . import _result , _utils , exceptions , messages as _messages , models
13
13
from .tools import AgentDepsT , RunContext
14
14
from .usage import Usage , UsageLimits
15
15
@@ -160,7 +160,6 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Resu
160
160
Returns:
161
161
An async iterable of the response data.
162
162
"""
163
- self ._stream_response .stream_structured (debounce_by = debounce_by )
164
163
async for structured_message , is_last in self .stream_structured (debounce_by = debounce_by ):
165
164
result = await self .validate_structured_result (structured_message , allow_partial = not is_last )
166
165
yield result
@@ -183,11 +182,11 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None =
183
182
184
183
with _logfire .span ('response stream text' ) as lf_span :
185
184
if delta :
186
- async for text in self ._stream_response . stream_text (delta = delta , debounce_by = debounce_by ):
185
+ async for text in self ._stream_response_text (delta = delta , debounce_by = debounce_by ):
187
186
yield text
188
187
else :
189
188
combined_validated_text = ''
190
- async for text in self ._stream_response . stream_text (delta = delta , debounce_by = debounce_by ):
189
+ async for text in self ._stream_response_text (delta = delta , debounce_by = debounce_by ):
191
190
combined_validated_text = await self ._validate_text_result (text )
192
191
yield combined_validated_text
193
192
lf_span .set_attribute ('combined_text' , combined_validated_text )
@@ -214,7 +213,7 @@ async def stream_structured(
214
213
yield msg , False
215
214
break
216
215
217
- async for msg in self ._stream_response . stream_structured (debounce_by = debounce_by ):
216
+ async for msg in self ._stream_response_structured (debounce_by = debounce_by ):
218
217
yield msg , False
219
218
220
219
msg = self ._stream_response .get ()
@@ -289,6 +288,61 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None:
289
288
self ._all_messages .append (message )
290
289
await self ._on_complete ()
291
290
291
+ async def _stream_response_structured (
292
+ self , * , debounce_by : float | None = 0.1
293
+ ) -> AsyncIterator [_messages .ModelResponse ]:
294
+ async with _utils .group_by_temporal (self ._stream_response , debounce_by ) as group_iter :
295
+ async for _items in group_iter :
296
+ yield self ._stream_response .get ()
297
+
298
+ async def _stream_response_text (
299
+ self , * , delta : bool = False , debounce_by : float | None = 0.1
300
+ ) -> AsyncIterator [str ]:
301
+ """Stream the response as an async iterable of text."""
302
+
303
+ # Define a "merged" version of the iterator that will yield items that have already been retrieved
304
+ # and items that we receive while streaming. We define a dedicated async iterator for this so we can
305
+ # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
306
+ async def _stream_text_deltas_ungrouped () -> AsyncIterator [tuple [str , int ]]:
307
+ # yields tuples of (text_content, part_index)
308
+ # we don't currently make use of the part_index, but in principle this may be useful
309
+ # so we retain it here for now to make possible future refactors simpler
310
+ msg = self ._stream_response .get ()
311
+ for i , part in enumerate (msg .parts ):
312
+ if isinstance (part , _messages .TextPart ) and part .content :
313
+ yield part .content , i
314
+
315
+ async for event in self ._stream_response :
316
+ if (
317
+ isinstance (event , _messages .PartStartEvent )
318
+ and isinstance (event .part , _messages .TextPart )
319
+ and event .part .content
320
+ ):
321
+ yield event .part .content , event .index
322
+ elif (
323
+ isinstance (event , _messages .PartDeltaEvent )
324
+ and isinstance (event .delta , _messages .TextPartDelta )
325
+ and event .delta .content_delta
326
+ ):
327
+ yield event .delta .content_delta , event .index
328
+
329
+ async def _stream_text_deltas () -> AsyncIterator [str ]:
330
+ async with _utils .group_by_temporal (_stream_text_deltas_ungrouped (), debounce_by ) as group_iter :
331
+ async for items in group_iter :
332
+ # Note: we are currently just dropping the part index on the group here
333
+ yield '' .join ([content for content , _ in items ])
334
+
335
+ if delta :
336
+ async for text in _stream_text_deltas ():
337
+ yield text
338
+ else :
339
+ # a quick benchmark shows it's faster to build up a string with concat when we're
340
+ # yielding at each step
341
+ deltas : list [str ] = []
342
+ async for text in _stream_text_deltas ():
343
+ deltas .append (text )
344
+ yield '' .join (deltas )
345
+
292
346
293
347
@dataclass
294
348
class FinalResult (Generic [ResultDataT ]):
0 commit comments