diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index e477ed3f1245..7b281e39f6f5 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -17,10 +17,10 @@ ToolCallRequestEvent, ToolCallSummaryMessage, ) -from autogen_core import Image +from autogen_core import FunctionCall, Image from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult from autogen_core.model_context import BufferedChatCompletionContext -from autogen_core.models import LLMMessage +from autogen_core.models import FunctionExecutionResult, LLMMessage from autogen_core.models._model_client import ModelFamily from autogen_core.tools import FunctionTool from autogen_ext.models.openai import OpenAIChatCompletionClient @@ -281,6 +281,142 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> assert state == state2 +@pytest.mark.asyncio +async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None: + model = "gpt-4o-2024-05-13" + chat_completions = [ + ChatCompletion( + id="id1", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="1", + type="function", + function=Function( + name="_pass_function", + arguments=json.dumps({"input": "task1"}), + ), + ), + ChatCompletionMessageToolCall( + id="2", + type="function", + function=Function( + name="_pass_function", + arguments=json.dumps({"input": "task2"}), + ), + ), + ChatCompletionMessageToolCall( + id="3", + type="function", + function=Function( + name="_echo_function", + arguments=json.dumps({"input": "task3"}), + ), + ), + ], + role="assistant", + ), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + ChatCompletion( + id="id2", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="pass", role="assistant"), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + ChatCompletion( + id="id2", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="TERMINATE", role="assistant"), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + ] + mock = _MockChatCompletion(chat_completions) + monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + agent = AssistantAgent( + "tool_use_agent", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], + ) + result = await agent.run(task="task") + + assert len(result.messages) == 4 + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].models_usage is None + assert isinstance(result.messages[1], ToolCallRequestEvent) + assert result.messages[1].content == [ + FunctionCall(id="1", arguments=r'{"input": "task1"}', name="_pass_function"), + FunctionCall(id="2", arguments=r'{"input": "task2"}', name="_pass_function"), + FunctionCall(id="3", arguments=r'{"input": "task3"}', name="_echo_function"), + ] + assert result.messages[1].models_usage is not None + assert result.messages[1].models_usage.completion_tokens == 5 + assert result.messages[1].models_usage.prompt_tokens == 10 + assert isinstance(result.messages[2], ToolCallExecutionEvent) + expected_content = [ + FunctionExecutionResult(call_id="1", content="pass"), + FunctionExecutionResult(call_id="2", content="pass"), + FunctionExecutionResult(call_id="3", content="task3"), + ] + for expected in expected_content: + assert expected in result.messages[2].content + assert result.messages[2].models_usage is None + assert isinstance(result.messages[3], ToolCallSummaryMessage) + assert result.messages[3].content == "pass\npass\ntask3" + assert result.messages[3].models_usage is None + + # Test streaming. + mock.curr_index = 0 # Reset the mock + index = 0 + async for message in agent.run_stream(task="task"): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + + # Test state saving and loading. + state = await agent.save_state() + agent2 = AssistantAgent( + "tool_use_agent", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + ) + await agent2.load_state(state) + state2 = await agent2.save_state() + assert state == state2 + + @pytest.mark.asyncio async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: handoff = Handoff(target="agent2")