Skip to content

Commit f5d2a5c

Browse files
committed
Improve model names in tests
1 parent 57df786 commit f5d2a5c

File tree

7 files changed

+60
-27
lines changed

7 files changed

+60
-27
lines changed

Diff for: docs/agents.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ async def main():
136136
HandleResponseNode(
137137
model_response=ModelResponse(
138138
parts=[TextPart(content='Paris', part_kind='text')],
139-
model_name='function:model_logic',
139+
model_name='gpt-4o',
140140
timestamp=datetime.datetime(...),
141141
kind='response',
142142
)
@@ -197,7 +197,7 @@ async def main():
197197
HandleResponseNode(
198198
model_response=ModelResponse(
199199
parts=[TextPart(content='Paris', part_kind='text')],
200-
model_name='function:model_logic',
200+
model_name='gpt-4o',
201201
timestamp=datetime.datetime(...),
202202
kind='response',
203203
)
@@ -612,7 +612,7 @@ with capture_run_messages() as messages: # (2)!
612612
part_kind='tool-call',
613613
)
614614
],
615-
model_name='function:model_logic',
615+
model_name='gpt-4o',
616616
timestamp=datetime.datetime(...),
617617
kind='response',
618618
),
@@ -637,7 +637,7 @@ with capture_run_messages() as messages: # (2)!
637637
part_kind='tool-call',
638638
)
639639
],
640-
model_name='function:model_logic',
640+
model_name='gpt-4o',
641641
timestamp=datetime.datetime(...),
642642
kind='response',
643643
),

Diff for: docs/message-history.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ print(result.all_messages())
6262
part_kind='text',
6363
)
6464
],
65-
model_name='function:model_logic',
65+
model_name='gpt-4o',
6666
timestamp=datetime.datetime(...),
6767
kind='response',
6868
),
@@ -193,7 +193,7 @@ print(result2.all_messages())
193193
part_kind='text',
194194
)
195195
],
196-
model_name='function:model_logic',
196+
model_name='gpt-4o',
197197
timestamp=datetime.datetime(...),
198198
kind='response',
199199
),
@@ -214,7 +214,7 @@ print(result2.all_messages())
214214
part_kind='text',
215215
)
216216
],
217-
model_name='function:model_logic',
217+
model_name='gpt-4o',
218218
timestamp=datetime.datetime(...),
219219
kind='response',
220220
),
@@ -273,7 +273,7 @@ print(result2.all_messages())
273273
part_kind='text',
274274
)
275275
],
276-
model_name='function:model_logic',
276+
model_name='gpt-4o',
277277
timestamp=datetime.datetime(...),
278278
kind='response',
279279
),
@@ -294,7 +294,7 @@ print(result2.all_messages())
294294
part_kind='text',
295295
)
296296
],
297-
model_name='function:model_logic',
297+
model_name='gemini-1.5-pro',
298298
timestamp=datetime.datetime(...),
299299
kind='response',
300300
),

Diff for: docs/models.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -694,15 +694,15 @@ print(response.all_messages())
694694
),
695695
ModelResponse(
696696
parts=[TextPart(content='Paris', part_kind='text')],
697-
model_name='function:model_logic',
697+
model_name='claude-3-5-sonnet-latest',
698698
timestamp=datetime.datetime(...),
699699
kind='response',
700700
),
701701
]
702702
"""
703703
```
704704

705-
The `ModelResponse` message above indicates that the result was returned by the Anthropic model, which is the second model specified in the `FallbackModel`.
705+
The `ModelResponse` message above indicates in the `model_name` field that the result was returned by the Anthropic model, which is the second model specified in the `FallbackModel`.
706706

707707
!!! note
708708
Each model's options should be configured individually. For example, `base_url`, `api_key`, and custom clients should be set on each model itself, not on the `FallbackModel`.

Diff for: docs/tools.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ print(dice_result.all_messages())
8989
tool_name='roll_die', args={}, tool_call_id=None, part_kind='tool-call'
9090
)
9191
],
92-
model_name='function:model_logic',
92+
model_name='gemini-1.5-flash',
9393
timestamp=datetime.datetime(...),
9494
kind='response',
9595
),
@@ -114,7 +114,7 @@ print(dice_result.all_messages())
114114
part_kind='tool-call',
115115
)
116116
],
117-
model_name='function:model_logic',
117+
model_name='gemini-1.5-flash',
118118
timestamp=datetime.datetime(...),
119119
kind='response',
120120
),
@@ -137,7 +137,7 @@ print(dice_result.all_messages())
137137
part_kind='text',
138138
)
139139
],
140-
model_name='function:model_logic',
140+
model_name='gemini-1.5-flash',
141141
timestamp=datetime.datetime(...),
142142
kind='response',
143143
),

Diff for: pydantic_ai_slim/pydantic_ai/agent.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ async def main():
365365
HandleResponseNode(
366366
model_response=ModelResponse(
367367
parts=[TextPart(content='Paris', part_kind='text')],
368-
model_name='function:model_logic',
368+
model_name='gpt-4o',
369369
timestamp=datetime.datetime(...),
370370
kind='response',
371371
)
@@ -1214,7 +1214,7 @@ async def main():
12141214
HandleResponseNode(
12151215
model_response=ModelResponse(
12161216
parts=[TextPart(content='Paris', part_kind='text')],
1217-
model_name='function:model_logic',
1217+
model_name='gpt-4o',
12181218
timestamp=datetime.datetime(...),
12191219
kind='response',
12201220
)
@@ -1357,7 +1357,7 @@ async def main():
13571357
HandleResponseNode(
13581358
model_response=ModelResponse(
13591359
parts=[TextPart(content='Paris', part_kind='text')],
1360-
model_name='function:model_logic',
1360+
model_name='gpt-4o',
13611361
timestamp=datetime.datetime(...),
13621362
kind='response',
13631363
)

Diff for: pydantic_ai_slim/pydantic_ai/models/function.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,31 @@ class FunctionModel(Model):
4848
_system: str | None = field(default=None, repr=False)
4949

5050
@overload
51-
def __init__(self, function: FunctionDef) -> None: ...
51+
def __init__(self, function: FunctionDef, *, model_name: str | None = None) -> None: ...
5252

5353
@overload
54-
def __init__(self, *, stream_function: StreamFunctionDef) -> None: ...
54+
def __init__(self, *, stream_function: StreamFunctionDef, model_name: str | None = None) -> None: ...
5555

5656
@overload
57-
def __init__(self, function: FunctionDef, *, stream_function: StreamFunctionDef) -> None: ...
57+
def __init__(
58+
self, function: FunctionDef, *, stream_function: StreamFunctionDef, model_name: str | None = None
59+
) -> None: ...
5860

59-
def __init__(self, function: FunctionDef | None = None, *, stream_function: StreamFunctionDef | None = None):
61+
def __init__(
62+
self,
63+
function: FunctionDef | None = None,
64+
*,
65+
stream_function: StreamFunctionDef | None = None,
66+
model_name: str | None = None,
67+
):
6068
"""Initialize a `FunctionModel`.
6169
6270
Either `function` or `stream_function` must be provided, providing both is allowed.
6371
6472
Args:
6573
function: The function to call for non-streamed requests.
6674
stream_function: The function to call for streamed requests.
75+
model_name: The name of the model. If not provided, a name is generated from the function names.
6776
"""
6877
if function is None and stream_function is None:
6978
raise TypeError('Either `function` or `stream_function` must be provided')
@@ -72,7 +81,7 @@ def __init__(self, function: FunctionDef | None = None, *, stream_function: Stre
7281

7382
function_name = self.function.__name__ if self.function is not None else ''
7483
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
75-
self._model_name = f'function:{function_name}:{stream_function_name}'
84+
self._model_name = model_name or f'function:{function_name}:{stream_function_name}'
7685

7786
async def request(
7887
self,
@@ -95,7 +104,7 @@ async def request(
95104
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
96105
assert isinstance(response_, ModelResponse), response_
97106
response = response_
98-
response.model_name = f'function:{self.function.__name__}'
107+
response.model_name = self._model_name
99108
# TODO is `messages` right here? Should it just be new messages?
100109
return response, _estimate_usage(chain(messages, [response]))
101110

Diff for: tests/test_examples.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytest_examples import CodeExample, EvalExample, find_examples
1717
from pytest_mock import MockerFixture
1818

19+
from pydantic_ai import ModelHTTPError
1920
from pydantic_ai._utils import group_by_temporal
2021
from pydantic_ai.exceptions import UnexpectedModelBehavior
2122
from pydantic_ai.messages import (
@@ -27,8 +28,10 @@
2728
ToolReturnPart,
2829
UserPromptPart,
2930
)
30-
from pydantic_ai.models import KnownModelName, Model
31+
from pydantic_ai.models import KnownModelName, Model, infer_model
32+
from pydantic_ai.models.fallback import FallbackModel
3133
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
34+
from pydantic_ai.models.openai import OpenAIModel
3235
from pydantic_ai.models.test import TestModel
3336

3437
from .conftest import ClientWithHandler, TestEnv
@@ -77,6 +80,7 @@ def test_docs_examples(
7780
env.set('OPENAI_API_KEY', 'testing')
7881
env.set('GEMINI_API_KEY', 'testing')
7982
env.set('GROQ_API_KEY', 'testing')
83+
env.set('CO_API_KEY', 'testing')
8084

8185
sys.path.append('tests/example_modules')
8286

@@ -377,12 +381,32 @@ async def stream_model_logic(
377381

378382

379383
def mock_infer_model(model: Model | KnownModelName) -> Model:
384+
if model == 'test':
385+
return TestModel()
386+
387+
if isinstance(model, str):
388+
# Use the non-mocked model inference to ensure we get the same model name the user would
389+
model = infer_model(model)
390+
391+
if isinstance(model, FallbackModel):
392+
# When a fallback model is encountered, replace any OpenAIModel with a model that will raise a ModelHTTPError.
393+
# Otherwise, do the usual inference.
394+
def raise_http_error(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # pragma: no cover
395+
raise ModelHTTPError(401, 'Invalid API Key')
396+
397+
mock_fallback_models: list[Model] = []
398+
for m in model.models:
399+
if isinstance(m, OpenAIModel):
400+
# Raise an HTTP error for OpenAIModel
401+
mock_fallback_models.append(FunctionModel(raise_http_error, model_name=m.model_name))
402+
else:
403+
mock_fallback_models.append(mock_infer_model(m))
404+
return FallbackModel(*mock_fallback_models)
380405
if isinstance(model, (FunctionModel, TestModel)):
381406
return model
382-
elif model == 'test':
383-
return TestModel()
384407
else:
385-
return FunctionModel(model_logic, stream_function=stream_model_logic)
408+
model_name = model if isinstance(model, str) else model.model_name
409+
return FunctionModel(model_logic, stream_function=stream_model_logic, model_name=model_name)
386410

387411

388412
def mock_group_by_temporal(aiter: Any, soft_max_interval: float | None) -> Any:

0 commit comments

Comments
 (0)