Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle non-string function arguments in tool calls and add corresponding warnings #5260

Merged
merged 3 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -571,14 +571,24 @@ async def create(
stacklevel=2,
)
# NOTE: If OAI response type changes, this will need to be updated
content = [
FunctionCall(
id=x.id,
arguments=x.function.arguments,
name=normalize_name(x.function.name),
content = []
for tool_call in choice.message.tool_calls:
if not isinstance(tool_call.function.arguments, str):
warnings.warn(
f"Tool call function arguments field is not a string: {tool_call.function.arguments}."
"This is unexpected and may due to the API used not returning the correct type. "
"Attempting to convert it to string.",
stacklevel=2,
)
if isinstance(tool_call.function.arguments, dict):
tool_call.function.arguments = json.dumps(tool_call.function.arguments)
content.append(
FunctionCall(
id=tool_call.id,
arguments=tool_call.function.arguments,
name=normalize_name(tool_call.function.name),
)
)
for x in choice.message.tool_calls
]
finish_reason = "tool_calls"
else:
finish_reason = choice.finish_reason
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,31 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
# Should raise warning when function arguments is not a string.
ChatCompletion(
id="id6",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function.construct(name="_pass_function", arguments={"input": "task"}), # type: ignore
)
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
Expand Down Expand Up @@ -676,8 +701,16 @@ async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
assert create_result.content == "I should make a tool call."
assert create_result.finish_reason == "stop"

# Should raise warning when function arguments is not a string.
with pytest.warns(UserWarning, match="Tool call function arguments field is not a string"):
create_result = await model_client.create(
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
)
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
assert create_result.finish_reason == "function_calls"


async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
async def _test_model_client_basic_completion(model_client: OpenAIChatCompletionClient) -> None:
# Test basic completion
create_result = await model_client.create(
messages=[
Expand All @@ -688,6 +721,8 @@ async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0


async def _test_model_client_with_function_calling(model_client: OpenAIChatCompletionClient) -> None:
# Test tool calling
pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.")
fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.")
Expand Down Expand Up @@ -755,7 +790,8 @@ async def test_openai() -> None:
model="gpt-4o-mini",
api_key=api_key,
)
await _test_model_client(model_client)
await _test_model_client_basic_completion(model_client)
await _test_model_client_with_function_calling(model_client)


@pytest.mark.asyncio
Expand All @@ -775,7 +811,29 @@ async def test_gemini() -> None:
"family": ModelFamily.UNKNOWN,
},
)
await _test_model_client(model_client)
await _test_model_client_basic_completion(model_client)
await _test_model_client_with_function_calling(model_client)


@pytest.mark.asyncio
async def test_hugging_face() -> None:
api_key = os.getenv("HF_TOKEN")
if not api_key:
pytest.skip("HF_TOKEN not found in environment variables")

model_client = OpenAIChatCompletionClient(
model="microsoft/Phi-3.5-mini-instruct",
api_key=api_key,
base_url="https://api-inference.huggingface.co/v1/",
model_info={
"function_calling": False,
"json_output": False,
"vision": False,
"family": ModelFamily.UNKNOWN,
},
)

await _test_model_client_basic_completion(model_client)


# TODO: add integration tests for Azure OpenAI using AAD token.
Loading