Skip to content

Commit c91dc77

Browse files
authoredSep 25, 2024
Merge branch 'main' into api_validation
2 parents ed2bec7 + feef9d4 commit c91dc77

File tree

6 files changed

+160
-4
lines changed

6 files changed

+160
-4
lines changed
 

‎autogen/oai/client.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from autogen.runtime_logging import log_chat_completion, log_new_client, log_new_wrapper, logging_enabled
1717
from autogen.token_count_utils import count_token
1818

19+
from .rate_limiters import RateLimiter, TimeRateLimiter
20+
1921
TOOL_ENABLED = False
2022
try:
2123
import openai
@@ -203,7 +205,9 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
203205
"""
204206
iostream = IOStream.get_default()
205207

206-
completions: Completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
208+
completions: Completions = (
209+
self._oai_client.chat.completions if "messages" in params else self._oai_client.completions
210+
) # type: ignore [attr-defined]
207211
# If streaming is enabled and has messages, then iterate over the chunks of the response.
208212
if params.get("stream", False) and "messages" in params:
209213
response_contents = [""] * params.get("n", 1)
@@ -423,8 +427,11 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base
423427

424428
self._clients: List[ModelClient] = []
425429
self._config_list: List[Dict[str, Any]] = []
430+
self._rate_limiters: List[Optional[RateLimiter]] = []
426431

427432
if config_list:
433+
self._initialize_rate_limiters(config_list)
434+
428435
config_list = [config.copy() for config in config_list] # make a copy before modifying
429436
for config in config_list:
430437
self._register_default_client(config, openai_config) # could modify the config
@@ -745,6 +752,7 @@ def yes_or_no_filter(context, response):
745752
return response
746753
continue # filter is not passed; try the next config
747754
try:
755+
self._throttle_api_calls(i)
748756
request_ts = get_current_ts()
749757
response = client.create(params)
750758
except APITimeoutError as err:
@@ -1038,3 +1046,20 @@ def extract_text_or_completion_object(
10381046
A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.
10391047
"""
10401048
return response.message_retrieval_function(response)
1049+
1050+
def _throttle_api_calls(self, idx: int) -> None:
1051+
"""Rate limit api calls."""
1052+
if self._rate_limiters[idx]:
1053+
limiter = self._rate_limiters[idx]
1054+
1055+
assert limiter is not None
1056+
limiter.sleep()
1057+
1058+
def _initialize_rate_limiters(self, config_list: List[Dict[str, Any]]) -> None:
1059+
for config in config_list:
1060+
# Instantiate the rate limiter
1061+
if "api_rate_limit" in config:
1062+
self._rate_limiters.append(TimeRateLimiter(config["api_rate_limit"]))
1063+
del config["api_rate_limit"]
1064+
else:
1065+
self._rate_limiters.append(None)

‎autogen/oai/rate_limiters.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import time
2+
from typing import Protocol
3+
4+
5+
class RateLimiter(Protocol):
6+
def sleep(self, *args, **kwargs): ...
7+
8+
9+
class TimeRateLimiter:
10+
"""A class to implement a time-based rate limiter.
11+
12+
This rate limiter ensures that a certain operation does not exceed a specified frequency.
13+
It can be used to limit the rate of requests sent to a server or the rate of any repeated action.
14+
"""
15+
16+
def __init__(self, rate: float):
17+
"""
18+
Args:
19+
rate (int): The frequency of the time-based rate limiter (NOT time interval).
20+
"""
21+
self._time_interval_seconds = 1.0 / rate
22+
self._last_time_called = 0.0
23+
24+
def sleep(self, *args, **kwargs):
25+
"""Synchronously waits until enough time has passed to allow the next operation.
26+
27+
If the elapsed time since the last operation is less than the required time interval,
28+
this method will block the execution by sleeping for the remaining time.
29+
"""
30+
if self._elapsed_time() < self._time_interval_seconds:
31+
time.sleep(self._time_interval_seconds - self._elapsed_time())
32+
33+
self._last_time_called = time.perf_counter()
34+
35+
def _elapsed_time(self):
36+
return time.perf_counter() - self._last_time_called

‎test/oai/test_client.py

+64-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import shutil
55
import sys
66
import time
7+
from types import SimpleNamespace
78

89
import pytest
910

@@ -31,6 +32,40 @@
3132
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
3233

3334

35+
class _MockClient:
36+
def __init__(self, config, **kwargs):
37+
pass
38+
39+
def create(self, params):
40+
# can create my own data response class
41+
# here using SimpleNamespace for simplicity
42+
# as long as it adheres to the ModelClientResponseProtocol
43+
44+
response = SimpleNamespace()
45+
response.choices = []
46+
response.model = "mock_model"
47+
48+
text = "this is a dummy text response"
49+
choice = SimpleNamespace()
50+
choice.message = SimpleNamespace()
51+
choice.message.content = text
52+
choice.message.function_call = None
53+
response.choices.append(choice)
54+
return response
55+
56+
def message_retrieval(self, response):
57+
choices = response.choices
58+
return [choice.message.content for choice in choices]
59+
60+
def cost(self, response) -> float:
61+
response.cost = 0
62+
return 0
63+
64+
@staticmethod
65+
def get_usage(response):
66+
return {}
67+
68+
3469
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
3570
def test_aoai_chat_completion():
3671
config_list = config_list_from_json(
@@ -322,12 +357,39 @@ def test_cache():
322357
assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED)))
323358

324359

360+
def test_throttled_api_calls():
361+
# Api calling limited at 0.2 request per second, or 1 request per 5 seconds
362+
rate = 1 / 5.0
363+
364+
config_list = [
365+
{
366+
"model": "mock_model",
367+
"model_client_cls": "_MockClient",
368+
# Adding a timeout to catch false positives
369+
"timeout": 1 / rate,
370+
"api_rate_limit": rate,
371+
}
372+
]
373+
374+
client = OpenAIWrapper(config_list=config_list, cache_seed=None)
375+
client.register_model_client(_MockClient)
376+
377+
n_loops = 2
378+
current_time = time.time()
379+
for _ in range(n_loops):
380+
client.create(messages=[{"role": "user", "content": "hello"}])
381+
382+
min_expected_time = (n_loops - 1) / rate
383+
assert time.time() - current_time > min_expected_time
384+
385+
325386
if __name__ == "__main__":
326387
# test_aoai_chat_completion()
327388
# test_oai_tool_calling_extraction()
328389
# test_chat_completion()
329390
test_completion()
330391
# # test_cost()
331392
# test_usage_summary()
332-
# test_legacy_cache()
333-
# test_cache()
393+
test_legacy_cache()
394+
test_cache()
395+
test_throttled_api_calls()

‎test/oai/test_rate_limiters.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import time
2+
3+
import pytest
4+
5+
from autogen.oai.rate_limiters import TimeRateLimiter
6+
7+
8+
@pytest.mark.parametrize("execute_n_times", range(5))
9+
def test_time_rate_limiter(execute_n_times):
10+
current_time_seconds = time.time()
11+
12+
rate = 1
13+
rate_limiter = TimeRateLimiter(rate)
14+
15+
n_loops = 2
16+
for _ in range(n_loops):
17+
rate_limiter.sleep()
18+
19+
total_time = time.time() - current_time_seconds
20+
min_expected_time = (n_loops - 1) / rate
21+
assert total_time >= min_expected_time

‎website/docs/FAQ.mdx

+9-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,15 @@ Yes. You currently have two options:
3737
- Autogen can work with any API endpoint which complies with OpenAI-compatible RESTful APIs - e.g. serving local LLM via FastChat or LM Studio. Please check https://microsoft.github.io/autogen/blog/2023/07/14/Local-LLMs for an example.
3838
- You can supply your own custom model implementation and use it with Autogen. Please check https://microsoft.github.io/autogen/blog/2024/01/26/Custom-Models for more information.
3939

40-
## Handle Rate Limit Error and Timeout Error
40+
## Handling API Rate Limits
41+
42+
### Setting the API Rate Limit
43+
44+
You can set the `api_rate_limit` in a `config_list` for an agent, which will be used to control the rate at which API requests are sent.
45+
46+
- `api_rate_limit` (float): the maximum number of API requests allowed per second.
47+
48+
### Handle Rate Limit Error and Timeout Error
4149

4250
You can set `max_retries` to handle rate limit error. And you can set `timeout` to handle timeout error. They can all be specified in `llm_config` for an agent, which will be used in the OpenAI client for LLM inference. They can be set differently for different clients if they are set in the `config_list`.
4351

‎website/docs/topics/llm_configuration.ipynb

+4
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
" <TabItem value=\"openai\" label=\"OpenAI\" default>\n",
6464
" - `model` (str, required): The identifier of the model to be used, such as 'gpt-4', 'gpt-3.5-turbo'.\n",
6565
" - `api_key` (str, optional): The API key required for authenticating requests to the model's API endpoint.\n",
66+
" - `api_rate_limit` (float, optional): Specifies the maximum number of API requests permitted per second.\n",
6667
" - `base_url` (str, optional): The base URL of the API endpoint. This is the root address where API calls are directed.\n",
6768
" - `tags` (List[str], optional): Tags which can be used for filtering.\n",
6869
"\n",
@@ -72,6 +73,7 @@
7273
" {\n",
7374
" \"model\": \"gpt-4\",\n",
7475
" \"api_key\": os.environ['OPENAI_API_KEY']\n",
76+
" \"api_rate_limit\": 60.0, // Set to allow up to 60 API requests per second.\n",
7577
" }\n",
7678
" ]\n",
7779
" ```\n",
@@ -80,6 +82,7 @@
8082
" - `model` (str, required): The deployment to be used. The model corresponds to the deployment name on Azure OpenAI.\n",
8183
" - `api_key` (str, optional): The API key required for authenticating requests to the model's API endpoint.\n",
8284
" - `api_type`: `azure`\n",
85+
" - `api_rate_limit` (float, optional): Specifies the maximum number of API requests permitted per second.\n",
8386
" - `base_url` (str, optional): The base URL of the API endpoint. This is the root address where API calls are directed.\n",
8487
" - `api_version` (str, optional): The version of the Azure API you wish to use.\n",
8588
" - `tags` (List[str], optional): Tags which can be used for filtering.\n",
@@ -100,6 +103,7 @@
100103
" <TabItem value=\"other\" label=\"Other OpenAI compatible\">\n",
101104
" - `model` (str, required): The identifier of the model to be used, such as 'llama-7B'.\n",
102105
" - `api_key` (str, optional): The API key required for authenticating requests to the model's API endpoint.\n",
106+
" - `api_rate_limit` (float, optional): Specifies the maximum number of API requests permitted per second.\n",
103107
" - `base_url` (str, optional): The base URL of the API endpoint. This is the root address where API calls are directed.\n",
104108
" - `tags` (List[str], optional): Tags which can be used for filtering.\n",
105109
"\n",

0 commit comments

Comments
 (0)