Skip to content

Commit 512ffd0

Browse files
committed
feat(chainable): caution momentum, not update
1 parent baea766 commit 512ffd0

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

heavyball/chainable.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,12 @@ def _update_psgd_cache(cached, Q_cache, q):
364364
return Q_cache
365365

366366

367-
def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache):
367+
def _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad):
368368
if group.get('is_cached', False):
369-
return utils.precond_grad_cached_(cache_expr, update, *Q_cache)
370-
return utils.psgd_precond_grad(exprs[-1], update, *Q_mat)
369+
out = utils.precond_grad_cached_(cache_expr, update, *Q_cache, caution=group['caution'], grad=grad)
370+
out = utils.psgd_precond_grad(exprs[-1], update, *Q_mat, caution=group['caution'], grad=grad)
371+
group['caution'] = False # we already cautioned here - shouldn't do it again
372+
return out
371373

372374

373375
def _fused_cached_psgd_precond_grad(group, grad, param, cache_expr, exprs, update, Q_mat, Q_cache):
@@ -387,15 +389,15 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
387389
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
388390
Q_mat = _update_psgd_precond(cached, Q_cache, group, param,
389391
update if group['momentum_into_precond_update'] else grad, Q_mat, Q, exprs, prob)
390-
return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
392+
return _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
391393

392394

393395
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
394396
@no_state_no_foreach
395397
def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
396398
prob: Optional[callable] = None):
397399
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
398-
precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache)
400+
precond = _cached_psgd_precond_grad(group, cache_expr, exprs, update, Q_mat, Q_cache, grad)
399401
_ = _update_psgd_precond(cached, Q_cache, group, param, update if group['momentum_into_precond_update'] else grad,
400402
Q_mat, Q, exprs, prob)
401403
return precond
@@ -467,6 +469,8 @@ def _step(self, group):
467469
f'only supported with foreach=True (currently foreach={group["foreach"]}).')
468470
group['base_lr'] = group['lr']
469471

472+
caution = group['caution']
473+
470474
vals = list(self.split_p_and_g_in_group(group, should_promote=self.promote, beta1=utils.get_beta1(group)))
471475

472476
if not vals:
@@ -492,6 +496,7 @@ def _step(self, group):
492496
else:
493497
chain(self.state_, group, g, p, *self.fns)
494498

499+
group['caution'] = caution
495500
group['lr'] = group['prev_lr']
496501
group['step'] = None
497502

heavyball/utils.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -1300,7 +1300,10 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
13001300

13011301

13021302
@decorator_knowngood
1303-
def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool = True):
1303+
def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, caution: bool = False, grad: Optional[Tensor] = None,
1304+
cast: bool = True):
1305+
if caution:
1306+
ea = _compilable_cautioning(grad, ea)
13041307
md = min_dtype(list(cached_q) + [ea])
13051308
args = [q.to(md) for q in cached_q]
13061309
args = args + [ea.to(md)]
@@ -1312,8 +1315,8 @@ def precond_grad_cached_(expr: str, ea: Tensor, *cached_q: Tensor, cast: bool =
13121315

13131316
@decorator_knowngood
13141317
def _compilable_fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
1315-
precond = precond_grad_cached_(expr, ea, *cached_q, cast=False)
1316-
update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1318+
precond = precond_grad_cached_(expr, ea, *cached_q, caution=caution, grad=grad, cast=False)
1319+
update_param_(param, precond, lr, decay, caution=False)
13171320

13181321

13191322
def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, caution, *cached_q: Tensor):
@@ -1322,7 +1325,9 @@ def fused_precond_grad_cached_(expr: str, ea: Tensor, param, lr, grad, decay, ca
13221325

13231326

13241327
@decorator_knowngood
1325-
def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
1328+
def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor, caution: bool = False, grad: Optional[Tensor] = None):
1329+
if caution:
1330+
ea = _compilable_cautioning(grad, ea)
13261331
md = min_dtype(list(preconds) + [ea])
13271332
args = [q.to(md) for q in preconds]
13281333
args = args + args + [ea.to(md)]
@@ -1332,8 +1337,8 @@ def psgd_precond_grad(expr: str, ea: Tensor, *preconds: Tensor):
13321337

13331338
@decorator_knowngood
13341339
def _compilable_fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):
1335-
precond = psgd_precond_grad(expr, ea, *preconds)
1336-
update_param_(param, precond, lr, decay, caution=caution, grad=grad)
1340+
precond = psgd_precond_grad(expr, ea, *preconds, caution=caution, grad=grad)
1341+
update_param_(param, precond, lr, decay, caution=False, grad=grad)
13371342

13381343

13391344
def fused_psgd_precond_grad(expr: str, ea: Tensor, param, lr, grad, decay, caution, *preconds: Tensor):

0 commit comments

Comments
 (0)