diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 481e55728c54..515226861f3e 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -374,6 +374,22 @@ def get_usage(response: Union[ChatCompletion, Completion]) -> Dict: } +class OpenAI_O1(OpenAIClient): + + def __init__(self, **kwargs): + super().__init__(OpenAI(**kwargs)) + + def create(self, params: Dict[str, Any]) -> ChatCompletion: + print(params["messages"]) + # replace any message with the role "system" to role "assistant" to avoid errors + for message in params["messages"]: + if message["role"] == "system": + message["role"] = "assistant" + + # pass the message to the create method of the parent class + return super().create(params) + + class OpenAIWrapper: """A wrapper class for openai client.""" @@ -532,6 +548,9 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s raise ImportError("Please install `anthropic` to use Anthropic API.") client = AnthropicClient(**openai_config) self._clients.append(client) + elif api_type is not None and api_type.startswith("openai-o1"): + client = OpenAI_O1(**openai_config) + self._clients.append(client) elif api_type is not None and api_type.startswith("mistral"): if mistral_import_exception: raise ImportError("Please install `mistralai` to use the Mistral.AI API.") diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py index ceb7ef90c933..dccee47f3275 100644 --- a/autogen/oai/openai_utils.py +++ b/autogen/oai/openai_utils.py @@ -30,6 +30,12 @@ "gpt-4o": (0.005, 0.015), "gpt-4o-2024-05-13": (0.005, 0.015), "gpt-4o-2024-08-06": (0.0025, 0.01), + # o1 models + "o1-preview": (0.015, 0.060), + "o1-preview-2024-09-12": (0.015, 0.060), + # o1-mini models + "o1-mini": (0.003, 0.012), + "o1-mini-2024-09-12": (0.003, 0.012), # gpt-4-turbo "gpt-4-turbo-2024-04-09": (0.01, 0.03), # gpt-4 diff --git a/autogen/token_count_utils.py b/autogen/token_count_utils.py index 8552a8f16536..41290e741bde 100644 --- a/autogen/token_count_utils.py +++ b/autogen/token_count_utils.py @@ -39,6 +39,12 @@ def get_max_token_limit(model: str = "gpt-3.5-turbo-0613") -> int: "gpt-4o-2024-08-06": 128000, "gpt-4o-mini": 128000, "gpt-4o-mini-2024-07-18": 128000, + # o1 models + "o1-preview": 128000, + "o1-preview-2024-09-12": 128000, + # o1-mini models + "o1-mini": 128000, + "o1-mini-2024-09-12": 128000, } return max_token_limit[model] diff --git a/test/twoagent-o1.py b/test/twoagent-o1.py new file mode 100644 index 000000000000..538c3398f66a --- /dev/null +++ b/test/twoagent-o1.py @@ -0,0 +1,8 @@ +from autogen import AssistantAgent, UserProxyAgent, config_list_from_json + +config_list = [{"api_type": "openai-o1", "model": "o1-mini"}] +assistant = AssistantAgent("assistant", llm_config={"config_list": config_list}) +user_proxy = UserProxyAgent( + "user_proxy", code_execution_config={"work_dir": "coding", "use_docker": False} +) # IMPORTANT: set to True to run code in docker, recommended +user_proxy.initiate_chat(assistant, message="Save a chart of NVDA and TESLA stock price change YTD.")