Skip to content

Commit f1652d1

Browse files
authoredJul 30, 2023
Add hyperbolic ops (#634)
* add hyperbolic numpy functions * update backend modules + tests + black * docstrings * fix * remove ops.nn.tanh
1 parent c153bac commit f1652d1

File tree

9 files changed

+353
-26
lines changed

9 files changed

+353
-26
lines changed
 

‎guides/distributed_training_with_jax.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def get_model():
5858
# Make a simple convnet with batch normalization and dropout.
5959
inputs = keras.Input(shape=(28, 28, 1))
6060
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
61-
x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
62-
x
63-
)
61+
x = keras.layers.Conv2D(
62+
filters=12, kernel_size=3, padding="same", use_bias=False
63+
)(x)
6464
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
6565
x = keras.layers.ReLU()(x)
6666
x = keras.layers.Conv2D(
@@ -187,7 +187,11 @@ def compute_loss(trainable_variables, non_trainable_variables, x, y):
187187
# Training step, Keras provides a pure functional optimizer.stateless_apply
188188
@jax.jit
189189
def train_step(train_state, x, y):
190-
trainable_variables, non_trainable_variables, optimizer_variables = train_state
190+
(
191+
trainable_variables,
192+
non_trainable_variables,
193+
optimizer_variables,
194+
) = train_state
191195
(loss_value, non_trainable_variables), grads = compute_gradients(
192196
trainable_variables, non_trainable_variables, x, y
193197
)
@@ -211,7 +215,9 @@ def get_replicated_train_state(devices):
211215
var_replication = NamedSharding(var_mesh, P())
212216

213217
# Apply the distribution settings to the model variables
214-
trainable_variables = jax.device_put(model.trainable_variables, var_replication)
218+
trainable_variables = jax.device_put(
219+
model.trainable_variables, var_replication
220+
)
215221
non_trainable_variables = jax.device_put(
216222
model.non_trainable_variables, var_replication
217223
)
@@ -255,7 +261,9 @@ def get_replicated_train_state(devices):
255261
trainable_variables, non_trainable_variables, optimizer_variables = train_state
256262
for variable, value in zip(model.trainable_variables, trainable_variables):
257263
variable.assign(value)
258-
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
264+
for variable, value in zip(
265+
model.non_trainable_variables, non_trainable_variables
266+
):
259267
variable.assign(value)
260268

261269
"""

‎guides/distributed_training_with_torch.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def get_model():
5353
# Make a simple convnet with batch normalization and dropout.
5454
inputs = keras.Input(shape=(28, 28, 1))
5555
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
56-
x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
57-
x
58-
)
56+
x = keras.layers.Conv2D(
57+
filters=12, kernel_size=3, padding="same", use_bias=False
58+
)(x)
5959
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
6060
x = keras.layers.ReLU()(x)
6161
x = keras.layers.Conv2D(
@@ -231,7 +231,9 @@ def per_device_launch_fn(current_gpu_index, num_gpu):
231231
model = get_model()
232232

233233
# prepare the dataloader
234-
dataloader = prepare_dataloader(dataset, current_gpu_index, num_gpu, batch_size)
234+
dataloader = prepare_dataloader(
235+
dataset, current_gpu_index, num_gpu, batch_size
236+
)
235237

236238
# Instantiate the torch optimizer
237239
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

‎keras_core/backend/jax/numpy.py

+24
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,18 @@ def arccos(x):
107107
return jnp.arccos(x)
108108

109109

110+
def arccosh(x):
111+
return jnp.arccosh(x)
112+
113+
110114
def arcsin(x):
111115
return jnp.arcsin(x)
112116

113117

118+
def arcsinh(x):
119+
return jnp.arcsinh(x)
120+
121+
114122
def arctan(x):
115123
return jnp.arctan(x)
116124

@@ -119,6 +127,10 @@ def arctan2(x1, x2):
119127
return jnp.arctan2(x1, x2)
120128

121129

130+
def arctanh(x):
131+
return jnp.arctanh(x)
132+
133+
122134
def argmax(x, axis=None):
123135
return jnp.argmax(x, axis=axis)
124136

@@ -171,6 +183,10 @@ def cos(x):
171183
return jnp.cos(x)
172184

173185

186+
def cosh(x):
187+
return jnp.cosh(x)
188+
189+
174190
def count_nonzero(x, axis=None):
175191
return jnp.count_nonzero(x, axis=axis)
176192

@@ -441,6 +457,10 @@ def sin(x):
441457
return jnp.sin(x)
442458

443459

460+
def sinh(x):
461+
return jnp.sinh(x)
462+
463+
444464
def size(x):
445465
return jnp.size(x)
446466

@@ -479,6 +499,10 @@ def tan(x):
479499
return jnp.tan(x)
480500

481501

502+
def tanh(x):
503+
return jnp.tanh(x)
504+
505+
482506
def tensordot(x1, x2, axes=2):
483507
x1 = convert_to_tensor(x1)
484508
x2 = convert_to_tensor(x2)

‎keras_core/backend/numpy/numpy.py

+24
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,18 @@ def arccos(x):
8484
return np.arccos(x)
8585

8686

87+
def arccosh(x):
88+
return np.arccosh(x)
89+
90+
8791
def arcsin(x):
8892
return np.arcsin(x)
8993

9094

95+
def arcsinh(x):
96+
return np.arcsinh(x)
97+
98+
9199
def arctan(x):
92100
return np.arctan(x)
93101

@@ -96,6 +104,10 @@ def arctan2(x1, x2):
96104
return np.arctan2(x1, x2)
97105

98106

107+
def arctanh(x):
108+
return np.arctanh(x)
109+
110+
99111
def argmax(x, axis=None):
100112
axis = tuple(axis) if isinstance(axis, list) else axis
101113
return np.argmax(x, axis=axis)
@@ -157,6 +169,10 @@ def cos(x):
157169
return np.cos(x)
158170

159171

172+
def cosh(x):
173+
return np.cosh(x)
174+
175+
160176
def count_nonzero(x, axis=None):
161177
axis = tuple(axis) if isinstance(axis, list) else axis
162178
return np.count_nonzero(x, axis=axis)
@@ -438,6 +454,10 @@ def sin(x):
438454
return np.sin(x)
439455

440456

457+
def sinh(x):
458+
return np.sinh(x)
459+
460+
441461
def size(x):
442462
return np.size(x)
443463

@@ -480,6 +500,10 @@ def tan(x):
480500
return np.tan(x)
481501

482502

503+
def tanh(x):
504+
return np.tanh(x)
505+
506+
483507
def tensordot(x1, x2, axes=2):
484508
axes = tuple(axes) if isinstance(axes, list) else axes
485509
return np.tensordot(x1, x2, axes=axes)

‎keras_core/backend/tensorflow/numpy.py

+24
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,18 @@ def arccos(x):
105105
return tfnp.arccos(x)
106106

107107

108+
def arccosh(x):
109+
return tfnp.arccosh(x)
110+
111+
108112
def arcsin(x):
109113
return tfnp.arcsin(x)
110114

111115

116+
def arcsinh(x):
117+
return tfnp.arcsinh(x)
118+
119+
112120
def arctan(x):
113121
return tfnp.arctan(x)
114122

@@ -117,6 +125,10 @@ def arctan2(x1, x2):
117125
return tfnp.arctan2(x1, x2)
118126

119127

128+
def arctanh(x):
129+
return tfnp.arctanh(x)
130+
131+
120132
def argmax(x, axis=None):
121133
return tfnp.argmax(x, axis=axis)
122134

@@ -174,6 +186,10 @@ def cos(x):
174186
return tfnp.cos(x)
175187

176188

189+
def cosh(x):
190+
return tfnp.cosh(x)
191+
192+
177193
def count_nonzero(x, axis=None):
178194
return tfnp.count_nonzero(x, axis=axis)
179195

@@ -472,6 +488,10 @@ def sin(x):
472488
return tfnp.sin(x)
473489

474490

491+
def sinh(x):
492+
return tfnp.sinh(x)
493+
494+
475495
def size(x):
476496
return tfnp.size(x)
477497

@@ -508,6 +528,10 @@ def tan(x):
508528
return tfnp.tan(x)
509529

510530

531+
def tanh(x):
532+
return tfnp.tanh(x)
533+
534+
511535
def tensordot(x1, x2, axes=2):
512536
return tfnp.tensordot(x1, x2, axes=axes)
513537

‎keras_core/backend/torch/numpy.py

+30
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,21 @@ def arccos(x):
173173
return torch.arccos(x)
174174

175175

176+
def arccosh(x):
177+
x = convert_to_tensor(x)
178+
return torch.arccosh(x)
179+
180+
176181
def arcsin(x):
177182
x = convert_to_tensor(x)
178183
return torch.arcsin(x)
179184

180185

186+
def arcsinh(x):
187+
x = convert_to_tensor(x)
188+
return torch.arcsinh(x)
189+
190+
181191
def arctan(x):
182192
x = convert_to_tensor(x)
183193
return torch.arctan(x)
@@ -188,6 +198,11 @@ def arctan2(x1, x2):
188198
return torch.arctan2(x1, x2)
189199

190200

201+
def arctanh(x):
202+
x = convert_to_tensor(x)
203+
return torch.arctanh(x)
204+
205+
191206
def argmax(x, axis=None):
192207
x = convert_to_tensor(x)
193208
return torch.argmax(x, dim=axis)
@@ -277,6 +292,11 @@ def cos(x):
277292
return torch.cos(x)
278293

279294

295+
def cosh(x):
296+
x = convert_to_tensor(x)
297+
return torch.cosh(x)
298+
299+
280300
def count_nonzero(x, axis=None):
281301
x = convert_to_tensor(x)
282302
if axis == () or axis == []:
@@ -729,6 +749,11 @@ def sin(x):
729749
return torch.sin(x)
730750

731751

752+
def sinh(x):
753+
x = convert_to_tensor(x)
754+
return torch.sinh(x)
755+
756+
732757
def size(x):
733758
x_shape = convert_to_tensor(tuple(x.shape))
734759
return torch.prod(x_shape)
@@ -806,6 +831,11 @@ def tan(x):
806831
return torch.tan(x)
807832

808833

834+
def tanh(x):
835+
x = convert_to_tensor(x)
836+
return torch.tanh(x)
837+
838+
809839
def tensordot(x1, x2, axes=2):
810840
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
811841
# Conversion to long necessary for `torch.tensordot`

‎keras_core/ops/nn.py

-15
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,6 @@ def sigmoid(x):
5757
return backend.nn.sigmoid(x)
5858

5959

60-
class Tanh(Operation):
61-
def call(self, x):
62-
return backend.nn.tanh(x)
63-
64-
def compute_output_spec(self, x):
65-
return KerasTensor(x.shape, dtype=x.dtype)
66-
67-
68-
@keras_core_export(["keras_core.ops.tanh", "keras_core.ops.nn.tanh"])
69-
def tanh(x):
70-
if any_symbolic_tensors((x,)):
71-
return Tanh().symbolic_call(x)
72-
return backend.nn.tanh(x)
73-
74-
7560
class Softplus(Operation):
7661
def call(self, x):
7762
return backend.nn.softplus(x)

‎keras_core/ops/numpy.py

+143
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
append
1111
arange
1212
arccos
13+
arccosh
1314
arcsin
15+
arcsinh
1416
arctan
1517
arctan2
18+
arctanh
1619
argmax
1720
argmin
1821
argsort
@@ -27,6 +30,7 @@
2730
conjugate
2831
copy
2932
cos
33+
cosh
3034
count_nonzero
3135
cross
3236
cumprod
@@ -102,6 +106,7 @@
102106
round
103107
sign
104108
sin
109+
sinh
105110
size
106111
sort
107112
split
@@ -116,6 +121,7 @@
116121
take
117122
take_along_axis
118123
tan
124+
tanh
119125
tensordot
120126
tile
121127
trace
@@ -713,6 +719,28 @@ def arccos(x):
713719
return backend.numpy.arccos(x)
714720

715721

722+
class Arccosh(Operation):
723+
def call(self, x):
724+
return backend.numpy.arccosh(x)
725+
726+
def compute_output_spec(self, x):
727+
return KerasTensor(x.shape, dtype=x.dtype)
728+
729+
730+
def arccosh(x):
731+
"""Inverse hyperbolic cosine, element-wise.
732+
733+
Arguments:
734+
x: Input tensor.
735+
736+
Returns:
737+
Output tensor of same shape as x.
738+
"""
739+
if any_symbolic_tensors((x,)):
740+
return Arccosh().symbolic_call(x)
741+
return backend.numpy.arccosh(x)
742+
743+
716744
class Arcsin(Operation):
717745
def call(self, x):
718746
return backend.numpy.arcsin(x)
@@ -742,6 +770,29 @@ def arcsin(x):
742770
return backend.numpy.arcsin(x)
743771

744772

773+
class Arcsinh(Operation):
774+
def call(self, x):
775+
return backend.numpy.arcsinh(x)
776+
777+
def compute_output_spec(self, x):
778+
return KerasTensor(x.shape, dtype=x.dtype)
779+
780+
781+
@keras_core_export(["keras_core.ops.arcsinh", "keras_core.ops.numpy.arcsinh"])
782+
def arcsinh(x):
783+
"""Inverse hyperbolic sine, element-wise.
784+
785+
Arguments:
786+
x: Input tensor.
787+
788+
Returns:
789+
Output tensor of same shape as x.
790+
"""
791+
if any_symbolic_tensors((x,)):
792+
return Arcsinh().symbolic_call(x)
793+
return backend.numpy.arcsinh(x)
794+
795+
745796
class Arctan(Operation):
746797
def call(self, x):
747798
return backend.numpy.arctan(x)
@@ -826,6 +877,29 @@ def arctan2(x1, x2):
826877
return backend.numpy.arctan2(x1, x2)
827878

828879

880+
class Arctanh(Operation):
881+
def call(self, x):
882+
return backend.numpy.arctanh(x)
883+
884+
def compute_output_spec(self, x):
885+
return KerasTensor(x.shape, dtype=x.dtype)
886+
887+
888+
@keras_core_export(["keras_core.ops.arctanh", "keras_core.ops.numpy.arctanh"])
889+
def arctanh(x):
890+
"""Inverse hyperbolic tangent, element-wise.
891+
892+
Arguments:
893+
x: Input tensor.
894+
895+
Returns:
896+
Output tensor of same shape as x.
897+
"""
898+
if any_symbolic_tensors((x,)):
899+
return Arctanh().symbolic_call(x)
900+
return backend.numpy.arctanh(x)
901+
902+
829903
class Argmax(Operation):
830904
def __init__(self, axis=None):
831905
super().__init__()
@@ -1289,6 +1363,29 @@ def cos(x):
12891363
return backend.numpy.cos(x)
12901364

12911365

1366+
class Cosh(Operation):
1367+
def call(self, x):
1368+
return backend.numpy.cosh(x)
1369+
1370+
def compute_output_spec(self, x):
1371+
return KerasTensor(x.shape, dtype=x.dtype)
1372+
1373+
1374+
@keras_core_export(["keras_core.ops.cosh", "keras_core.ops.numpy.cosh"])
1375+
def cosh(x):
1376+
"""Hyperbolic cosine, element-wise.
1377+
1378+
Arguments:
1379+
x: Input tensor.
1380+
1381+
Returns:
1382+
Output tensor of same shape as x.
1383+
"""
1384+
if any_symbolic_tensors((x,)):
1385+
return Cosh().symbolic_call(x)
1386+
return backend.numpy.cosh(x)
1387+
1388+
12921389
class CountNonzero(Operation):
12931390
def __init__(self, axis=None):
12941391
super().__init__()
@@ -3115,6 +3212,29 @@ def sin(x):
31153212
return backend.numpy.sin(x)
31163213

31173214

3215+
class Sinh(Operation):
3216+
def call(self, x):
3217+
return backend.numpy.sinh(x)
3218+
3219+
def compute_output_spec(self, x):
3220+
return KerasTensor(x.shape, dtype=x.dtype)
3221+
3222+
3223+
@keras_core_export(["keras_core.ops.sinh", "keras_core.ops.numpy.sinh"])
3224+
def sinh(x):
3225+
"""Hyperbolic sine, element-wise.
3226+
3227+
Arguments:
3228+
x: Input tensor.
3229+
3230+
Returns:
3231+
Output tensor of same shape as x.
3232+
"""
3233+
if any_symbolic_tensors((x,)):
3234+
return Sinh().symbolic_call(x)
3235+
return backend.numpy.sinh(x)
3236+
3237+
31183238
class Size(Operation):
31193239
def call(self, x):
31203240
return backend.numpy.size(x)
@@ -3376,6 +3496,29 @@ def tan(x):
33763496
return backend.numpy.tan(x)
33773497

33783498

3499+
class Tanh(Operation):
3500+
def call(self, x):
3501+
return backend.numpy.tanh(x)
3502+
3503+
def compute_output_spec(self, x):
3504+
return KerasTensor(x.shape, dtype=x.dtype)
3505+
3506+
3507+
@keras_core_export(["keras_core.ops.tanh", "keras_core.ops.numpy.tanh"])
3508+
def tanh(x):
3509+
"""Hyperbolic tangent, element-wise.
3510+
3511+
Arguments:
3512+
x: Input tensor.
3513+
3514+
Returns:
3515+
Output tensor of same shape as x.
3516+
"""
3517+
if any_symbolic_tensors((x,)):
3518+
return Tanh().symbolic_call(x)
3519+
return backend.numpy.tanh(x)
3520+
3521+
33793522
class Tensordot(Operation):
33803523
def __init__(self, axes=2):
33813524
super().__init__()

‎keras_core/ops/numpy_test.py

+88-1
Original file line numberDiff line numberDiff line change
@@ -754,14 +754,26 @@ def test_arccos(self):
754754
x = KerasTensor([None, 3])
755755
self.assertEqual(knp.arccos(x).shape, (None, 3))
756756

757+
def test_arccosh(self):
758+
x = KerasTensor([None, 3])
759+
self.assertEqual(knp.arccosh(x).shape, (None, 3))
760+
757761
def test_arcsin(self):
758762
x = KerasTensor([None, 3])
759763
self.assertEqual(knp.arcsin(x).shape, (None, 3))
760764

765+
def test_arcsinh(self):
766+
x = KerasTensor([None, 3])
767+
self.assertEqual(knp.arcsinh(x).shape, (None, 3))
768+
761769
def test_arctan(self):
762770
x = KerasTensor([None, 3])
763771
self.assertEqual(knp.arctan(x).shape, (None, 3))
764772

773+
def test_arctanh(self):
774+
x = KerasTensor([None, 3])
775+
self.assertEqual(knp.arctanh(x).shape, (None, 3))
776+
765777
def test_argmax(self):
766778
x = KerasTensor([None, 3])
767779
self.assertEqual(knp.argmax(x).shape, ())
@@ -855,6 +867,10 @@ def test_cos(self):
855867
x = KerasTensor([None, 3])
856868
self.assertEqual(knp.cos(x).shape, (None, 3))
857869

870+
def test_cosh(self):
871+
x = KerasTensor([None, 3])
872+
self.assertEqual(knp.cosh(x).shape, (None, 3))
873+
858874
def test_count_nonzero(self):
859875
x = KerasTensor([None, 3])
860876
self.assertEqual(knp.count_nonzero(x).shape, ())
@@ -1095,6 +1111,10 @@ def test_sin(self):
10951111
x = KerasTensor([None, 3])
10961112
self.assertEqual(knp.sin(x).shape, (None, 3))
10971113

1114+
def test_sinh(self):
1115+
x = KerasTensor([None, 3])
1116+
self.assertEqual(knp.sinh(x).shape, (None, 3))
1117+
10981118
def test_size(self):
10991119
x = KerasTensor([None, 3])
11001120
self.assertEqual(knp.size(x).shape, ())
@@ -1137,6 +1157,10 @@ def test_tan(self):
11371157
x = KerasTensor([None, 3])
11381158
self.assertEqual(knp.tan(x).shape, (None, 3))
11391159

1160+
def test_tanh(self):
1161+
x = KerasTensor([None, 3])
1162+
self.assertEqual(knp.tanh(x).shape, (None, 3))
1163+
11401164
def test_tile(self):
11411165
x = KerasTensor([None, 3])
11421166
self.assertEqual(knp.tile(x, [2]).shape, (None, 6))
@@ -1227,14 +1251,26 @@ def test_arccos(self):
12271251
x = KerasTensor([2, 3])
12281252
self.assertEqual(knp.arccos(x).shape, (2, 3))
12291253

1254+
def test_arccosh(self):
1255+
x = KerasTensor([2, 3])
1256+
self.assertEqual(knp.arccosh(x).shape, (2, 3))
1257+
12301258
def test_arcsin(self):
12311259
x = KerasTensor([2, 3])
12321260
self.assertEqual(knp.arcsin(x).shape, (2, 3))
12331261

1262+
def test_arcsinh(self):
1263+
x = KerasTensor([2, 3])
1264+
self.assertEqual(knp.arcsinh(x).shape, (2, 3))
1265+
12341266
def test_arctan(self):
12351267
x = KerasTensor([2, 3])
12361268
self.assertEqual(knp.arctan(x).shape, (2, 3))
12371269

1270+
def test_arctanh(self):
1271+
x = KerasTensor([2, 3])
1272+
self.assertEqual(knp.arctanh(x).shape, (2, 3))
1273+
12381274
def test_argmax(self):
12391275
x = KerasTensor([2, 3])
12401276
self.assertEqual(knp.argmax(x).shape, ())
@@ -1297,6 +1333,10 @@ def test_cos(self):
12971333
x = KerasTensor([2, 3])
12981334
self.assertEqual(knp.cos(x).shape, (2, 3))
12991335

1336+
def test_cosh(self):
1337+
x = KerasTensor([2, 3])
1338+
self.assertEqual(knp.cosh(x).shape, (2, 3))
1339+
13001340
def test_count_nonzero(self):
13011341
x = KerasTensor([2, 3])
13021342
self.assertEqual(knp.count_nonzero(x).shape, ())
@@ -1532,6 +1572,10 @@ def test_sin(self):
15321572
x = KerasTensor([2, 3])
15331573
self.assertEqual(knp.sin(x).shape, (2, 3))
15341574

1575+
def test_sinh(self):
1576+
x = KerasTensor([2, 3])
1577+
self.assertEqual(knp.sinh(x).shape, (2, 3))
1578+
15351579
def test_size(self):
15361580
x = KerasTensor([2, 3])
15371581
self.assertEqual(knp.size(x).shape, ())
@@ -1579,6 +1623,10 @@ def test_tan(self):
15791623
x = KerasTensor([2, 3])
15801624
self.assertEqual(knp.tan(x).shape, (2, 3))
15811625

1626+
def test_tanh(self):
1627+
x = KerasTensor([2, 3])
1628+
self.assertEqual(knp.tanh(x).shape, (2, 3))
1629+
15821630
def test_tile(self):
15831631
x = KerasTensor([2, 3])
15841632
self.assertEqual(knp.tile(x, [2]).shape, (2, 6))
@@ -2266,18 +2314,42 @@ def test_transpose(self):
22662314
np.transpose(x, axes=(1, 0, 3, 2, 4)),
22672315
)
22682316

2269-
def test_arcos(self):
2317+
def test_arccos(self):
22702318
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
22712319
self.assertAllClose(knp.arccos(x), np.arccos(x))
22722320

22732321
self.assertAllClose(knp.Arccos()(x), np.arccos(x))
22742322

2323+
def test_arccosh(self):
2324+
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
2325+
self.assertAllClose(knp.arccosh(x), np.arccosh(x))
2326+
2327+
self.assertAllClose(knp.Arccosh()(x), np.arccosh(x))
2328+
22752329
def test_arcsin(self):
22762330
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
22772331
self.assertAllClose(knp.arcsin(x), np.arcsin(x))
22782332

22792333
self.assertAllClose(knp.Arcsin()(x), np.arcsin(x))
22802334

2335+
def test_arcsinh(self):
2336+
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
2337+
self.assertAllClose(knp.arcsinh(x), np.arcsinh(x))
2338+
2339+
self.assertAllClose(knp.Arcsinh()(x), np.arcsinh(x))
2340+
2341+
def test_arctan(self):
2342+
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
2343+
self.assertAllClose(knp.arctan(x), np.arctan(x))
2344+
2345+
self.assertAllClose(knp.Arctan()(x), np.arctan(x))
2346+
2347+
def test_arctanh(self):
2348+
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
2349+
self.assertAllClose(knp.arctanh(x), np.arctanh(x))
2350+
2351+
self.assertAllClose(knp.Arctanh()(x), np.arctanh(x))
2352+
22812353
def test_argmax(self):
22822354
x = np.array([[1, 2, 3], [3, 2, 1]])
22832355
self.assertAllClose(knp.argmax(x), np.argmax(x))
@@ -2433,6 +2505,11 @@ def test_cos(self):
24332505
self.assertAllClose(knp.cos(x), np.cos(x))
24342506
self.assertAllClose(knp.Cos()(x), np.cos(x))
24352507

2508+
def test_cosh(self):
2509+
x = np.array([[1, 2, 3], [3, 2, 1]])
2510+
self.assertAllClose(knp.cosh(x), np.cosh(x))
2511+
self.assertAllClose(knp.Cosh()(x), np.cosh(x))
2512+
24362513
def test_count_nonzero(self):
24372514
x = np.array([[0, 2, 3], [3, 2, 0]])
24382515
self.assertAllClose(knp.count_nonzero(x), np.count_nonzero(x))
@@ -2899,6 +2976,11 @@ def test_sin(self):
28992976
self.assertAllClose(knp.sin(x), np.sin(x))
29002977
self.assertAllClose(knp.Sin()(x), np.sin(x))
29012978

2979+
def test_sinh(self):
2980+
x = np.array([[1, -2, 3], [-3, 2, -1]])
2981+
self.assertAllClose(knp.sinh(x), np.sinh(x))
2982+
self.assertAllClose(knp.Sinh()(x), np.sinh(x))
2983+
29022984
def test_size(self):
29032985
x = np.array([[1, 2, 3], [3, 2, 1]])
29042986
self.assertAllClose(knp.size(x), np.size(x))
@@ -2993,6 +3075,11 @@ def test_tan(self):
29933075
self.assertAllClose(knp.tan(x), np.tan(x))
29943076
self.assertAllClose(knp.Tan()(x), np.tan(x))
29953077

3078+
def test_tanh(self):
3079+
x = np.array([[1, -2, 3], [-3, 2, -1]])
3080+
self.assertAllClose(knp.tanh(x), np.tanh(x))
3081+
self.assertAllClose(knp.Tanh()(x), np.tanh(x))
3082+
29963083
def test_tile(self):
29973084
x = np.array([[1, 2, 3], [3, 2, 1]])
29983085
self.assertAllClose(knp.tile(x, [2, 3]), np.tile(x, [2, 3]))

0 commit comments

Comments
 (0)
Please sign in to comment.