-
Notifications
You must be signed in to change notification settings - Fork 6.4k
/
Copy pathtest_cancellation.py
154 lines (118 loc) · 5.21 KB
/
test_cancellation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import asyncio
from dataclasses import dataclass
import pytest
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, AgentInstantiationContext, CancellationToken, MessageContext
from autogen_core.components import RoutedAgent, message_handler
@dataclass
class MessageType: ...
# Note for future reader:
# To do cancellation, only the token should be interacted with as a user
# If you cancel a future, it may not work as you expect.
class LongRunningAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("A long running agent")
self.called = False
self.cancelled = False
@message_handler
async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType:
self.called = True
sleep = asyncio.ensure_future(asyncio.sleep(100))
ctx.cancellation_token.link_future(sleep)
try:
await sleep
return MessageType()
except asyncio.CancelledError:
self.cancelled = True
raise
class NestingLongRunningAgent(RoutedAgent):
def __init__(self, nested_agent: AgentId) -> None:
super().__init__("A nesting long running agent")
self.called = False
self.cancelled = False
self._nested_agent = nested_agent
@message_handler
async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType:
self.called = True
response = self.send_message(message, self._nested_agent, cancellation_token=ctx.cancellation_token)
try:
val = await response
assert isinstance(val, MessageType)
return val
except asyncio.CancelledError:
self.cancelled = True
raise
@pytest.mark.asyncio
async def test_cancellation_with_token() -> None:
runtime = SingleThreadedAgentRuntime()
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
agent_id = AgentId("long_running", key="default")
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=agent_id, cancellation_token=token))
assert not response.done()
while len(runtime.unprocessed_messages) == 0:
await asyncio.sleep(0.01)
await runtime.process_next()
token.cancel()
with pytest.raises(asyncio.CancelledError):
await response
assert response.done()
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LongRunningAgent)
assert long_running_agent.called
assert long_running_agent.cancelled
@pytest.mark.asyncio
async def test_nested_cancellation_only_outer_called() -> None:
runtime = SingleThreadedAgentRuntime()
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
await NestingLongRunningAgent.register(
runtime,
"nested",
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
)
long_running_id = AgentId("long_running", key="default")
nested_id = AgentId("nested", key="default")
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token))
assert not response.done()
while len(runtime.unprocessed_messages) == 0:
await asyncio.sleep(0.01)
await runtime.process_next()
token.cancel()
with pytest.raises(asyncio.CancelledError):
await response
assert response.done()
nested_agent = await runtime.try_get_underlying_agent_instance(nested_id, type=NestingLongRunningAgent)
assert nested_agent.called
assert nested_agent.cancelled
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent)
assert long_running_agent.called is False
assert long_running_agent.cancelled is False
@pytest.mark.asyncio
async def test_nested_cancellation_inner_called() -> None:
runtime = SingleThreadedAgentRuntime()
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
await NestingLongRunningAgent.register(
runtime,
"nested",
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
)
long_running_id = AgentId("long_running", key="default")
nested_id = AgentId("nested", key="default")
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token))
assert not response.done()
while len(runtime.unprocessed_messages) == 0:
await asyncio.sleep(0.01)
await runtime.process_next()
# allow the inner agent to process
await runtime.process_next()
token.cancel()
with pytest.raises(asyncio.CancelledError):
await response
assert response.done()
nested_agent = await runtime.try_get_underlying_agent_instance(nested_id, type=NestingLongRunningAgent)
assert nested_agent.called
assert nested_agent.cancelled
long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent)
assert long_running_agent.called
assert long_running_agent.cancelled