Skip to content

Commit 3edecbe

Browse files
hallacyfps7806hponderachellim
authored
Hallacy/v23 (#116)
* overload output type depending on stream literal (#142) * Bump to v22 * [numpy] change version (#143) * [numpy] change version * update comments * no version for numpy * Fix timeouts (#137) * Fix timeouts * Rename to request_timeout and add to readme * Dev/hallacy/request timeout takes tuples (#144) * Add tuple typing for request_timeout * imports * [api_requestor] Log request_id with response (#145) * Only import wandb as needed (#146) Co-authored-by: Felipe Petroski Such <[email protected]> Co-authored-by: Henrique Oliveira Pinto <[email protected]> Co-authored-by: Rachel Lim <[email protected]>
1 parent f3e3083 commit 3edecbe

11 files changed

+126
-13
lines changed

Diff for: README.md

+4
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ completion = openai.Completion.create(engine="ada", prompt="Hello world")
5252
print(completion.choices[0].text)
5353
```
5454

55+
56+
### Params
57+
All endpoints have a `.create` method that support a `request_timeout` param. This param takes a `Union[float, Tuple[float, float]]` and will raise a `openai.error.TimeoutError` error if the request exceeds that time in seconds (See: https://requests.readthedocs.io/en/latest/user/quickstart/#timeouts).
58+
5559
### Microsoft Azure Endpoints
5660

5761
In order to use the library with Microsoft Azure endpoints, you need to set the api_type, api_base and api_version in addition to the api_key. The api_type must be set to 'azure' and the others correspond to the properties of your endpoint.

Diff for: openai/api_requestor.py

+72-4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import threading
44
import warnings
55
from json import JSONDecodeError
6-
from typing import Dict, Iterator, Optional, Tuple, Union
6+
from typing import Dict, Iterator, Optional, Tuple, Union, overload
77
from urllib.parse import urlencode, urlsplit, urlunsplit
88

99
import requests
10+
from typing_extensions import Literal
1011

1112
import openai
1213
from openai import error, util, version
@@ -99,15 +100,73 @@ def format_app_info(cls, info):
99100
str += " (%s)" % (info["url"],)
100101
return str
101102

103+
@overload
104+
def request(
105+
self,
106+
method,
107+
url,
108+
params,
109+
headers,
110+
files,
111+
stream: Literal[True],
112+
request_id: Optional[str] = ...,
113+
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
114+
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
115+
pass
116+
117+
@overload
118+
def request(
119+
self,
120+
method,
121+
url,
122+
params=...,
123+
headers=...,
124+
files=...,
125+
*,
126+
stream: Literal[True],
127+
request_id: Optional[str] = ...,
128+
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
129+
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
130+
pass
131+
132+
@overload
133+
def request(
134+
self,
135+
method,
136+
url,
137+
params=...,
138+
headers=...,
139+
files=...,
140+
stream: Literal[False] = ...,
141+
request_id: Optional[str] = ...,
142+
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
143+
) -> Tuple[OpenAIResponse, bool, str]:
144+
pass
145+
146+
@overload
147+
def request(
148+
self,
149+
method,
150+
url,
151+
params=...,
152+
headers=...,
153+
files=...,
154+
stream: bool = ...,
155+
request_id: Optional[str] = ...,
156+
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
157+
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
158+
pass
159+
102160
def request(
103161
self,
104162
method,
105163
url,
106164
params=None,
107165
headers=None,
108166
files=None,
109-
stream=False,
167+
stream: bool = False,
110168
request_id: Optional[str] = None,
169+
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
111170
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
112171
result = self.request_raw(
113172
method.lower(),
@@ -117,6 +176,7 @@ def request(
117176
files=files,
118177
stream=stream,
119178
request_id=request_id,
179+
request_timeout=request_timeout,
120180
)
121181
resp, got_stream = self._interpret_response(result, stream)
122182
return resp, got_stream, self.api_key
@@ -179,7 +239,11 @@ def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False
179239
return error.APIError(message, rbody, rcode, resp, rheaders)
180240
else:
181241
return error.APIError(
182-
error_data.get("message"), rbody, rcode, resp, rheaders
242+
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}",
243+
rbody,
244+
rcode,
245+
resp,
246+
rheaders,
183247
)
184248

185249
def request_headers(
@@ -256,6 +320,7 @@ def request_raw(
256320
files=None,
257321
stream: bool = False,
258322
request_id: Optional[str] = None,
323+
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
259324
) -> requests.Response:
260325
abs_url = "%s%s" % (self.api_base, url)
261326
headers = self._validate_headers(supplied_headers)
@@ -295,15 +360,18 @@ def request_raw(
295360
data=data,
296361
files=files,
297362
stream=stream,
298-
timeout=TIMEOUT_SECS,
363+
timeout=request_timeout if request_timeout else TIMEOUT_SECS,
299364
)
365+
except requests.exceptions.Timeout as e:
366+
raise error.Timeout("Request timed out") from e
300367
except requests.exceptions.RequestException as e:
301368
raise error.APIConnectionError("Error communicating with OpenAI") from e
302369
util.log_info(
303370
"OpenAI API response",
304371
path=abs_url,
305372
response_code=result.status_code,
306373
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
374+
request_id=result.headers.get("X-Request-Id"),
307375
)
308376
# Don't read the whole stream for debug logging unless necessary.
309377
if openai.log == "debug":

Diff for: openai/api_resources/abstract/api_resource.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,21 @@ class APIResource(OpenAIObject):
1313
azure_deployments_prefix = "deployments"
1414

1515
@classmethod
16-
def retrieve(cls, id, api_key=None, request_id=None, **params):
16+
def retrieve(
17+
cls, id, api_key=None, request_id=None, request_timeout=None, **params
18+
):
1719
instance = cls(id, api_key, **params)
18-
instance.refresh(request_id=request_id)
20+
instance.refresh(request_id=request_id, request_timeout=request_timeout)
1921
return instance
2022

21-
def refresh(self, request_id=None):
23+
def refresh(self, request_id=None, request_timeout=None):
2224
self.refresh_from(
23-
self.request("get", self.instance_url(), request_id=request_id)
25+
self.request(
26+
"get",
27+
self.instance_url(),
28+
request_id=request_id,
29+
request_timeout=request_timeout,
30+
)
2431
)
2532
return self
2633

Diff for: openai/api_resources/abstract/engine_api_resource.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def create(
7777
timeout = params.pop("timeout", None)
7878
stream = params.get("stream", False)
7979
headers = params.pop("headers", None)
80-
80+
request_timeout = params.pop("request_timeout", None)
8181
typed_api_type = cls._get_api_type_and_version(api_type=api_type)[0]
8282
if typed_api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
8383
if deployment_id is None and engine is None:
@@ -119,6 +119,7 @@ def create(
119119
headers=headers,
120120
stream=stream,
121121
request_id=request_id,
122+
request_timeout=request_timeout,
122123
)
123124

124125
if stream:

Diff for: openai/cli.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import requests
1010

1111
import openai
12-
import openai.wandb_logger
1312
from openai.upload_progress import BufferReader
1413
from openai.validators import (
1514
apply_necessary_remediation,
@@ -542,6 +541,8 @@ def prepare_data(cls, args):
542541
class WandbLogger:
543542
@classmethod
544543
def sync(cls, args):
544+
import openai.wandb_logger
545+
545546
resp = openai.wandb_logger.WandbLogger.sync(
546547
id=args.id,
547548
n_fine_tunes=args.n_fine_tunes,

Diff for: openai/error.py

+4
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ class TryAgain(OpenAIError):
7676
pass
7777

7878

79+
class Timeout(OpenAIError):
80+
pass
81+
82+
7983
class APIConnectionError(OpenAIError):
8084
def __init__(
8185
self,

Diff for: openai/openai_object.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from copy import deepcopy
3-
from typing import Optional
3+
from typing import Optional, Tuple, Union
44

55
import openai
66
from openai import api_requestor, util
@@ -165,6 +165,7 @@ def request(
165165
stream=False,
166166
plain_old_data=False,
167167
request_id: Optional[str] = None,
168+
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
168169
):
169170
if params is None:
170171
params = self._retrieve_params
@@ -182,6 +183,7 @@ def request(
182183
stream=stream,
183184
headers=headers,
184185
request_id=request_id,
186+
request_timeout=request_timeout,
185187
)
186188

187189
if stream:

Diff for: openai/tests/test_endpoints.py

+24
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import io
22
import json
33

4+
import pytest
5+
46
import openai
7+
from openai import error
58

69

710
# FILE TESTS
@@ -34,3 +37,24 @@ def test_completions_model():
3437
result = openai.Completion.create(prompt="This was a test", n=5, model="ada")
3538
assert len(result.choices) == 5
3639
assert result.model.startswith("ada")
40+
41+
42+
def test_timeout_raises_error():
43+
# A query that should take awhile to return
44+
with pytest.raises(error.Timeout):
45+
openai.Completion.create(
46+
prompt="test" * 1000,
47+
n=10,
48+
model="ada",
49+
max_tokens=100,
50+
request_timeout=0.01,
51+
)
52+
53+
54+
def test_timeout_does_not_error():
55+
# A query that should be fast
56+
openai.Completion.create(
57+
prompt="test",
58+
model="ada",
59+
request_timeout=10,
60+
)

Diff for: openai/tests/test_exceptions.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
openai.error.SignatureVerificationError("message", "sig_header?"),
2222
openai.error.APIConnectionError("message!", should_retry=True),
2323
openai.error.TryAgain(),
24+
openai.error.Timeout(),
2425
openai.error.APIError(
2526
message="message",
2627
code=400,

Diff for: openai/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
VERSION = "0.22.1"
1+
VERSION = "0.23.0"

Diff for: setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
"pandas>=1.2.3", # Needed for CLI fine-tuning data preparation tool
2020
"pandas-stubs>=1.1.0.11", # Needed for type hints for mypy
2121
"openpyxl>=3.0.7", # Needed for CLI fine-tuning data preparation tool xlsx format
22-
"numpy>=1.22.0", # To address a vuln in <1.21.6
22+
"numpy",
23+
"typing_extensions", # Needed for type hints for mypy
2324
],
2425
extras_require={
2526
"dev": ["black~=21.6b0", "pytest==6.*"],

0 commit comments

Comments
 (0)