|
2 | 2 |
|
3 | 3 | import datetime
|
4 | 4 | import json
|
| 5 | +import re |
5 | 6 | from collections.abc import AsyncIterator
|
| 7 | +from copy import deepcopy |
6 | 8 | from datetime import timezone
|
| 9 | +from typing import Union |
7 | 10 |
|
8 | 11 | import pytest
|
9 | 12 | from inline_snapshot import snapshot
|
10 | 13 | from pydantic import BaseModel
|
11 | 14 |
|
12 | 15 | from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages
|
| 16 | +from pydantic_ai.agent import AgentRun |
13 | 17 | from pydantic_ai.messages import (
|
14 | 18 | ModelMessage,
|
15 | 19 | ModelRequest,
|
|
22 | 26 | )
|
23 | 27 | from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
|
24 | 28 | from pydantic_ai.models.test import TestModel
|
25 |
| -from pydantic_ai.result import Usage |
| 29 | +from pydantic_ai.result import AgentStream, FinalResult, Usage |
| 30 | +from pydantic_graph import End |
26 | 31 |
|
27 | 32 | from .conftest import IsNow
|
28 | 33 |
|
@@ -739,3 +744,109 @@ async def test_custom_result_type_default_structured() -> None:
|
739 | 744 | async with agent.run_stream('test', result_type=str) as result:
|
740 | 745 | response = await result.get_data()
|
741 | 746 | assert response == snapshot('success (no tool calls)')
|
| 747 | + |
| 748 | + |
| 749 | +async def test_iter_stream_output(): |
| 750 | + m = TestModel(custom_result_text='The cat sat on the mat.') |
| 751 | + |
| 752 | + agent = Agent(m) |
| 753 | + |
| 754 | + @agent.result_validator |
| 755 | + def result_validator_simple(data: str) -> str: |
| 756 | + # Make a substitution in the validated results |
| 757 | + return re.sub('cat sat', 'bat sat', data) |
| 758 | + |
| 759 | + run: AgentRun |
| 760 | + stream: AgentStream |
| 761 | + messages: list[str] = [] |
| 762 | + |
| 763 | + stream_usage: Usage | None = None |
| 764 | + with agent.iter('Hello') as run: |
| 765 | + async for node in run: |
| 766 | + if agent.is_model_request_node(node): |
| 767 | + async with node.stream(run.ctx) as stream: |
| 768 | + async for chunk in stream.stream_output(debounce_by=None): |
| 769 | + messages.append(chunk) |
| 770 | + stream_usage = deepcopy(stream.usage()) |
| 771 | + assert run.next_node == End(data=FinalResult(data='The bat sat on the mat.', tool_name=None)) |
| 772 | + assert ( |
| 773 | + run.usage() |
| 774 | + == stream_usage |
| 775 | + == Usage(requests=1, request_tokens=51, response_tokens=7, total_tokens=58, details=None) |
| 776 | + ) |
| 777 | + |
| 778 | + assert messages == [ |
| 779 | + '', |
| 780 | + 'The ', |
| 781 | + 'The cat ', |
| 782 | + 'The bat sat ', |
| 783 | + 'The bat sat on ', |
| 784 | + 'The bat sat on the ', |
| 785 | + 'The bat sat on the mat.', |
| 786 | + 'The bat sat on the mat.', |
| 787 | + ] |
| 788 | + |
| 789 | + |
| 790 | +async def test_iter_stream_responses(): |
| 791 | + m = TestModel(custom_result_text='The cat sat on the mat.') |
| 792 | + |
| 793 | + agent = Agent(m) |
| 794 | + |
| 795 | + @agent.result_validator |
| 796 | + def result_validator_simple(data: str) -> str: |
| 797 | + # Make a substitution in the validated results |
| 798 | + return re.sub('cat sat', 'bat sat', data) |
| 799 | + |
| 800 | + run: AgentRun |
| 801 | + stream: AgentStream |
| 802 | + messages: list[ModelResponse] = [] |
| 803 | + with agent.iter('Hello') as run: |
| 804 | + async for node in run: |
| 805 | + if agent.is_model_request_node(node): |
| 806 | + async with node.stream(run.ctx) as stream: |
| 807 | + async for chunk in stream.stream_responses(debounce_by=None): |
| 808 | + messages.append(chunk) |
| 809 | + |
| 810 | + assert messages == [ |
| 811 | + ModelResponse( |
| 812 | + parts=[TextPart(content=text, part_kind='text')], |
| 813 | + model_name='test', |
| 814 | + timestamp=IsNow(tz=timezone.utc), |
| 815 | + kind='response', |
| 816 | + ) |
| 817 | + for text in [ |
| 818 | + '', |
| 819 | + '', |
| 820 | + 'The ', |
| 821 | + 'The cat ', |
| 822 | + 'The cat sat ', |
| 823 | + 'The cat sat on ', |
| 824 | + 'The cat sat on the ', |
| 825 | + 'The cat sat on the mat.', |
| 826 | + ] |
| 827 | + ] |
| 828 | + |
| 829 | + # Note: as you can see above, the result validator is not applied to the streamed responses, just the final result: |
| 830 | + assert run.result is not None |
| 831 | + assert run.result.data == 'The bat sat on the mat.' |
| 832 | + |
| 833 | + |
| 834 | +async def test_stream_iter_structured_validator() -> None: |
| 835 | + class NotResultType(BaseModel): |
| 836 | + not_value: str |
| 837 | + |
| 838 | + agent = Agent[None, ResultType | NotResultType]('test', result_type=Union[ResultType, NotResultType]) # pyright: ignore[reportArgumentType] |
| 839 | + |
| 840 | + @agent.result_validator |
| 841 | + def result_validator(data: ResultType | NotResultType) -> ResultType | NotResultType: |
| 842 | + assert isinstance(data, ResultType) |
| 843 | + return ResultType(value=data.value + ' (validated)') |
| 844 | + |
| 845 | + outputs: list[ResultType] = [] |
| 846 | + with agent.iter('test') as run: |
| 847 | + async for node in run: |
| 848 | + if agent.is_model_request_node(node): |
| 849 | + async with node.stream(run.ctx) as stream: |
| 850 | + async for output in stream.stream_output(debounce_by=None): |
| 851 | + outputs.append(output) |
| 852 | + assert outputs == [ResultType(value='a (validated)'), ResultType(value='a (validated)')] |
0 commit comments