diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index 131e7288a65..c743debf992 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -571,14 +571,24 @@ async def create( stacklevel=2, ) # NOTE: If OAI response type changes, this will need to be updated - content = [ - FunctionCall( - id=x.id, - arguments=x.function.arguments, - name=normalize_name(x.function.name), + content = [] + for tool_call in choice.message.tool_calls: + if not isinstance(tool_call.function.arguments, str): + warnings.warn( + f"Tool call function arguments field is not a string: {tool_call.function.arguments}." + "This is unexpected and may due to the API used not returning the correct type. " + "Attempting to convert it to string.", + stacklevel=2, + ) + if isinstance(tool_call.function.arguments, dict): + tool_call.function.arguments = json.dumps(tool_call.function.arguments) + content.append( + FunctionCall( + id=tool_call.id, + arguments=tool_call.function.arguments, + name=normalize_name(tool_call.function.name), + ) ) - for x in choice.message.tool_calls - ] finish_reason = "tool_calls" else: finish_reason = choice.finish_reason diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index 0e568d45a93..ab67917998a 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -619,6 +619,31 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None: object="chat.completion", usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), + # Should raise warning when function arguments is not a string. + ChatCompletion( + id="id6", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="1", + type="function", + function=Function.construct(name="_pass_function", arguments={"input": "task"}), # type: ignore + ) + ], + role="assistant", + ), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), ] mock = _MockChatCompletion(chat_completions) monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) @@ -676,8 +701,16 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None: assert create_result.content == "I should make a tool call." assert create_result.finish_reason == "stop" + # Should raise warning when function arguments is not a string. + with pytest.warns(UserWarning, match="Tool call function arguments field is not a string"): + create_result = await model_client.create( + messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool] + ) + assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")] + assert create_result.finish_reason == "function_calls" + -async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None: +async def _test_model_client_basic_completion(model_client: OpenAIChatCompletionClient) -> None: # Test basic completion create_result = await model_client.create( messages=[ @@ -688,6 +721,8 @@ async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None: assert isinstance(create_result.content, str) assert len(create_result.content) > 0 + +async def _test_model_client_with_function_calling(model_client: OpenAIChatCompletionClient) -> None: # Test tool calling pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.") fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.") @@ -755,7 +790,8 @@ async def test_openai() -> None: model="gpt-4o-mini", api_key=api_key, ) - await _test_model_client(model_client) + await _test_model_client_basic_completion(model_client) + await _test_model_client_with_function_calling(model_client) @pytest.mark.asyncio @@ -775,7 +811,29 @@ async def test_gemini() -> None: "family": ModelFamily.UNKNOWN, }, ) - await _test_model_client(model_client) + await _test_model_client_basic_completion(model_client) + await _test_model_client_with_function_calling(model_client) + + +@pytest.mark.asyncio +async def test_hugging_face() -> None: + api_key = os.getenv("HF_TOKEN") + if not api_key: + pytest.skip("HF_TOKEN not found in environment variables") + + model_client = OpenAIChatCompletionClient( + model="microsoft/Phi-3.5-mini-instruct", + api_key=api_key, + base_url="https://api-inference.huggingface.co/v1/", + model_info={ + "function_calling": False, + "json_output": False, + "vision": False, + "family": ModelFamily.UNKNOWN, + }, + ) + + await _test_model_client_basic_completion(model_client) # TODO: add integration tests for Azure OpenAI using AAD token.