Skip to content

Add FallbackModel support #894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
d66b3a2
fallback proof of concept
sydney-runkle Feb 11, 2025
9956c96
remove model name updates
sydney-runkle Feb 11, 2025
a821380
catching 4xx and 5xx business
sydney-runkle Feb 11, 2025
a4b4ebb
fixing test
sydney-runkle Feb 11, 2025
b4a6b1a
move groq import
sydney-runkle Feb 11, 2025
4e2b089
return after yield
sydney-runkle Feb 11, 2025
9e7f53b
intro docs
sydney-runkle Feb 11, 2025
737cf2e
non streaming testing
sydney-runkle Feb 11, 2025
3113ae5
initial tests
sydney-runkle Feb 11, 2025
bf477e2
Comprehension golfing, thanks @alexmojaki
sydney-runkle Feb 12, 2025
47b4901
fix linting issue
sydney-runkle Feb 13, 2025
3b95211
openai test with exceptions
sydney-runkle Feb 13, 2025
992d540
Merge branch 'main' into fallback-model-updated
sydney-runkle Feb 13, 2025
25afb86
adding model_name and system abstract methods
sydney-runkle Feb 13, 2025
7a23eb1
Merge branch 'main' into fallback-model-updated
sydney-runkle Feb 13, 2025
31150ff
using sequence and fixing 3.9 tests
sydney-runkle Feb 13, 2025
b83bc5e
using list with type hinting
sydney-runkle Feb 13, 2025
154d907
tests
sydney-runkle Feb 13, 2025
0f49228
streaming tests
sydney-runkle Feb 13, 2025
7cb8267
Get tests passing
dmontagu Feb 13, 2025
ff7b596
Minor cleanup
dmontagu Feb 13, 2025
68b0f0b
type alias testing cleanup
sydney-runkle Feb 14, 2025
d614df2
exception group like
sydney-runkle Feb 14, 2025
9ac1f8e
adding fallback failure example
sydney-runkle Feb 14, 2025
82b6580
fix f string issue
sydney-runkle Feb 14, 2025
538ac12
try moving type alias definitions into protected import block
sydney-runkle Feb 14, 2025
ba5f778
docs updates + fixing 3.9 tests
sydney-runkle Feb 14, 2025
b11d971
Merge branch 'main' into fallback-model-updated
dmontagu Feb 18, 2025
b2f499a
Merge remote-tracking branch 'origin/main' into fallback-model-updated
Kludex Feb 25, 2025
0088e4c
Fix test
Kludex Feb 25, 2025
0b025e5
Push exceptiongroup
Kludex Feb 25, 2025
9c32626
tests passing
Kludex Feb 25, 2025
59c246e
Apply David's comments
Kludex Feb 25, 2025
6b5cdba
Update docs/models.md
Kludex Feb 25, 2025
57df786
Update docs
dmontagu Feb 25, 2025
f5d2a5c
Improve model names in tests
dmontagu Feb 25, 2025
da98197
Fix tests
dmontagu Feb 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def main():
HandleResponseNode(
model_response=ModelResponse(
parts=[TextPart(content='Paris', part_kind='text')],
model_name='function:model_logic',
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
)
Expand Down Expand Up @@ -197,7 +197,7 @@ async def main():
HandleResponseNode(
model_response=ModelResponse(
parts=[TextPart(content='Paris', part_kind='text')],
model_name='function:model_logic',
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
)
Expand Down Expand Up @@ -612,7 +612,7 @@ with capture_run_messages() as messages: # (2)!
part_kind='tool-call',
)
],
model_name='function:model_logic',
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
),
Expand All @@ -637,7 +637,7 @@ with capture_run_messages() as messages: # (2)!
part_kind='tool-call',
)
],
model_name='function:model_logic',
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
),
Expand Down
3 changes: 3 additions & 0 deletions docs/api/models/fallback.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# pydantic_ai.models.fallback

::: pydantic_ai.models.fallback
12 changes: 6 additions & 6 deletions docs/message-history.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ print(result.all_messages())
part_kind='text',
)
],
model_name='function:model_logic',
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
),
Expand Down Expand Up @@ -136,7 +136,7 @@ async def main():
part_kind='text',
)
],
model_name='function:stream_model_logic',
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
),
Expand Down Expand Up @@ -193,7 +193,7 @@ print(result2.all_messages())
part_kind='text',
)
],
model_name='function:model_logic',
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
),
Expand All @@ -214,7 +214,7 @@ print(result2.all_messages())
part_kind='text',
)
],
model_name='function:model_logic',
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
),
Expand Down Expand Up @@ -273,7 +273,7 @@ print(result2.all_messages())
part_kind='text',
)
],
model_name='function:model_logic',
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
),
Expand All @@ -294,7 +294,7 @@ print(result2.all_messages())
part_kind='text',
)
],
model_name='function:model_logic',
model_name='gemini-1.5-pro',
timestamp=datetime.datetime(...),
kind='response',
),
Expand Down
112 changes: 112 additions & 0 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,115 @@ For streaming, you'll also need to implement the following abstract base class:
The best place to start is to review the source code for existing implementations, e.g. [`OpenAIModel`](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/models/openai.py).

For details on when we'll accept contributions adding new models to PydanticAI, see the [contributing guidelines](contributing.md#new-model-rules).


## Fallback

You can use [`FallbackModel`][pydantic_ai.models.fallback.FallbackModel] to attempt multiple models
in sequence until one returns a successful result. Under the hood, PydanticAI automatically switches
from one model to the next if the current model returns a 4xx or 5xx status code.

In the following example, the agent first makes a request to the OpenAI model (which fails due to an invalid API key),
and then falls back to the Anthropic model.

```python {title="fallback_model.py"}
from pydantic_ai import Agent
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.fallback import FallbackModel
from pydantic_ai.models.openai import OpenAIModel

openai_model = OpenAIModel('gpt-4o', api_key='not-valid')
anthropic_model = AnthropicModel('claude-3-5-sonnet-latest')
fallback_model = FallbackModel(openai_model, anthropic_model)

agent = Agent(fallback_model)
response = agent.run_sync('What is the capital of France?')
print(response.data)
#> Paris

print(response.all_messages())
"""
[
ModelRequest(
parts=[
UserPromptPart(
content='What is the capital of France?',
timestamp=datetime.datetime(...),
part_kind='user-prompt',
)
],
kind='request',
),
ModelResponse(
parts=[TextPart(content='Paris', part_kind='text')],
model_name='claude-3-5-sonnet-latest',
timestamp=datetime.datetime(...),
kind='response',
),
]
"""
```

The `ModelResponse` message above indicates in the `model_name` field that the result was returned by the Anthropic model, which is the second model specified in the `FallbackModel`.

!!! note
Each model's options should be configured individually. For example, `base_url`, `api_key`, and custom clients should be set on each model itself, not on the `FallbackModel`.

In this next example, we demonstrate the exception-handling capabilities of `FallbackModel`.
If all models fail, a [`FallbackExceptionGroup`][pydantic_ai.exceptions.FallbackExceptionGroup] is raised, which
contains all the exceptions encountered during the `run` execution.

=== "Python >=3.11"

```python {title="fallback_model_failure.py" py="3.11"}
from pydantic_ai import Agent
from pydantic_ai.exceptions import ModelHTTPError
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.fallback import FallbackModel
from pydantic_ai.models.openai import OpenAIModel

openai_model = OpenAIModel('gpt-4o', api_key='not-valid')
anthropic_model = AnthropicModel('claude-3-5-sonnet-latest', api_key='not-valid')
fallback_model = FallbackModel(openai_model, anthropic_model)

agent = Agent(fallback_model)
try:
response = agent.run_sync('What is the capital of France?')
except* ModelHTTPError as exc_group:
for exc in exc_group.exceptions:
print(exc)
```

=== "Python <3.11"

Since [`except*`](https://docs.python.org/3/reference/compound_stmts.html#except-star) is only supported
in Python 3.11+, we use the [`exceptiongroup`](https://github.com/agronholm/exceptiongroup) backport
package for earlier Python versions:

```python {title="fallback_model_failure.py" noqa="F821" test="skip"}
from exceptiongroup import catch

from pydantic_ai import Agent
from pydantic_ai.exceptions import ModelHTTPError
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.fallback import FallbackModel
from pydantic_ai.models.openai import OpenAIModel


def model_status_error_handler(exc_group: BaseExceptionGroup) -> None:
for exc in exc_group.exceptions:
print(exc)


openai_model = OpenAIModel('gpt-4o', api_key='not-valid')
anthropic_model = AnthropicModel('claude-3-5-sonnet-latest', api_key='not-valid')
fallback_model = FallbackModel(openai_model, anthropic_model)

agent = Agent(fallback_model)
with catch({ModelHTTPError: model_status_error_handler}):
response = agent.run_sync('What is the capital of France?')
```

By default, the `FallbackModel` only moves on to the next model if the current model raises a
[`ModelHTTPError`][pydantic_ai.exceptions.ModelHTTPError]. You can customize this behavior by
passing a custom `fallback_on` argument to the `FallbackModel` constructor.
6 changes: 3 additions & 3 deletions docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ print(dice_result.all_messages())
tool_name='roll_die', args={}, tool_call_id=None, part_kind='tool-call'
)
],
model_name='function:model_logic',
model_name='gemini-1.5-flash',
timestamp=datetime.datetime(...),
kind='response',
),
Expand All @@ -114,7 +114,7 @@ print(dice_result.all_messages())
part_kind='tool-call',
)
],
model_name='function:model_logic',
model_name='gemini-1.5-flash',
timestamp=datetime.datetime(...),
kind='response',
),
Expand All @@ -137,7 +137,7 @@ print(dice_result.all_messages())
part_kind='text',
)
],
model_name='function:model_logic',
model_name='gemini-1.5-flash',
timestamp=datetime.datetime(...),
kind='response',
),
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ nav:
- api/models/mistral.md
- api/models/test.md
- api/models/function.md
- api/models/fallback.md
- api/pydantic_graph/graph.md
- api/pydantic_graph/nodes.md
- api/pydantic_graph/state.md
Expand Down
12 changes: 11 additions & 1 deletion pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from importlib.metadata import version

from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
from .exceptions import (
AgentRunError,
FallbackExceptionGroup,
ModelHTTPError,
ModelRetry,
UnexpectedModelBehavior,
UsageLimitExceeded,
UserError,
)
from .messages import AudioUrl, BinaryContent, ImageUrl
from .tools import RunContext, Tool

Expand All @@ -17,6 +25,8 @@
# exceptions
'AgentRunError',
'ModelRetry',
'ModelHTTPError',
'FallbackExceptionGroup',
'UnexpectedModelBehavior',
'UsageLimitExceeded',
'UserError',
Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ async def main():
HandleResponseNode(
model_response=ModelResponse(
parts=[TextPart(content='Paris', part_kind='text')],
model_name='function:model_logic',
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
)
Expand Down Expand Up @@ -1214,7 +1214,7 @@ async def main():
HandleResponseNode(
model_response=ModelResponse(
parts=[TextPart(content='Paris', part_kind='text')],
model_name='function:model_logic',
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
)
Expand Down Expand Up @@ -1357,7 +1357,7 @@ async def main():
HandleResponseNode(
model_response=ModelResponse(
parts=[TextPart(content='Paris', part_kind='text')],
model_name='function:model_logic',
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
)
Expand Down
43 changes: 42 additions & 1 deletion pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
from __future__ import annotations as _annotations

import json
import sys

__all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded'
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
else:
ExceptionGroup = ExceptionGroup

__all__ = (
'ModelRetry',
'UserError',
'AgentRunError',
'UnexpectedModelBehavior',
'UsageLimitExceeded',
'ModelHTTPError',
'FallbackExceptionGroup',
)


class ModelRetry(Exception):
Expand Down Expand Up @@ -72,3 +86,30 @@ def __str__(self) -> str:
return f'{self.message}, body:\n{self.body}'
else:
return self.message


class ModelHTTPError(AgentRunError):
"""Raised when an model provider response has a status code of 4xx or 5xx."""

status_code: int
"""The HTTP status code returned by the API."""

model_name: str
"""The name of the model associated with the error."""

body: object | None
"""The body of the response, if available."""

message: str
"""The error message with the status code and response body, if available."""

def __init__(self, status_code: int, model_name: str, body: object | None = None):
self.status_code = status_code
self.model_name = model_name
self.body = body
message = f'status_code: {status_code}, model_name: {model_name}, body: {body}'
super().__init__(message)


class FallbackExceptionGroup(ExceptionGroup):
"""A group of exceptions that can be raised when all fallback models fail."""
Loading