|
16 | 16 | from autogen.runtime_logging import log_chat_completion, log_new_client, log_new_wrapper, logging_enabled
|
17 | 17 | from autogen.token_count_utils import count_token
|
18 | 18 |
|
| 19 | +from .rate_limiters import RateLimiter, TimeRateLimiter |
| 20 | + |
19 | 21 | TOOL_ENABLED = False
|
20 | 22 | try:
|
21 | 23 | import openai
|
@@ -203,7 +205,9 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
|
203 | 205 | """
|
204 | 206 | iostream = IOStream.get_default()
|
205 | 207 |
|
206 |
| - completions: Completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined] |
| 208 | + completions: Completions = ( |
| 209 | + self._oai_client.chat.completions if "messages" in params else self._oai_client.completions |
| 210 | + ) # type: ignore [attr-defined] |
207 | 211 | # If streaming is enabled and has messages, then iterate over the chunks of the response.
|
208 | 212 | if params.get("stream", False) and "messages" in params:
|
209 | 213 | response_contents = [""] * params.get("n", 1)
|
@@ -423,8 +427,11 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base
|
423 | 427 |
|
424 | 428 | self._clients: List[ModelClient] = []
|
425 | 429 | self._config_list: List[Dict[str, Any]] = []
|
| 430 | + self._rate_limiters: List[Optional[RateLimiter]] = [] |
426 | 431 |
|
427 | 432 | if config_list:
|
| 433 | + self._initialize_rate_limiters(config_list) |
| 434 | + |
428 | 435 | config_list = [config.copy() for config in config_list] # make a copy before modifying
|
429 | 436 | for config in config_list:
|
430 | 437 | self._register_default_client(config, openai_config) # could modify the config
|
@@ -745,6 +752,7 @@ def yes_or_no_filter(context, response):
|
745 | 752 | return response
|
746 | 753 | continue # filter is not passed; try the next config
|
747 | 754 | try:
|
| 755 | + self._throttle_api_calls(i) |
748 | 756 | request_ts = get_current_ts()
|
749 | 757 | response = client.create(params)
|
750 | 758 | except APITimeoutError as err:
|
@@ -1038,3 +1046,20 @@ def extract_text_or_completion_object(
|
1038 | 1046 | A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.
|
1039 | 1047 | """
|
1040 | 1048 | return response.message_retrieval_function(response)
|
| 1049 | + |
| 1050 | + def _throttle_api_calls(self, idx: int) -> None: |
| 1051 | + """Rate limit api calls.""" |
| 1052 | + if self._rate_limiters[idx]: |
| 1053 | + limiter = self._rate_limiters[idx] |
| 1054 | + |
| 1055 | + assert limiter is not None |
| 1056 | + limiter.sleep() |
| 1057 | + |
| 1058 | + def _initialize_rate_limiters(self, config_list: List[Dict[str, Any]]) -> None: |
| 1059 | + for config in config_list: |
| 1060 | + # Instantiate the rate limiter |
| 1061 | + if "api_rate_limit" in config: |
| 1062 | + self._rate_limiters.append(TimeRateLimiter(config["api_rate_limit"])) |
| 1063 | + del config["api_rate_limit"] |
| 1064 | + else: |
| 1065 | + self._rate_limiters.append(None) |
0 commit comments