-
Notifications
You must be signed in to change notification settings - Fork 6.5k
/
Copy pathtest_tool_agent.py
164 lines (140 loc) · 5.68 KB
/
test_tool_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import asyncio
import json
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union
import pytest
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, CancellationToken
from autogen_core.components import FunctionCall
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
ModelCapabilities,
RequestUsage,
UserMessage,
)
from autogen_core.components.tool_agent import (
InvalidToolArgumentsException,
ToolAgent,
ToolExecutionException,
ToolNotFoundException,
tool_agent_caller_loop,
)
from autogen_core.components.tools import FunctionTool, Tool, ToolSchema
def _pass_function(input: str) -> str:
return "pass"
def _raise_function(input: str) -> str:
raise Exception("raise")
async def _async_sleep_function(input: str) -> str:
await asyncio.sleep(10)
return "pass"
@pytest.mark.asyncio
async def test_tool_agent() -> None:
runtime = SingleThreadedAgentRuntime()
await ToolAgent.register(
runtime,
"tool_agent",
lambda: ToolAgent(
description="Tool agent",
tools=[
FunctionTool(_pass_function, name="pass", description="Pass function"),
FunctionTool(_raise_function, name="raise", description="Raise function"),
FunctionTool(_async_sleep_function, name="sleep", description="Sleep function"),
],
),
)
agent = AgentId("tool_agent", "default")
runtime.start()
# Test pass function
result = await runtime.send_message(
FunctionCall(id="1", arguments=json.dumps({"input": "pass"}), name="pass"), agent
)
assert result == FunctionExecutionResult(call_id="1", content="pass")
# Test raise function
with pytest.raises(ToolExecutionException):
await runtime.send_message(FunctionCall(id="2", arguments=json.dumps({"input": "raise"}), name="raise"), agent)
# Test invalid tool name
with pytest.raises(ToolNotFoundException):
await runtime.send_message(FunctionCall(id="3", arguments=json.dumps({"input": "pass"}), name="invalid"), agent)
# Test invalid arguments
with pytest.raises(InvalidToolArgumentsException):
await runtime.send_message(FunctionCall(id="3", arguments="invalid json /xd", name="pass"), agent)
# Test sleep and cancel.
token = CancellationToken()
result_future = runtime.send_message(
FunctionCall(id="3", arguments=json.dumps({"input": "sleep"}), name="sleep"), agent, cancellation_token=token
)
token.cancel()
with pytest.raises(asyncio.CancelledError):
await result_future
await runtime.stop()
@pytest.mark.asyncio
async def test_caller_loop() -> None:
class MockChatCompletionClient(ChatCompletionClient):
async def create(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
if len(messages) == 1:
return CreateResult(
content=[FunctionCall(id="1", name="pass", arguments=json.dumps({"input": "test"}))],
finish_reason="stop",
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
cached=False,
logprobs=None,
)
return CreateResult(
content="Done",
finish_reason="stop",
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
cached=False,
logprobs=None,
)
def create_stream(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
raise NotImplementedError()
def actual_usage(self) -> RequestUsage:
return RequestUsage(prompt_tokens=0, completion_tokens=0)
def total_usage(self) -> RequestUsage:
return RequestUsage(prompt_tokens=0, completion_tokens=0)
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
return 0
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
return 0
@property
def capabilities(self) -> ModelCapabilities:
return ModelCapabilities(vision=False, function_calling=True, json_output=False)
client = MockChatCompletionClient()
tools: List[Tool] = [FunctionTool(_pass_function, name="pass", description="Pass function")]
runtime = SingleThreadedAgentRuntime()
await ToolAgent.register(
runtime,
"tool_agent",
lambda: ToolAgent(
description="Tool agent",
tools=tools,
),
)
agent = AgentId("tool_agent", "default")
runtime.start()
messages = await tool_agent_caller_loop(
runtime, agent, client, [UserMessage(content="Hello", source="user")], tool_schema=tools
)
assert len(messages) == 3
assert isinstance(messages[0], AssistantMessage)
assert isinstance(messages[1], FunctionExecutionResultMessage)
assert isinstance(messages[2], AssistantMessage)
await runtime.stop()