Skip to content

Commit

Permalink
fix: handle non-string function arguments in tool calls and add corre…
Browse files Browse the repository at this point in the history
…sponding warnings (#5260)
  • Loading branch information
ekzhu authored Jan 30, 2025
1 parent aa23093 commit 44db2cc
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -571,14 +571,24 @@ async def create(
stacklevel=2,
)
# NOTE: If OAI response type changes, this will need to be updated
content = [
FunctionCall(
id=x.id,
arguments=x.function.arguments,
name=normalize_name(x.function.name),
content = []
for tool_call in choice.message.tool_calls:
if not isinstance(tool_call.function.arguments, str):
warnings.warn(
f"Tool call function arguments field is not a string: {tool_call.function.arguments}."
"This is unexpected and may due to the API used not returning the correct type. "
"Attempting to convert it to string.",
stacklevel=2,
)
if isinstance(tool_call.function.arguments, dict):
tool_call.function.arguments = json.dumps(tool_call.function.arguments)
content.append(
FunctionCall(
id=tool_call.id,
arguments=tool_call.function.arguments,
name=normalize_name(tool_call.function.name),
)
)
for x in choice.message.tool_calls
]
finish_reason = "tool_calls"
else:
finish_reason = choice.finish_reason
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,31 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
# Should raise warning when function arguments is not a string.
ChatCompletion(
id="id6",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function.construct(name="_pass_function", arguments={"input": "task"}), # type: ignore
)
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
Expand Down Expand Up @@ -676,8 +701,16 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
assert create_result.content == "I should make a tool call."
assert create_result.finish_reason == "stop"

# Should raise warning when function arguments is not a string.
with pytest.warns(UserWarning, match="Tool call function arguments field is not a string"):
create_result = await model_client.create(
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
)
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
assert create_result.finish_reason == "function_calls"


async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
async def _test_model_client_basic_completion(model_client: OpenAIChatCompletionClient) -> None:
# Test basic completion
create_result = await model_client.create(
messages=[
Expand All @@ -688,6 +721,8 @@ async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0


async def _test_model_client_with_function_calling(model_client: OpenAIChatCompletionClient) -> None:
# Test tool calling
pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.")
fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.")
Expand Down Expand Up @@ -755,7 +790,8 @@ async def test_openai() -> None:
model="gpt-4o-mini",
api_key=api_key,
)
await _test_model_client(model_client)
await _test_model_client_basic_completion(model_client)
await _test_model_client_with_function_calling(model_client)


@pytest.mark.asyncio
Expand All @@ -775,7 +811,29 @@ async def test_gemini() -> None:
"family": ModelFamily.UNKNOWN,
},
)
await _test_model_client(model_client)
await _test_model_client_basic_completion(model_client)
await _test_model_client_with_function_calling(model_client)


@pytest.mark.asyncio
async def test_hugging_face() -> None:
api_key = os.getenv("HF_TOKEN")
if not api_key:
pytest.skip("HF_TOKEN not found in environment variables")

model_client = OpenAIChatCompletionClient(
model="microsoft/Phi-3.5-mini-instruct",
api_key=api_key,
base_url="https://api-inference.huggingface.co/v1/",
model_info={
"function_calling": False,
"json_output": False,
"vision": False,
"family": ModelFamily.UNKNOWN,
},
)

await _test_model_client_basic_completion(model_client)


# TODO: add integration tests for Azure OpenAI using AAD token.

0 comments on commit 44db2cc

Please sign in to comment.