1
1
import json
2
2
from typing import Any , List
3
3
4
- from autogen_core .base import MessageContext
4
+ from autogen_core .base import AgentId , CancellationToken , MessageContext
5
5
from autogen_core .components import DefaultTopicId , Image , event , rpc
6
6
from autogen_core .components .models import (
7
7
AssistantMessage ,
@@ -120,7 +120,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
120
120
planning_conversation .append (
121
121
UserMessage (content = self ._get_task_ledger_facts_prompt (self ._task ), source = self ._name )
122
122
)
123
- response = await self ._model_client .create (planning_conversation )
123
+ response = await self ._model_client .create (planning_conversation , cancellation_token = ctx . cancellation_token )
124
124
125
125
assert isinstance (response .content , str )
126
126
self ._facts = response .content
@@ -131,19 +131,19 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
131
131
planning_conversation .append (
132
132
UserMessage (content = self ._get_task_ledger_plan_prompt (self ._team_description ), source = self ._name )
133
133
)
134
- response = await self ._model_client .create (planning_conversation )
134
+ response = await self ._model_client .create (planning_conversation , cancellation_token = ctx . cancellation_token )
135
135
136
136
assert isinstance (response .content , str )
137
137
self ._plan = response .content
138
138
139
139
# Kick things off
140
140
self ._n_stalls = 0
141
- await self ._reenter_inner_loop ()
141
+ await self ._reenter_inner_loop (ctx . cancellation_token )
142
142
143
143
@event
144
144
async def handle_agent_response (self , message : GroupChatAgentResponse , ctx : MessageContext ) -> None :
145
145
self ._message_thread .append (message .agent_response .chat_message )
146
- await self ._orchestrate_step ()
146
+ await self ._orchestrate_step (ctx . cancellation_token )
147
147
148
148
@rpc
149
149
async def handle_reset (self , message : GroupChatReset , ctx : MessageContext ) -> None :
@@ -162,12 +162,16 @@ async def reset(self) -> None:
162
162
async def on_unhandled_message (self , message : Any , ctx : MessageContext ) -> None :
163
163
raise ValueError (f"Unhandled message in group chat manager: { type (message )} " )
164
164
165
- async def _reenter_inner_loop (self ) -> None :
165
+ async def _reenter_inner_loop (self , cancellation_token : CancellationToken ) -> None :
166
166
# Reset the agents
167
- await self .publish_message (
168
- GroupChatReset (),
169
- topic_id = DefaultTopicId (type = self ._group_topic_type ),
170
- )
167
+ for participant_topic_type in self ._participant_topic_types :
168
+ await self ._runtime .send_message (
169
+ GroupChatReset (),
170
+ recipient = AgentId (type = participant_topic_type , key = self .id .key ),
171
+ cancellation_token = cancellation_token ,
172
+ )
173
+ # Reset the group chat manager
174
+ await self .reset ()
171
175
self ._message_thread .clear ()
172
176
173
177
# Prepare the ledger
@@ -192,12 +196,12 @@ async def _reenter_inner_loop(self) -> None:
192
196
)
193
197
194
198
# Restart the inner loop
195
- await self ._orchestrate_step ()
199
+ await self ._orchestrate_step (cancellation_token = cancellation_token )
196
200
197
- async def _orchestrate_step (self ) -> None :
201
+ async def _orchestrate_step (self , cancellation_token : CancellationToken ) -> None :
198
202
# Check if we reached the maximum number of rounds
199
203
if self ._max_turns is not None and self ._n_rounds > self ._max_turns :
200
- await self ._prepare_final_answer ("Max rounds reached." )
204
+ await self ._prepare_final_answer ("Max rounds reached." , cancellation_token )
201
205
return
202
206
self ._n_rounds += 1
203
207
@@ -216,7 +220,7 @@ async def _orchestrate_step(self) -> None:
216
220
217
221
# Check for task completion
218
222
if progress_ledger ["is_request_satisfied" ]["answer" ]:
219
- await self ._prepare_final_answer (progress_ledger ["is_request_satisfied" ]["reason" ])
223
+ await self ._prepare_final_answer (progress_ledger ["is_request_satisfied" ]["reason" ], cancellation_token )
220
224
return
221
225
222
226
# Check for stalling
@@ -229,8 +233,8 @@ async def _orchestrate_step(self) -> None:
229
233
230
234
# Too much stalling
231
235
if self ._n_stalls >= self ._max_stalls :
232
- await self ._update_task_ledger ()
233
- await self ._reenter_inner_loop ()
236
+ await self ._update_task_ledger (cancellation_token )
237
+ await self ._reenter_inner_loop (cancellation_token )
234
238
return
235
239
236
240
# Broadcst the next step
@@ -247,20 +251,23 @@ async def _orchestrate_step(self) -> None:
247
251
await self .publish_message ( # Broadcast
248
252
GroupChatAgentResponse (agent_response = Response (chat_message = message )),
249
253
topic_id = DefaultTopicId (type = self ._group_topic_type ),
254
+ cancellation_token = cancellation_token ,
250
255
)
251
256
252
257
# Request that the step be completed
253
258
next_speaker = progress_ledger ["next_speaker" ]["answer" ]
254
- await self .publish_message (GroupChatRequestPublish (), topic_id = DefaultTopicId (type = next_speaker ))
259
+ await self .publish_message (
260
+ GroupChatRequestPublish (), topic_id = DefaultTopicId (type = next_speaker ), cancellation_token = cancellation_token
261
+ )
255
262
256
- async def _update_task_ledger (self ) -> None :
263
+ async def _update_task_ledger (self , cancellation_token : CancellationToken ) -> None :
257
264
context = self ._thread_to_context ()
258
265
259
266
# Update the facts
260
267
update_facts_prompt = self ._get_task_ledger_facts_update_prompt (self ._task , self ._facts )
261
268
context .append (UserMessage (content = update_facts_prompt , source = self ._name ))
262
269
263
- response = await self ._model_client .create (context )
270
+ response = await self ._model_client .create (context , cancellation_token = cancellation_token )
264
271
265
272
assert isinstance (response .content , str )
266
273
self ._facts = response .content
@@ -270,19 +277,19 @@ async def _update_task_ledger(self) -> None:
270
277
update_plan_prompt = self ._get_task_ledger_plan_update_prompt (self ._team_description )
271
278
context .append (UserMessage (content = update_plan_prompt , source = self ._name ))
272
279
273
- response = await self ._model_client .create (context )
280
+ response = await self ._model_client .create (context , cancellation_token = cancellation_token )
274
281
275
282
assert isinstance (response .content , str )
276
283
self ._plan = response .content
277
284
278
- async def _prepare_final_answer (self , reason : str ) -> None :
285
+ async def _prepare_final_answer (self , reason : str , cancellation_token : CancellationToken ) -> None :
279
286
context = self ._thread_to_context ()
280
287
281
288
# Get the final answer
282
289
final_answer_prompt = self ._get_final_answer_prompt (self ._task )
283
290
context .append (UserMessage (content = final_answer_prompt , source = self ._name ))
284
291
285
- response = await self ._model_client .create (context )
292
+ response = await self ._model_client .create (context , cancellation_token = cancellation_token )
286
293
assert isinstance (response .content , str )
287
294
message = TextMessage (content = response .content , source = self ._name )
288
295
@@ -298,6 +305,7 @@ async def _prepare_final_answer(self, reason: str) -> None:
298
305
await self .publish_message (
299
306
GroupChatAgentResponse (agent_response = Response (chat_message = message )),
300
307
topic_id = DefaultTopicId (type = self ._group_topic_type ),
308
+ cancellation_token = cancellation_token ,
301
309
)
302
310
303
311
# Signal termination
0 commit comments