Skip to content

Commit

Permalink
cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
Franck Mamalet committed Oct 21, 2024
1 parent 9d514c1 commit cd56dff
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 18 deletions.
1 change: 0 additions & 1 deletion deel/torchlip/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def forward(self, z, axis=1):
for _ in range(len(z.shape) - len(theta.shape)):
theta = theta.unsqueeze(-1)
x, y = z.split(z.shape[axis] // 2, axis)
print("aaaa", x.shape, y.shape, theta.shape)
selector = (x * torch.sin(0.5 * theta)) - (y * torch.cos(0.5 * theta))

a_2 = x * torch.cos(theta) + y * torch.sin(theta)
Expand Down
2 changes: 1 addition & 1 deletion deel/torchlip/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __init__(
) -> None:
if dilation != 1:
raise ValueError("SpectralConvTranspose2d does not support dilation rate")
if not output_padding in [0, None]:
if output_padding not in [0, None]:
raise ValueError("SpectralConvTranspose2d only supports output_padding=0")
torch.nn.ConvTranspose2d.__init__(
self,
Expand Down
29 changes: 13 additions & 16 deletions tests/utils_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,16 @@ def __call__(self, **kwargs):
return None


TauCategoricalCrossentropyLoss = TauCrossEntropyLoss
TauSparseCategoricalCrossentropyLoss = TauCrossEntropyLoss
TauBinaryCrossentropyLoss = TauBCEWithLogitsLoss

tInput = module_Unavailable_foo
AutoWeightClipConstraint = module_Unavailable_class
SpectralConstraint = module_Unavailable_class
FrobeniusConstraint = module_Unavailable_class
CondenseCallback = module_Unavailable_class
MonitorCallback = module_Unavailable_class
TauCategoricalCrossentropyLoss = TauCrossEntropyLoss
TauSparseCategoricalCrossentropyLoss = TauCrossEntropyLoss
TauBinaryCrossentropyLoss = TauBCEWithLogitsLoss
CategoricalProvableRobustAccuracy = module_Unavailable_class
BinaryProvableRobustAccuracy = module_Unavailable_class
CategoricalProvableAvgRobustness = module_Unavailable_class
Expand Down Expand Up @@ -236,35 +237,31 @@ def get_instance_withcheck(
get_instance_withreplacement, dict_keys_replace={"data_format": None}
),
KRLoss: partial(
get_instance_withcheck,
get_instance_withreplacement,
dict_keys_replace={"name": None},
list_keys_notimplemented=[],
),
HingeMarginLoss: partial(get_instance_withcheck, dict_keys_replace={"name": None}),
HingeMarginLoss: partial(
get_instance_withreplacement, dict_keys_replace={"name": None}
),
HKRLoss: partial(
get_instance_withcheck,
get_instance_withreplacement,
dict_keys_replace={"name": None},
list_keys_notimplemented=[],
),
HingeMulticlassLoss: partial(
get_instance_withcheck,
get_instance_withreplacement,
dict_keys_replace={"name": None},
list_keys_notimplemented=[],
),
HKRMulticlassLoss: partial(
get_instance_withcheck,
get_instance_withreplacement,
dict_keys_replace={"name": None},
list_keys_notimplemented=[],
),
KRMulticlassLoss: partial(
get_instance_withcheck,
get_instance_withreplacement,
dict_keys_replace={"name": None},
list_keys_notimplemented=[],
),
SoftHKRMulticlassLoss: partial(
get_instance_withcheck,
get_instance_withreplacement,
dict_keys_replace={"name": None},
list_keys_notimplemented=[],
),
tLinear: partial(
get_instance_withcheck,
Expand Down

0 comments on commit cd56dff

Please sign in to comment.