Skip to content

Commit a8721bb

Browse files
committed
use lp_lppool2d in ScaledAdaptativeL2NormPool2d ( #23 (comment)) + name change Global to Adaptative
1 parent f5e28ce commit a8721bb

File tree

7 files changed

+22
-33
lines changed

7 files changed

+22
-33
lines changed

deel/torchlip/modules/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
from .pooling import ScaledAdaptiveAvgPool2d
7676
from .pooling import ScaledAvgPool2d
7777
from .pooling import ScaledL2NormPool2d
78-
from .pooling import ScaledGlobalL2NormPool2d
78+
from .pooling import ScaledAdaptativeL2NormPool2d
7979
from .upsampling import InvertibleUpSampling
8080
from .normalization import LayerCentering
8181
from .normalization import BatchCentering

deel/torchlip/modules/pooling.py

+7-18
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
import numpy as np
3131
import torch
32+
import torch.nn.functional as F
3233
from torch.nn.common_types import _size_2_t
3334

3435
from ..utils import sqrt_with_gradeps
@@ -193,12 +194,13 @@ def vanilla_export(self):
193194
return self
194195

195196

196-
class ScaledGlobalL2NormPool2d(torch.nn.AdaptiveAvgPool2d, LipschitzModule):
197+
class ScaledAdaptativeL2NormPool2d(
198+
torch.nn.modules.pooling._AdaptiveAvgPoolNd, LipschitzModule
199+
):
197200
def __init__(
198201
self,
199202
output_size: _size_2_t = (1, 1),
200203
k_coef_lip: float = 1.0,
201-
eps_grad_sqrt: float = 1e-6,
202204
):
203205
"""
204206
Average pooling operation for spatial data, with a lipschitz bound. This
@@ -210,37 +212,24 @@ def __init__(
210212
Arguments:
211213
output_size: the target output size has to be (1,1)
212214
k_coef_lip: the lipschitz factor to ensure
213-
eps_grad_sqrt: Epsilon value to avoid numerical instability
214-
due to non-defined gradient at 0 in the sqrt function
215215
216216
Input shape:
217217
4D tensor with shape `(batch_size, channels, rows, cols)`.
218218
219219
Output shape:
220220
4D tensor with shape `(batch_size, channels, 1, 1)`.
221221
"""
222-
if eps_grad_sqrt < 0.0:
223-
raise RuntimeError("eps_grad_sqrt must be positive")
224222
if not isinstance(output_size, tuple) or len(output_size) != 2:
225223
raise RuntimeError("output_size must be a tuple of 2 integers")
226224
else:
227225
if output_size[0] != 1 or output_size[1] != 1:
228226
raise RuntimeError("output_size must be (1, 1)")
229-
torch.nn.AdaptiveAvgPool2d.__init__(self, output_size)
227+
torch.nn.modules.pooling._AdaptiveAvgPoolNd.__init__(self, output_size)
230228
LipschitzModule.__init__(self, k_coef_lip)
231-
self.eps_grad_sqrt = eps_grad_sqrt
232229

233230
def forward(self, input: torch.Tensor) -> torch.Tensor:
234-
# coeff = computePoolScalingFactor(input.shape[-2:]) * self._coefficient_lip
235-
# avg = torch.nn.AdaptiveAvgPool2d.forward(self, torch.square(input))
236-
# return sqrt_with_gradeps(avg,self.eps_grad_sqrt)* coeff
237-
return ( # type: ignore
238-
sqrt_with_gradeps(
239-
torch.square(input).sum(axis=(2, 3), keepdim=True),
240-
self.eps_grad_sqrt,
241-
)
242-
* self._coefficient_lip
243-
)
231+
return F.lp_pool2d(input, 2, input.shape[-2:]) * self._coefficient_lip
244232

245233
def vanilla_export(self):
246234
return self
235+

docs/source/basic_example.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ The following table indicates which module are safe to use in a Lipschitz networ
5454
-
5555
* - :class:`torch.nn.AvgPool2d`\ :raw-html-m2r:`<br>`\ :class:`torch.nn.AdaptiveAvgPool2d`
5656
- no
57-
- :class:`.ScaledAvgPool2d`\ :raw-html-m2r:`<br>`\ :class:`.ScaledAdaptiveAvgPool2d` \ :raw-html-m2r:`<br>` \ :class:`.ScaledL2NormPool2d` \ :raw-html-m2r:`<br>` \ :class:`.ScaledGlobalL2NormPool2d`
57+
- :class:`.ScaledAvgPool2d`\ :raw-html-m2r:`<br>`\ :class:`.ScaledAdaptiveAvgPool2d` \ :raw-html-m2r:`<br>` \ :class:`.ScaledL2NormPool2d` \ :raw-html-m2r:`<br>` \ :class:`.ScaledAdaptativeL2NormPool2d`
5858
- The Lipschitz constant is bounded by ``sqrt(pool_h * pool_w)``.
5959
* - :class:`Flatten`
6060
- yes

docs/source/deel.torchlip.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Pooling Layers
3939
.. autoclass:: ScaledAdaptiveAvgPool2d
4040
.. autoclass:: ScaledAvgPool2d
4141
.. autoclass:: ScaledL2NormPool2d
42-
.. autoclass:: ScaledGlobalL2NormPool2d
42+
.. autoclass:: ScaledAdaptativeL2NormPool2d
4343
.. autoclass:: InvertibleDownSampling
4444
.. autoclass:: InvertibleUpSampling
4545

tests/test_layers.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
ScaledL2NormPool2d,
4545
InvertibleDownSampling,
4646
InvertibleUpSampling,
47-
ScaledGlobalL2NormPool2d,
47+
ScaledAdaptativeL2NormPool2d,
4848
Flatten,
4949
Sequential,
5050
)
@@ -877,7 +877,7 @@ def test_scaledl2normPool2d(test_params):
877877

878878

879879
@pytest.mark.skipif(
880-
hasattr(ScaledGlobalL2NormPool2d, "unavailable_class"),
880+
hasattr(ScaledAdaptativeL2NormPool2d, "unavailable_class"),
881881
reason="compute_layer_sv not available",
882882
)
883883
@pytest.mark.parametrize(
@@ -890,7 +890,7 @@ def test_scaledl2normPool2d(test_params):
890890
"layers": [
891891
tInput(uft.to_framework_channel((1, 5, 5))),
892892
uft.get_instance_framework(
893-
ScaledGlobalL2NormPool2d, {"data_format": "channels_last"}
893+
ScaledAdaptativeL2NormPool2d, {"data_format": "channels_last"}
894894
),
895895
]
896896
},
@@ -908,7 +908,7 @@ def test_scaledl2normPool2d(test_params):
908908
"layers": [
909909
tInput(uft.to_framework_channel((1, 5, 5))),
910910
uft.get_instance_framework(
911-
ScaledGlobalL2NormPool2d, {"data_format": "channels_last"}
911+
ScaledAdaptativeL2NormPool2d, {"data_format": "channels_last"}
912912
),
913913
]
914914
},
@@ -926,7 +926,7 @@ def test_scaledl2normPool2d(test_params):
926926
"layers": [
927927
tInput(uft.to_framework_channel((1, 5, 5))),
928928
uft.get_instance_framework(
929-
ScaledGlobalL2NormPool2d, {"data_format": "channels_last"}
929+
ScaledAdaptativeL2NormPool2d, {"data_format": "channels_last"}
930930
),
931931
]
932932
},

tests/test_pooling.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
ScaledAvgPool2d,
3636
ScaledAdaptiveAvgPool2d,
3737
ScaledL2NormPool2d,
38-
ScaledGlobalL2NormPool2d,
38+
ScaledAdaptativeL2NormPool2d,
3939
)
4040

4141

@@ -92,7 +92,7 @@ def test_pooling_simple(layer_type, layer_params):
9292
"layer_type",
9393
[
9494
ScaledAdaptiveAvgPool2d,
95-
ScaledGlobalL2NormPool2d,
95+
ScaledAdaptativeL2NormPool2d,
9696
],
9797
)
9898
@pytest.mark.parametrize(
@@ -137,8 +137,8 @@ def test_pooling_global(layer_type, layer_params):
137137
(ScaledL2NormPool2d, {"kernel_size": (2, 2), "stride": (2, 2)}),
138138
(ScaledAdaptiveAvgPool2d, {"output_size": (1, 1)}),
139139
(ScaledAdaptiveAvgPool2d, {"output_size": (1, 1), "k_coef_lip": 2.5}),
140-
(ScaledGlobalL2NormPool2d, {"output_size": (1, 1)}),
141-
(ScaledGlobalL2NormPool2d, {"output_size": (1, 1), "k_coef_lip": 2.5}),
140+
(ScaledAdaptativeL2NormPool2d, {"output_size": (1, 1)}),
141+
(ScaledAdaptativeL2NormPool2d, {"output_size": (1, 1), "k_coef_lip": 2.5}),
142142
],
143143
)
144144
def test_pool_vanilla_export(layer_type, layer_params):
@@ -206,7 +206,7 @@ def test_pool_vanilla_export(layer_type, layer_params):
206206
[[[[40.0 / math.sqrt(4.0 * 4.0)]], [[18.0 / math.sqrt(4.0 * 4.0)]]]],
207207
),
208208
(
209-
ScaledGlobalL2NormPool2d,
209+
ScaledAdaptativeL2NormPool2d,
210210
{"output_size": (1, 1)},
211211
[[[[math.sqrt(120.0)]], [[math.sqrt(320.0)]]]],
212212
),

tests/utils_framework.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from deel.torchlip.modules import ScaledAvgPool2d
4343
from deel.torchlip.modules import ScaledAdaptiveAvgPool2d
4444
from deel.torchlip.modules import ScaledL2NormPool2d
45-
from deel.torchlip.modules import ScaledGlobalL2NormPool2d
45+
from deel.torchlip.modules import ScaledAdaptativeL2NormPool2d
4646
from deel.torchlip.modules import InvertibleDownSampling
4747
from deel.torchlip.modules import InvertibleUpSampling
4848
from deel.torchlip.modules import LayerCentering
@@ -231,7 +231,7 @@ def get_instance_withcheck(
231231
ScaledL2NormPool2d: partial(
232232
get_instance_withreplacement, dict_keys_replace={"data_format": None}
233233
),
234-
ScaledGlobalL2NormPool2d: partial(
234+
ScaledAdaptativeL2NormPool2d: partial(
235235
get_instance_withreplacement, dict_keys_replace={"data_format": None}
236236
),
237237
SpectralConv2d: partial(

0 commit comments

Comments
 (0)