Skip to content

Commit 54aff03

Browse files
committed
Update coverage
1 parent 09b4319 commit 54aff03

File tree

2 files changed

+113
-2
lines changed

2 files changed

+113
-2
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1293,7 +1293,7 @@ def next_node(
12931293
return next_node
12941294
if _agent_graph.is_agent_node(next_node):
12951295
return next_node
1296-
raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}')
1296+
raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover
12971297

12981298
@property
12991299
def result(self) -> AgentRunResult[ResultDataT] | None:

tests/test_streaming.py

+112-1
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22

33
import datetime
44
import json
5+
import re
56
from collections.abc import AsyncIterator
7+
from copy import deepcopy
68
from datetime import timezone
9+
from typing import Union
710

811
import pytest
912
from inline_snapshot import snapshot
1013
from pydantic import BaseModel
1114

1215
from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages
16+
from pydantic_ai.agent import AgentRun
1317
from pydantic_ai.messages import (
1418
ModelMessage,
1519
ModelRequest,
@@ -22,7 +26,8 @@
2226
)
2327
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
2428
from pydantic_ai.models.test import TestModel
25-
from pydantic_ai.result import Usage
29+
from pydantic_ai.result import AgentStream, FinalResult, Usage
30+
from pydantic_graph import End
2631

2732
from .conftest import IsNow
2833

@@ -739,3 +744,109 @@ async def test_custom_result_type_default_structured() -> None:
739744
async with agent.run_stream('test', result_type=str) as result:
740745
response = await result.get_data()
741746
assert response == snapshot('success (no tool calls)')
747+
748+
749+
async def test_iter_stream_output():
750+
m = TestModel(custom_result_text='The cat sat on the mat.')
751+
752+
agent = Agent(m)
753+
754+
@agent.result_validator
755+
def result_validator_simple(data: str) -> str:
756+
# Make a substitution in the validated results
757+
return re.sub('cat sat', 'bat sat', data)
758+
759+
run: AgentRun
760+
stream: AgentStream
761+
messages: list[str] = []
762+
763+
stream_usage: Usage | None = None
764+
with agent.iter('Hello') as run:
765+
async for node in run:
766+
if agent.is_model_request_node(node):
767+
async with node.stream(run.ctx) as stream:
768+
async for chunk in stream.stream_output(debounce_by=None):
769+
messages.append(chunk)
770+
stream_usage = deepcopy(stream.usage())
771+
assert run.next_node == End(data=FinalResult(data='The bat sat on the mat.', tool_name=None))
772+
assert (
773+
run.usage()
774+
== stream_usage
775+
== Usage(requests=1, request_tokens=51, response_tokens=7, total_tokens=58, details=None)
776+
)
777+
778+
assert messages == [
779+
'',
780+
'The ',
781+
'The cat ',
782+
'The bat sat ',
783+
'The bat sat on ',
784+
'The bat sat on the ',
785+
'The bat sat on the mat.',
786+
'The bat sat on the mat.',
787+
]
788+
789+
790+
async def test_iter_stream_responses():
791+
m = TestModel(custom_result_text='The cat sat on the mat.')
792+
793+
agent = Agent(m)
794+
795+
@agent.result_validator
796+
def result_validator_simple(data: str) -> str:
797+
# Make a substitution in the validated results
798+
return re.sub('cat sat', 'bat sat', data)
799+
800+
run: AgentRun
801+
stream: AgentStream
802+
messages: list[ModelResponse] = []
803+
with agent.iter('Hello') as run:
804+
async for node in run:
805+
if agent.is_model_request_node(node):
806+
async with node.stream(run.ctx) as stream:
807+
async for chunk in stream.stream_responses(debounce_by=None):
808+
messages.append(chunk)
809+
810+
assert messages == [
811+
ModelResponse(
812+
parts=[TextPart(content=text, part_kind='text')],
813+
model_name='test',
814+
timestamp=IsNow(tz=timezone.utc),
815+
kind='response',
816+
)
817+
for text in [
818+
'',
819+
'',
820+
'The ',
821+
'The cat ',
822+
'The cat sat ',
823+
'The cat sat on ',
824+
'The cat sat on the ',
825+
'The cat sat on the mat.',
826+
]
827+
]
828+
829+
# Note: as you can see above, the result validator is not applied to the streamed responses, just the final result:
830+
assert run.result is not None
831+
assert run.result.data == 'The bat sat on the mat.'
832+
833+
834+
async def test_stream_iter_structured_validator() -> None:
835+
class NotResultType(BaseModel):
836+
not_value: str
837+
838+
agent = Agent[None, ResultType | NotResultType]('test', result_type=Union[ResultType, NotResultType]) # pyright: ignore[reportArgumentType]
839+
840+
@agent.result_validator
841+
def result_validator(data: ResultType | NotResultType) -> ResultType | NotResultType:
842+
assert isinstance(data, ResultType)
843+
return ResultType(value=data.value + ' (validated)')
844+
845+
outputs: list[ResultType] = []
846+
with agent.iter('test') as run:
847+
async for node in run:
848+
if agent.is_model_request_node(node):
849+
async with node.stream(run.ctx) as stream:
850+
async for output in stream.stream_output(debounce_by=None):
851+
outputs.append(output)
852+
assert outputs == [ResultType(value='a (validated)'), ResultType(value='a (validated)')]

0 commit comments

Comments
 (0)