Skip to content

Commit 0015925

Browse files
committed
clean normalization
1 parent 3c61caf commit 0015925

File tree

2 files changed

+1
-26
lines changed

2 files changed

+1
-26
lines changed

deel/torchlip/modules/normalization.py

-24
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ def forward(self, x):
2323

2424

2525
LayerCentering2d = LayerCentering
26-
# class LayerCentering2D(LayerCentering):
27-
# def __init__(self, size = 1, dim=[-2,-1]):
28-
# super(LayerCentering2D, self).__init__(size = size,dim=[-2,-1])
2926

3027

3128
class BatchCentering(nn.Module):
@@ -72,25 +69,4 @@ def forward(self, x):
7269
return x - mean.view(mean_shape)
7370

7471

75-
# class BatchCenteringBiases(BatchCentering):
76-
# def __init__(self, size =1, dim=[0,-2,-1], momentum=0.05):
77-
# super(BatchCenteringBiases, self).__init__(size = size, dim = dim, momentum = momentum)
78-
# if isinstance(size, tuple):
79-
# self.alpha = nn.Parameter(torch.zeros(size), requires_grad=True)
80-
# else:
81-
# self.alpha = nn.Parameter(torch.zeros(1,size,1,1), requires_grad=True)
82-
83-
# def forward(self, x):
84-
# #print(x.mean(dim=self.dim, keepdim=True).abs().mean().cpu().numpy(), self.running_mean.abs().cpu().mean().numpy(), self.alpha.abs().mean().cpu().numpy())
85-
# #print(x.mean(dim=self.dim, keepdim=True).abs().mean().cpu().numpy(),(x.mean(dim=self.dim, keepdim=True)-self.running_mean).abs().mean().cpu().numpy())
86-
# return super().forward(x) + self.alpha
87-
8872
BatchCentering2d = BatchCentering
89-
90-
# class BatchCenteringBiases2D(BatchCenteringBiases):
91-
# def __init__(self, size =1, momentum=0.05):
92-
# super(BatchCenteringBiases2D, self).__init__(size = size, dim=[0,-2,-1],momentum=momentum)
93-
94-
# class BatchCentering2D(BatchCentering):
95-
# def __init__(self, size =1, momentum=0.05):
96-
# super(BatchCentering2D, self).__init__(size = size, dim=[0,-2,-1],momentum=momentum)

tests/test_normalization.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
# =====================================================================================
2727
import os
2828
import pytest
29-
from functools import partial
3029

3130
import numpy as np
3231

@@ -249,6 +248,6 @@ def test_BatchCentering_runningmean(size, input_shape, bias):
249248
mean_x = np.mean(x, axis=(0, 2, 3))
250249
x = uft.to_tensor(x)
251250
for _ in range(1000):
252-
y = bn(x)
251+
y = bn(x) # noqa: F841
253252

254253
np.testing.assert_allclose(bn.running_mean, mean_x, atol=1e-5)

0 commit comments

Comments
 (0)