Skip to content

Commit 8bf5bee

Browse files
iogadmontagu
andauthored
fix: anthropic parallel tool call results. (#653)
Co-authored-by: David Montague <[email protected]>
1 parent 9b4de86 commit 8bf5bee

File tree

2 files changed

+71
-49
lines changed

2 files changed

+71
-49
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

+41-49
Original file line numberDiff line numberDiff line change
@@ -272,64 +272,56 @@ def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageP
272272
anthropic_messages: list[MessageParam] = []
273273
for m in messages:
274274
if isinstance(m, ModelRequest):
275-
for part in m.parts:
276-
if isinstance(part, SystemPromptPart):
277-
system_prompt += part.content
278-
elif isinstance(part, UserPromptPart):
279-
anthropic_messages.append(MessageParam(role='user', content=part.content))
280-
elif isinstance(part, ToolReturnPart):
281-
anthropic_messages.append(
282-
MessageParam(
283-
role='user',
284-
content=[
285-
ToolResultBlockParam(
286-
tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
287-
type='tool_result',
288-
content=part.model_response_str(),
289-
is_error=False,
290-
)
291-
],
292-
)
275+
user_content_params: list[ToolResultBlockParam | TextBlockParam] = []
276+
for request_part in m.parts:
277+
if isinstance(request_part, SystemPromptPart):
278+
system_prompt += request_part.content
279+
elif isinstance(request_part, UserPromptPart):
280+
text_block_param = TextBlockParam(type='text', text=request_part.content)
281+
user_content_params.append(text_block_param)
282+
elif isinstance(request_part, ToolReturnPart):
283+
tool_result_block_param = ToolResultBlockParam(
284+
tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
285+
type='tool_result',
286+
content=request_part.model_response_str(),
287+
is_error=False,
293288
)
294-
elif isinstance(part, RetryPromptPart):
295-
if part.tool_name is None:
296-
anthropic_messages.append(MessageParam(role='user', content=part.model_response()))
289+
user_content_params.append(tool_result_block_param)
290+
elif isinstance(request_part, RetryPromptPart):
291+
if request_part.tool_name is None:
292+
retry_param = TextBlockParam(type='text', text=request_part.model_response())
297293
else:
298-
anthropic_messages.append(
299-
MessageParam(
300-
role='user',
301-
content=[
302-
ToolResultBlockParam(
303-
tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
304-
type='tool_result',
305-
content=part.model_response(),
306-
is_error=True,
307-
),
308-
],
309-
)
294+
retry_param = ToolResultBlockParam(
295+
tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
296+
type='tool_result',
297+
content=request_part.model_response(),
298+
is_error=True,
310299
)
300+
user_content_params.append(retry_param)
301+
anthropic_messages.append(
302+
MessageParam(
303+
role='user',
304+
content=user_content_params,
305+
)
306+
)
311307
elif isinstance(m, ModelResponse):
312-
content: list[TextBlockParam | ToolUseBlockParam] = []
313-
for item in m.parts:
314-
if isinstance(item, TextPart):
315-
content.append(TextBlockParam(text=item.content, type='text'))
308+
assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
309+
for response_part in m.parts:
310+
if isinstance(response_part, TextPart):
311+
assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
316312
else:
317-
assert isinstance(item, ToolCallPart)
318-
content.append(self._map_tool_call(item))
319-
anthropic_messages.append(MessageParam(role='assistant', content=content))
313+
tool_use_block_param = ToolUseBlockParam(
314+
id=_guard_tool_call_id(t=response_part, model_source='Anthropic'),
315+
type='tool_use',
316+
name=response_part.tool_name,
317+
input=response_part.args_as_dict(),
318+
)
319+
assistant_content_params.append(tool_use_block_param)
320+
anthropic_messages.append(MessageParam(role='assistant', content=assistant_content_params))
320321
else:
321322
assert_never(m)
322323
return system_prompt, anthropic_messages
323324

324-
@staticmethod
325-
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
326-
return ToolUseBlockParam(
327-
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
328-
type='tool_use',
329-
name=t.tool_name,
330-
input=t.args_as_dict(),
331-
)
332-
333325
@staticmethod
334326
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
335327
return {

tests/models/test_anthropic.py

+30
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,36 @@ async def get_location(loc_name: str) -> str:
332332
)
333333

334334

335+
@pytest.mark.skip(reason='Need to finish implementing this test') # TODO: David to finish implementing this test
336+
async def test_multiple_parallel_tool_calls(): # pragma: no cover # Need to finish the test..
337+
async def retrieve_entity_info(name: str) -> str:
338+
"""Get the knowledge about the given entity."""
339+
data = {
340+
'alice': "alice is bob's wife",
341+
'bob': "bob is alice's husband",
342+
'charlie': "charlie is alice's son",
343+
'daisy': "daisy is bob's daughter and charlie's younger sister",
344+
}
345+
return data[name.lower()]
346+
347+
system_prompt = """
348+
Use the `retrieve_entity_info` tool to get information about a specific person.
349+
If you need to use `retrieve_entity_info` to get information about multiple people, try
350+
to call them in parallel as much as possible.
351+
Think step by step and then provide a single most probable concise answer.
352+
"""
353+
354+
mock_client = MockAnthropic.create_mock([])
355+
agent = Agent(
356+
AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client),
357+
system_prompt=system_prompt,
358+
tools=[retrieve_entity_info],
359+
)
360+
361+
_result = agent.run_sync('Alice, Bob, Charlie and Daisy are a family. Who is the youngest?')
362+
# assert ...
363+
364+
335365
async def test_anthropic_specific_metadata(allow_model_requests: None) -> None:
336366
c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
337367
mock_client = MockAnthropic.create_mock(c)

0 commit comments

Comments
 (0)