Skip to content

Commit

Permalink
simplify householder reshape computation + correct test
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Oct 22, 2024
1 parent 3ea5d10 commit 3c61caf
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
5 changes: 2 additions & 3 deletions deel/torchlip/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,8 @@ def __init__(self, channels, k_coef_lip: float = 1.0, theta_initializer=None):
raise ValueError(f"Unknown initializer {theta_initializer}")

def forward(self, z, axis=1):
theta = self.theta.to(z.device).view(1, -1)
for _ in range(len(z.shape) - len(theta.shape)):
theta = theta.unsqueeze(-1)
theta_shape = (1, -1) + (1,) * (len(z.shape) - 2)
theta = self.theta.to(z.device).view(theta_shape)
x, y = z.split(z.shape[axis] // 2, axis)
selector = (x * torch.sin(0.5 * theta)) - (y * torch.cos(0.5 * theta))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def check_serialization(layer_type, layer_params):
m = uft.generate_k_lip_model(layer_type, layer_params, input_shape=(10,), k=1)
if m is None:
return
optimizer, loss, _ = uft.compile_model(
loss, optimizer, _ = uft.compile_model(
m,
optimizer=uft.get_instance_framework(uft.SGD, inst_params={"model": m}),
loss=CategoricalCrossentropy(from_logits=True),
Expand Down

0 comments on commit 3c61caf

Please sign in to comment.