@@ -15,11 +15,14 @@ class AgenticMemoryController:
15
15
Manages memory-based learning, testing, and the flow of information to and from the memory bank.
16
16
17
17
Args:
18
- settings: Settings for the memory controller.
19
18
reset: True to clear the memory bank before starting.
20
19
client: The client to call the model.
21
20
task_assignment_callback: The callback to assign a task to the agent.
22
- logger: The logger to log the model calls.
21
+ - config: An optional dict that can be used to override the following values:
22
+ - max_train_trials: The maximum number of trials to attempt when training on a task.
23
+ - max_test_trials: The maximum number of trials to attempt when testing on a task.
24
+ - AgenticMemoryBank: A config dict passed to AgenticMemoryBank.
25
+ logger: An optional logger. If None, a default logger will be created.
23
26
24
27
Methods:
25
28
reset_memory: Resets the memory bank.
@@ -34,19 +37,38 @@ class AgenticMemoryController:
34
37
35
38
def __init__ (
36
39
self ,
37
- settings : Dict [str , Any ],
38
40
reset : bool ,
39
41
client : ChatCompletionClient ,
40
42
task_assignment_callback : Callable [[str ], Awaitable [Tuple [str , str ]]],
41
- logger : PageLogger ,
43
+ config : Dict [str , Any ] | None = None ,
44
+ logger : PageLogger | None = None ,
42
45
) -> None :
46
+ if logger is None :
47
+ logger = PageLogger ({"level" : "INFO" })
43
48
self .logger = logger
44
49
self .logger .enter_function ()
45
- self .settings = settings
50
+
51
+ # Assign default values that can be overridden by config.
52
+ self .max_train_trials = 10
53
+ self .max_test_trials = 3
54
+ agentic_memory_bank_config = None
55
+
56
+ if config is not None :
57
+ # Apply any overrides from the config.
58
+ for key in config :
59
+ if key == "max_train_trials" :
60
+ self .max_train_trials = config [key ]
61
+ elif key == "max_test_trials" :
62
+ self .max_test_trials = config [key ]
63
+ elif key == "AgenticMemoryBank" :
64
+ agentic_memory_bank_config = config [key ]
65
+ else :
66
+ self .logger .error ('Unexpected item in config: ["{}"] = {}' .format (key , config [key ]))
67
+
46
68
self .client = client
47
69
self .task_assignment_callback = task_assignment_callback
48
70
self .prompter = Prompter (client , logger )
49
- self .memory_bank = AgenticMemoryBank (self . settings [ "AgenticMemoryBank" ], reset = reset , logger = logger )
71
+ self .memory_bank = AgenticMemoryBank (reset = reset , config = agentic_memory_bank_config , logger = logger )
50
72
self .grader = Grader (client , logger )
51
73
self .logger .leave_function ()
52
74
@@ -62,9 +84,7 @@ async def train_on_task(self, task: str, expected_answer: str) -> None:
62
84
"""
63
85
self .logger .enter_function ()
64
86
self .logger .info ("Iterate on the task, possibly discovering a useful new insight.\n " )
65
- _ , insight = await self ._iterate_on_task (
66
- task , expected_answer , self .settings ["max_train_trials" ], self .settings ["max_test_trials" ]
67
- )
87
+ _ , insight = await self ._iterate_on_task (task , expected_answer )
68
88
if insight is None :
69
89
self .logger .info ("No useful insight was discovered.\n " )
70
90
else :
@@ -219,7 +239,7 @@ def _format_memory_section(self, memories: List[str]) -> str:
219
239
return memory_section
220
240
221
241
async def _test_for_failure (
222
- self , task : str , task_plus_insights : str , expected_answer : str , num_trials : int
242
+ self , task : str , task_plus_insights : str , expected_answer : str
223
243
) -> Tuple [bool , str , str ]:
224
244
"""
225
245
Attempts to solve the given task multiple times to find a failure case to learn from.
@@ -231,7 +251,7 @@ async def _test_for_failure(
231
251
failure_found = False
232
252
response , work_history = "" , ""
233
253
234
- for trial in range (num_trials ):
254
+ for trial in range (self . max_test_trials ):
235
255
self .logger .info ("\n ----- TRIAL {} -----\n " .format (trial + 1 ))
236
256
237
257
# Attempt to solve the task.
@@ -252,9 +272,7 @@ async def _test_for_failure(
252
272
self .logger .leave_function ()
253
273
return failure_found , response , work_history
254
274
255
- async def _iterate_on_task (
256
- self , task : str , expected_answer : str , max_train_trials : int , max_test_trials : int
257
- ) -> Tuple [str , None | str ]:
275
+ async def _iterate_on_task (self , task : str , expected_answer : str ) -> Tuple [str , None | str ]:
258
276
"""
259
277
Repeatedly assigns a task to the agent, and tries to learn from failures by creating useful insights as memories.
260
278
"""
@@ -270,7 +288,7 @@ async def _iterate_on_task(
270
288
successful_insight = None
271
289
272
290
# Loop until success (or timeout) while learning from failures.
273
- for trial in range (1 , max_train_trials + 1 ):
291
+ for trial in range (1 , self . max_train_trials + 1 ):
274
292
self .logger .info ("\n ----- TRAIN TRIAL {} -----\n " .format (trial ))
275
293
task_plus_insights = task
276
294
@@ -284,7 +302,7 @@ async def _iterate_on_task(
284
302
285
303
# Can we find a failure case to learn from?
286
304
failure_found , response , work_history = await self ._test_for_failure (
287
- task , task_plus_insights , expected_answer , max_test_trials
305
+ task , task_plus_insights , expected_answer
288
306
)
289
307
if not failure_found :
290
308
# No. Time to exit the loop.
@@ -299,7 +317,7 @@ async def _iterate_on_task(
299
317
break
300
318
301
319
# Will we try again?
302
- if trial == max_train_trials :
320
+ if trial == self . max_train_trials :
303
321
# No. We're out of training trials.
304
322
self .logger .info ("\n No more trials will be attempted.\n " )
305
323
break
0 commit comments