-
Notifications
You must be signed in to change notification settings - Fork 6.4k
/
Copy pathtest_intervention.py
129 lines (94 loc) · 4.76 KB
/
test_intervention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import pytest
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId
from autogen_core.base.exceptions import MessageDroppedException
from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage
from test_utils import LoopbackAgent, MessageType
@pytest.mark.asyncio
async def test_intervention_count_messages() -> None:
class DebugInterventionHandler(DefaultInterventionHandler):
def __init__(self) -> None:
self.num_messages = 0
async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType:
self.num_messages += 1
return message
handler = DebugInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()
_response = await runtime.send_message(MessageType(), recipient=loopback)
await runtime.stop()
assert handler.num_messages == 1
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
assert loopback_agent.num_calls == 1
@pytest.mark.asyncio
async def test_intervention_drop_send() -> None:
class DropSendInterventionHandler(DefaultInterventionHandler):
async def on_send(
self, message: MessageType, *, sender: AgentId | None, recipient: AgentId
) -> MessageType | type[DropMessage]:
return DropMessage
handler = DropSendInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()
with pytest.raises(MessageDroppedException):
_response = await runtime.send_message(MessageType(), recipient=loopback)
await runtime.stop()
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
assert loopback_agent.num_calls == 0
@pytest.mark.asyncio
async def test_intervention_drop_response() -> None:
class DropResponseInterventionHandler(DefaultInterventionHandler):
async def on_response(
self, message: MessageType, *, sender: AgentId, recipient: AgentId | None
) -> MessageType | type[DropMessage]:
return DropMessage
handler = DropResponseInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()
with pytest.raises(MessageDroppedException):
_response = await runtime.send_message(MessageType(), recipient=loopback)
await runtime.stop()
@pytest.mark.asyncio
async def test_intervention_raise_exception_on_send() -> None:
class InterventionException(Exception):
pass
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
async def on_send(
self, message: MessageType, *, sender: AgentId | None, recipient: AgentId
) -> MessageType | type[DropMessage]: # type: ignore
raise InterventionException
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()
with pytest.raises(InterventionException):
_response = await runtime.send_message(MessageType(), recipient=loopback)
await runtime.stop()
long_running_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
assert long_running_agent.num_calls == 0
@pytest.mark.asyncio
async def test_intervention_raise_exception_on_respond() -> None:
class InterventionException(Exception):
pass
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
async def on_response(
self, message: MessageType, *, sender: AgentId, recipient: AgentId | None
) -> MessageType | type[DropMessage]: # type: ignore
raise InterventionException
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()
with pytest.raises(InterventionException):
_response = await runtime.send_message(MessageType(), recipient=loopback)
await runtime.stop()
long_running_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
assert long_running_agent.num_calls == 1