Skip to content

Commit

Permalink
feat: Rework error handling (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
baxen committed Sep 15, 2024
1 parent c4365bf commit ed8bbbf
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 299 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.9.1] - 2024-09-15

- fix: retry only some 400s and raise error details

## [0.9.0] - 2024-09-09

- chore: add just command for releases and update pyproject for changelog (#43)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ai-exchange"
version = "0.9.0"
version = "0.9.1"
description = "a uniform python SDK for message generation with LLMs"
readme = "README.md"
requires-python = ">=3.10"
Expand All @@ -12,6 +12,7 @@ dependencies = [
"jinja2>=3.1.4",
"tiktoken>=0.7.0",
"httpx>=0.27.0",
"tenacity>=9.0.0",
]

[tool.hatch.build.targets.wheel]
Expand Down
25 changes: 16 additions & 9 deletions src/exchange/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,19 @@
from exchange import Message, Tool
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status

ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)


class AnthropicProvider(Provider):
def __init__(self, client: httpx.Client) -> None:
Expand Down Expand Up @@ -138,14 +146,13 @@ def complete(
)
payload = {k: v for k, v in payload.items() if v}

response = self._send_request(payload)

response_data = raise_for_status(response).json()
message = self.anthropic_response_to_message(response_data)
usage = self.get_usage(response_data)
response = self._post(payload)
message = self.anthropic_response_to_message(response)
usage = self.get_usage(response)

return message, usage

@retry_httpx_request()
def _send_request(self, payload: Dict[str, Any]) -> httpx.Response:
return self.client.post(ANTHROPIC_HOST, json=payload)
@retry_procedure
def _post(self, payload: dict) -> httpx.Response:
response = self.client.post(ANTHROPIC_HOST, json=payload)
return raise_for_status(response).json()
29 changes: 18 additions & 11 deletions src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand All @@ -15,6 +16,13 @@
)
from exchange.tool import Tool

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)


class AzureProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI"""
Expand Down Expand Up @@ -91,18 +99,17 @@ def complete(

payload = {k: v for k, v in payload.items() if v}
request_url = f"{self.client.base_url}/chat/completions?api-version={self.api_version}"
response = self._send_request(payload, request_url)
response = self._post(payload, request_url)

# Check for context_length_exceeded error for single, long input message
if "error" in response.json() and len(messages) == 1:
openai_single_message_context_length_exceeded(response.json()["error"])

data = raise_for_status(response).json()
if "error" in response and len(messages) == 1:
openai_single_message_context_length_exceeded(response["error"])

message = openai_response_to_message(data)
usage = self.get_usage(data)
message = openai_response_to_message(response)
usage = self.get_usage(response)
return message, usage

@retry_httpx_request()
def _send_request(self, payload: Any, request_url: str) -> httpx.Response: # noqa: ANN401
return self.client.post(request_url, json=payload)
@retry_procedure
def _post(self, payload: dict, request_url: str) -> dict:
response = self.client.post(request_url, json=payload)
return raise_for_status(response).json()
26 changes: 17 additions & 9 deletions src/exchange/providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status
from exchange.tool import Tool

Expand All @@ -21,6 +22,13 @@

logger = logging.getLogger(__name__)

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)


class AwsClient(httpx.Client):
def __init__(
Expand Down Expand Up @@ -110,7 +118,7 @@ def get_signature_key(key: str, date_stamp: str, region_name: str, service_name:
algorithm = "AWS4-HMAC-SHA256"
credential_scope = f"{date_stamp}/{self.region}/{service}/aws4_request"
string_to_sign = (
f'{algorithm}\n{amz_date}\n{credential_scope}\n'
f"{algorithm}\n{amz_date}\n{credential_scope}\n"
f'{hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()}'
)

Expand Down Expand Up @@ -204,11 +212,10 @@ def complete(
payload = {k: v for k, v in payload.items() if v}

path = f"{self.client.host}model/{model}/converse"
response = self._send_request(payload, path)
raise_for_status(response)
response_message = response.json()["output"]["message"]
response = self._post(payload, path)
response_message = response["output"]["message"]

usage_data = response.json()["usage"]
usage_data = response["usage"]
usage = Usage(
input_tokens=usage_data.get("inputTokens"),
output_tokens=usage_data.get("outputTokens"),
Expand All @@ -217,9 +224,10 @@ def complete(

return self.response_to_message(response_message), usage

@retry_httpx_request()
def _send_request(self, payload: Any, path: str) -> httpx.Response: # noqa: ANN401
return self.client.post(path, json=payload)
@retry_procedure
def _post(self, payload: Any, path: str) -> dict: # noqa: ANN401
response = self.client.post(path, json=payload)
return raise_for_status(response).json()

@staticmethod
def message_to_bedrock_spec(message: Message) -> dict:
Expand Down
26 changes: 17 additions & 9 deletions src/exchange/providers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,24 @@

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import raise_for_status, retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
raise_for_status,
tools_to_openai_spec,
)
from exchange.tool import Tool


retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)


class DatabricksProvider(Provider):
"""Provides chat completions for models on Databricks serving endpoints
Expand Down Expand Up @@ -80,15 +88,15 @@ def complete(
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self._send_request(model, payload)
data = raise_for_status(response).json()
message = openai_response_to_message(data)
usage = self.get_usage(data)
response = self._post(model, payload)
message = openai_response_to_message(response)
usage = self.get_usage(response)
return message, usage

@retry_httpx_request()
def _send_request(self, model: str, payload: Any) -> httpx.Response: # noqa: ANN401
return self.client.post(
@retry_procedure
def _post(self, model: str, payload: dict) -> httpx.Response:
response = self.client.post(
f"serving-endpoints/{model}/invocations",
json=payload,
)
return raise_for_status(response).json()
35 changes: 20 additions & 15 deletions src/exchange/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.retry_with_back_off_decorator import retry_httpx_request
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
Expand All @@ -14,9 +13,18 @@
tools_to_openai_spec,
)
from exchange.tool import Tool
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status

OPENAI_HOST = "https://api.openai.com/"

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)


class OpenAiProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI"""
Expand Down Expand Up @@ -65,28 +73,25 @@ def complete(
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
system_message = [] if model.startswith("o1") else [{"role": "system", "content": system}]
payload = dict(
messages=[
{"role": "system", "content": system},
*messages_to_openai_spec(messages),
],
messages=system_message + messages_to_openai_spec(messages),
model=model,
tools=tools_to_openai_spec(tools) if tools else [],
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self._send_request(payload)
response = self._post(payload)

# Check for context_length_exceeded error for single, long input message
if "error" in response.json() and len(messages) == 1:
openai_single_message_context_length_exceeded(response.json()["error"])

data = raise_for_status(response).json()
if "error" in response and len(messages) == 1:
openai_single_message_context_length_exceeded(response["error"])

message = openai_response_to_message(data)
usage = self.get_usage(data)
message = openai_response_to_message(response)
usage = self.get_usage(response)
return message, usage

@retry_httpx_request()
def _send_request(self, payload: Any) -> httpx.Response: # noqa: ANN401
return self.client.post("v1/chat/completions", json=payload)
@retry_procedure
def _post(self, payload: dict) -> dict:
response = self.client.post("v1/chat/completions", json=payload)
return raise_for_status(response).json()
61 changes: 0 additions & 61 deletions src/exchange/providers/retry_with_back_off_decorator.py

This file was deleted.

17 changes: 16 additions & 1 deletion src/exchange/providers/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
import base64
import json
import re
from typing import Any, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import httpx
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.tool import Tool
from tenacity import retry_if_exception


def retry_if_status(codes: Optional[List[int]] = None, above: Optional[int] = None) -> Callable:
codes = codes or []

def predicate(exc: Exception) -> bool:
if isinstance(exc, httpx.HTTPStatusError):
if exc.response.status_code in codes:
return True
if above and exc.response.status_code >= above:
return True
return False

return retry_if_exception(predicate)


def raise_for_status(response: httpx.Response) -> httpx.Response:
Expand Down
Loading

0 comments on commit ed8bbbf

Please sign in to comment.