Skip to content

Commit 01a0326

Browse files
authored
Optimizer torch optimizer performance (#482)
* add torch optimizers * addressing comments --------- Co-authored-by: Haifeng Jin <[email protected]>
1 parent c9bce12 commit 01a0326

File tree

6 files changed

+83
-10
lines changed

6 files changed

+83
-10
lines changed

benchmarks/torch_ctl_benchmark/README.md

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Benchmark the performance of torch custom training loop
22

3-
This directory contains benchmarks to compare the performance between Keras and
4-
Torch while using Torch custom training loop. The benchmark purpose is to
5-
understand the performance diff resulting from the modeling API choice (Keras
6-
or Torch).
3+
This directory contains benchmarks to compare the performance of a Keras model
4+
and a equivalent Torch model while using the same Torch custom training loop.
5+
6+
The benchmark purpose is to understand the performance diff resulting from the
7+
modeling API choice (Keras or Torch).
78

89
To run the benchmark, use the command below and change to your target:
910

Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from keras_core.backend.torch.optimizers.torch_optimizer import TorchOptimizer
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
3+
from keras_core.optimizers.base_optimizer import BaseOptimizer
4+
5+
6+
class TorchOptimizer(BaseOptimizer):
7+
def __new__(cls, *args, **kwargs):
8+
# Import locally to avoid circular imports.
9+
from keras_core import optimizers
10+
from keras_core.backend.torch.optimizers import torch_sgd
11+
12+
OPTIMIZERS = {optimizers.SGD: torch_sgd.SGD}
13+
if cls in OPTIMIZERS:
14+
return OPTIMIZERS[cls](*args, **kwargs)
15+
return super().__new__(cls)
16+
17+
def _apply_weight_decay(self, variables):
18+
if self.weight_decay is None:
19+
return
20+
21+
torch._foreach_mul_(
22+
[v.value for v in variables if self._use_weight_decay(v)],
23+
1 - self.weight_decay * self._get_current_learning_rate(),
24+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
3+
from keras_core import optimizers
4+
5+
6+
class SGD(optimizers.SGD):
7+
def _internal_apply_gradients(self, grads_and_vars):
8+
grads, trainable_variables = zip(*grads_and_vars)
9+
10+
self._parallel_update_step(
11+
grads,
12+
[v.value for v in trainable_variables],
13+
self._get_current_learning_rate(),
14+
)
15+
self.iterations.assign(self.iterations + 1)
16+
17+
def _parallel_update_step(
18+
self,
19+
grads,
20+
variables,
21+
learning_rate,
22+
):
23+
if self.momentum != 0:
24+
bufs = [
25+
self.momentums[self._get_variable_index(variable.value)]
26+
for variable in variables
27+
]
28+
29+
for i in range(len(bufs)):
30+
if bufs[i] is None:
31+
bufs[i] = torch.clone(grads[i]).detach()
32+
33+
torch._foreach_mul_(bufs, self.momentum)
34+
torch._foreach_add_(bufs, grads, alpha=-learning_rate)
35+
36+
if self.nesterov:
37+
torch._foreach_add_(variables, grads, alpha=-learning_rate)
38+
torch._foreach_add_(variables, bufs, alpha=self.momentum)
39+
else:
40+
torch._foreach_add_(variables, bufs)
41+
42+
else:
43+
torch._foreach_add_(variables, grads, alpha=-learning_rate)

keras_core/optimizers/optimizer.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
from keras_core.optimizers import base_optimizer
44

55
if backend.backend() == "tensorflow":
6-
from keras_core.backend.tensorflow import optimizer as tf_optimizer
6+
from keras_core.backend.tensorflow.optimizer import TFOptimizer
77

8-
BackendOptimizer = tf_optimizer.TFOptimizer
8+
BackendOptimizer = TFOptimizer
9+
elif backend.backend() == "torch":
10+
from keras_core.backend.torch.optimizers import TorchOptimizer
11+
12+
BackendOptimizer = TorchOptimizer
913
else:
1014
BackendOptimizer = base_optimizer.BaseOptimizer
1115

keras_core/optimizers/sgd_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_config(self):
2121
def test_single_step(self):
2222
optimizer = SGD(learning_rate=0.5)
2323
self.assertEqual(len(optimizer.variables), 2)
24-
grads = np.array([1.0, 6.0, 7.0, 2.0])
24+
grads = ops.array([1.0, 6.0, 7.0, 2.0])
2525
vars = backend.Variable([1.0, 2.0, 3.0, 4.0])
2626
optimizer.build([vars])
2727
optimizer.apply_gradients(zip([grads], [vars]))
@@ -32,7 +32,7 @@ def test_single_step(self):
3232

3333
def test_weight_decay(self):
3434
grads, var1, var2, var3 = (
35-
np.zeros(()),
35+
ops.zeros(()),
3636
backend.Variable(2.0),
3737
backend.Variable(2.0, name="exclude"),
3838
backend.Variable(2.0),
@@ -56,8 +56,8 @@ def test_correctness_with_golden(self):
5656
optimizer = SGD(nesterov=True)
5757

5858
x = backend.Variable(np.ones([10]))
59-
grads = np.arange(0.1, 1.1, 0.1)
60-
first_grads = np.full((10,), 0.01)
59+
grads = ops.arange(0.1, 1.1, 0.1)
60+
first_grads = ops.full((10,), 0.01)
6161

6262
# fmt: off
6363
golden = np.array(

0 commit comments

Comments
 (0)