Skip to content

Commit cd56dff

Browse files
author
Franck Mamalet
committed
cleaning
1 parent 9d514c1 commit cd56dff

File tree

3 files changed

+14
-18
lines changed

3 files changed

+14
-18
lines changed

deel/torchlip/modules/activation.py

-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ def forward(self, z, axis=1):
254254
for _ in range(len(z.shape) - len(theta.shape)):
255255
theta = theta.unsqueeze(-1)
256256
x, y = z.split(z.shape[axis] // 2, axis)
257-
print("aaaa", x.shape, y.shape, theta.shape)
258257
selector = (x * torch.sin(0.5 * theta)) - (y * torch.cos(0.5 * theta))
259258

260259
a_2 = x * torch.cos(theta) + y * torch.sin(theta)

deel/torchlip/modules/conv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def __init__(
215215
) -> None:
216216
if dilation != 1:
217217
raise ValueError("SpectralConvTranspose2d does not support dilation rate")
218-
if not output_padding in [0, None]:
218+
if output_padding not in [0, None]:
219219
raise ValueError("SpectralConvTranspose2d only supports output_padding=0")
220220
torch.nn.ConvTranspose2d.__init__(
221221
self,

tests/utils_framework.py

+13-16
Original file line numberDiff line numberDiff line change
@@ -143,15 +143,16 @@ def __call__(self, **kwargs):
143143
return None
144144

145145

146+
TauCategoricalCrossentropyLoss = TauCrossEntropyLoss
147+
TauSparseCategoricalCrossentropyLoss = TauCrossEntropyLoss
148+
TauBinaryCrossentropyLoss = TauBCEWithLogitsLoss
149+
146150
tInput = module_Unavailable_foo
147151
AutoWeightClipConstraint = module_Unavailable_class
148152
SpectralConstraint = module_Unavailable_class
149153
FrobeniusConstraint = module_Unavailable_class
150154
CondenseCallback = module_Unavailable_class
151155
MonitorCallback = module_Unavailable_class
152-
TauCategoricalCrossentropyLoss = TauCrossEntropyLoss
153-
TauSparseCategoricalCrossentropyLoss = TauCrossEntropyLoss
154-
TauBinaryCrossentropyLoss = TauBCEWithLogitsLoss
155156
CategoricalProvableRobustAccuracy = module_Unavailable_class
156157
BinaryProvableRobustAccuracy = module_Unavailable_class
157158
CategoricalProvableAvgRobustness = module_Unavailable_class
@@ -236,35 +237,31 @@ def get_instance_withcheck(
236237
get_instance_withreplacement, dict_keys_replace={"data_format": None}
237238
),
238239
KRLoss: partial(
239-
get_instance_withcheck,
240+
get_instance_withreplacement,
240241
dict_keys_replace={"name": None},
241-
list_keys_notimplemented=[],
242242
),
243-
HingeMarginLoss: partial(get_instance_withcheck, dict_keys_replace={"name": None}),
243+
HingeMarginLoss: partial(
244+
get_instance_withreplacement, dict_keys_replace={"name": None}
245+
),
244246
HKRLoss: partial(
245-
get_instance_withcheck,
247+
get_instance_withreplacement,
246248
dict_keys_replace={"name": None},
247-
list_keys_notimplemented=[],
248249
),
249250
HingeMulticlassLoss: partial(
250-
get_instance_withcheck,
251+
get_instance_withreplacement,
251252
dict_keys_replace={"name": None},
252-
list_keys_notimplemented=[],
253253
),
254254
HKRMulticlassLoss: partial(
255-
get_instance_withcheck,
255+
get_instance_withreplacement,
256256
dict_keys_replace={"name": None},
257-
list_keys_notimplemented=[],
258257
),
259258
KRMulticlassLoss: partial(
260-
get_instance_withcheck,
259+
get_instance_withreplacement,
261260
dict_keys_replace={"name": None},
262-
list_keys_notimplemented=[],
263261
),
264262
SoftHKRMulticlassLoss: partial(
265-
get_instance_withcheck,
263+
get_instance_withreplacement,
266264
dict_keys_replace={"name": None},
267-
list_keys_notimplemented=[],
268265
),
269266
tLinear: partial(
270267
get_instance_withcheck,

0 commit comments

Comments
 (0)