Skip to content

Commit a7791f0

Browse files
committed
fix(chainable): correct nesterov momentum
1 parent dfc3299 commit a7791f0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

heavyball/chainable.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,13 @@ def orthogonalize_update(group, update, grad, param):
277277
@zero_guard("momentum")
278278
@no_state
279279
def nesterov_momentum(group, updates, grads, params, momentum):
280-
utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
280+
return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
281281

282282

283283
@zero_guard("momentum")
284284
@no_state
285285
def heavyball_momentum(group, updates, grads, params, momentum):
286-
utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
286+
return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
287287

288288

289289
@zero_guard("exp_avg", "exp_avg_sq")

0 commit comments

Comments
 (0)