-
Notifications
You must be signed in to change notification settings - Fork 6.4k
/
Copy pathtest_runtime.py
228 lines (169 loc) · 8.21 KB
/
test_runtime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import logging
import pytest
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import (
AgentId,
AgentInstantiationContext,
AgentType,
TopicId,
try_get_known_serializers_for_type,
)
from autogen_core.components import DefaultTopicId, TypeSubscription, type_subscription
from opentelemetry.sdk.trace import TracerProvider
from test_utils import (
CascadingAgent,
CascadingMessageType,
LoopbackAgent,
LoopbackAgentWithDefaultSubscription,
MessageType,
NoopAgent,
)
from test_utils.telemetry_test_utils import TestExporter, get_test_tracer_provider
test_exporter = TestExporter()
@pytest.fixture
def tracer_provider() -> TracerProvider:
test_exporter.clear()
return get_test_tracer_provider(test_exporter)
@pytest.mark.asyncio
async def test_agent_type_must_be_unique() -> None:
runtime = SingleThreadedAgentRuntime()
def agent_factory() -> NoopAgent:
id = AgentInstantiationContext.current_agent_id()
assert id == AgentId("name1", "default")
agent = NoopAgent()
assert agent.id == id
return agent
await NoopAgent.register(runtime, "name1", agent_factory)
# await runtime.register_factory(type=AgentType("name1"), agent_factory=agent_factory, expected_class=NoopAgent)
with pytest.raises(ValueError):
await runtime.register_factory(type=AgentType("name1"), agent_factory=agent_factory, expected_class=NoopAgent)
await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory, expected_class=NoopAgent)
@pytest.mark.asyncio
async def test_register_receives_publish(tracer_provider: TracerProvider) -> None:
runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider)
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await runtime.register_factory(
type=AgentType("name"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
)
await runtime.add_subscription(TypeSubscription("default", "name"))
runtime.start()
await runtime.publish_message(MessageType(), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(AgentId("name", "default"), type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
)
assert other_long_running_agent.num_calls == 0
exported_spans = test_exporter.get_exported_spans()
assert len(exported_spans) == 3
span_names = [span.name for span in exported_spans]
assert span_names == [
"autogen create default.(default)-T",
"autogen process name.(default)-A",
"autogen publish default.(default)-T",
]
@pytest.mark.asyncio
async def test_register_receives_publish_with_exception(caplog: pytest.LogCaptureFixture) -> None:
runtime = SingleThreadedAgentRuntime()
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
async def agent_factory() -> LoopbackAgent:
raise ValueError("test")
await runtime.register_factory(type=AgentType("name"), agent_factory=agent_factory, expected_class=LoopbackAgent)
await runtime.add_subscription(TypeSubscription("default", "name"))
with caplog.at_level(logging.ERROR):
runtime.start()
await runtime.publish_message(MessageType(), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle()
# Check if logger has the exception.
assert any("Error processing publish message" in e.message for e in caplog.records)
@pytest.mark.asyncio
async def test_register_receives_publish_cascade() -> None:
num_agents = 5
num_initial_messages = 5
max_rounds = 5
total_num_calls_expected = 0
for i in range(0, max_rounds):
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
runtime = SingleThreadedAgentRuntime()
# Register agents
for i in range(num_agents):
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
runtime.start()
# Publish messages
for _ in range(num_initial_messages):
await runtime.publish_message(CascadingMessageType(round=1), DefaultTopicId())
# Process until idle.
await runtime.stop_when_idle()
# Check that each agent received the correct number of messages.
for i in range(num_agents):
agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent)
assert agent.num_calls == total_num_calls_expected
@pytest.mark.asyncio
async def test_register_factory_explicit_name() -> None:
runtime = SingleThreadedAgentRuntime()
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
await runtime.add_subscription(TypeSubscription("default", "name"))
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)
await runtime.stop_when_idle()
# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_default_subscription() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.start()
await LoopbackAgentWithDefaultSubscription.register(runtime, "name", LoopbackAgentWithDefaultSubscription)
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.stop_when_idle()
long_running_agent = await runtime.try_get_underlying_agent_instance(
agent_id, type=LoopbackAgentWithDefaultSubscription
)
assert long_running_agent.num_calls == 1
other_long_running_agent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_type_subscription() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.start()
@type_subscription(topic_type="Other")
class LoopbackAgentWithSubscription(LoopbackAgent): ...
await LoopbackAgentWithSubscription.register(runtime, "name", LoopbackAgentWithSubscription)
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=TopicId("Other", "default"))
await runtime.stop_when_idle()
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgentWithSubscription)
assert long_running_agent.num_calls == 1
other_long_running_agent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgentWithSubscription
)
assert other_long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_default_subscription_publish_to_other_source() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.start()
await LoopbackAgentWithDefaultSubscription.register(runtime, "name", LoopbackAgentWithDefaultSubscription)
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))
await runtime.stop_when_idle()
long_running_agent = await runtime.try_get_underlying_agent_instance(
agent_id, type=LoopbackAgentWithDefaultSubscription
)
assert long_running_agent.num_calls == 0
other_long_running_agent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
)
assert other_long_running_agent.num_calls == 1