@@ -364,10 +364,12 @@ def _update_psgd_cache(cached, Q_cache, q):
364
364
return Q_cache
365
365
366
366
367
- def _cached_psgd_precond_grad (group , cache_expr , exprs , update , Q_mat , Q_cache ):
367
+ def _cached_psgd_precond_grad (group , cache_expr , exprs , update , Q_mat , Q_cache , grad ):
368
368
if group .get ('is_cached' , False ):
369
- return utils .precond_grad_cached_ (cache_expr , update , * Q_cache )
370
- return utils .psgd_precond_grad (exprs [- 1 ], update , * Q_mat )
369
+ out = utils .precond_grad_cached_ (cache_expr , update , * Q_cache , caution = group ['caution' ], grad = grad )
370
+ out = utils .psgd_precond_grad (exprs [- 1 ], update , * Q_mat , caution = group ['caution' ], grad = grad )
371
+ group ['caution' ] = False # we already cautioned here - shouldn't do it again
372
+ return out
371
373
372
374
373
375
def _fused_cached_psgd_precond_grad (group , grad , param , cache_expr , exprs , update , Q_mat , Q_cache ):
@@ -387,15 +389,15 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
387
389
Q_mat = utils .line_to_triu (Q ) if group ['store_triu_as_line' ] else Q
388
390
Q_mat = _update_psgd_precond (cached , Q_cache , group , param ,
389
391
update if group ['momentum_into_precond_update' ] else grad , Q_mat , Q , exprs , prob )
390
- return _cached_psgd_precond_grad (group , cache_expr , exprs , update , Q_mat , Q_cache )
392
+ return _cached_psgd_precond_grad (group , cache_expr , exprs , update , Q_mat , Q_cache , grad )
391
393
392
394
393
395
@general_guard ("Q" , "exprs" , ("Q_cache" , None ), ("cache_expr" , None ), init_fn = _init_psgd , skip_first = False )
394
396
@no_state_no_foreach
395
397
def scale_by_delayed_psgd (group , update , grad , param , Q , exprs , Q_cache , cache_expr : str , cached : bool = False ,
396
398
prob : Optional [callable ] = None ):
397
399
Q_mat = utils .line_to_triu (Q ) if group ['store_triu_as_line' ] else Q
398
- precond = _cached_psgd_precond_grad (group , cache_expr , exprs , update , Q_mat , Q_cache )
400
+ precond = _cached_psgd_precond_grad (group , cache_expr , exprs , update , Q_mat , Q_cache , grad )
399
401
_ = _update_psgd_precond (cached , Q_cache , group , param , update if group ['momentum_into_precond_update' ] else grad ,
400
402
Q_mat , Q , exprs , prob )
401
403
return precond
@@ -467,6 +469,8 @@ def _step(self, group):
467
469
f'only supported with foreach=True (currently foreach={ group ["foreach" ]} ).' )
468
470
group ['base_lr' ] = group ['lr' ]
469
471
472
+ caution = group ['caution' ]
473
+
470
474
vals = list (self .split_p_and_g_in_group (group , should_promote = self .promote , beta1 = utils .get_beta1 (group )))
471
475
472
476
if not vals :
@@ -492,6 +496,7 @@ def _step(self, group):
492
496
else :
493
497
chain (self .state_ , group , g , p , * self .fns )
494
498
499
+ group ['caution' ] = caution
495
500
group ['lr' ] = group ['prev_lr' ]
496
501
group ['step' ] = None
497
502
0 commit comments