29
29
30
30
import numpy as np
31
31
import torch
32
+ import torch .nn .functional as F
32
33
from torch .nn .common_types import _size_2_t
33
34
34
35
from ..utils import sqrt_with_gradeps
@@ -193,12 +194,13 @@ def vanilla_export(self):
193
194
return self
194
195
195
196
196
- class ScaledGlobalL2NormPool2d (torch .nn .AdaptiveAvgPool2d , LipschitzModule ):
197
+ class ScaledAdaptativeL2NormPool2d (
198
+ torch .nn .modules .pooling ._AdaptiveAvgPoolNd , LipschitzModule
199
+ ):
197
200
def __init__ (
198
201
self ,
199
202
output_size : _size_2_t = (1 , 1 ),
200
203
k_coef_lip : float = 1.0 ,
201
- eps_grad_sqrt : float = 1e-6 ,
202
204
):
203
205
"""
204
206
Average pooling operation for spatial data, with a lipschitz bound. This
@@ -210,37 +212,24 @@ def __init__(
210
212
Arguments:
211
213
output_size: the target output size has to be (1,1)
212
214
k_coef_lip: the lipschitz factor to ensure
213
- eps_grad_sqrt: Epsilon value to avoid numerical instability
214
- due to non-defined gradient at 0 in the sqrt function
215
215
216
216
Input shape:
217
217
4D tensor with shape `(batch_size, channels, rows, cols)`.
218
218
219
219
Output shape:
220
220
4D tensor with shape `(batch_size, channels, 1, 1)`.
221
221
"""
222
- if eps_grad_sqrt < 0.0 :
223
- raise RuntimeError ("eps_grad_sqrt must be positive" )
224
222
if not isinstance (output_size , tuple ) or len (output_size ) != 2 :
225
223
raise RuntimeError ("output_size must be a tuple of 2 integers" )
226
224
else :
227
225
if output_size [0 ] != 1 or output_size [1 ] != 1 :
228
226
raise RuntimeError ("output_size must be (1, 1)" )
229
- torch .nn .AdaptiveAvgPool2d .__init__ (self , output_size )
227
+ torch .nn .modules . pooling . _AdaptiveAvgPoolNd .__init__ (self , output_size )
230
228
LipschitzModule .__init__ (self , k_coef_lip )
231
- self .eps_grad_sqrt = eps_grad_sqrt
232
229
233
230
def forward (self , input : torch .Tensor ) -> torch .Tensor :
234
- # coeff = computePoolScalingFactor(input.shape[-2:]) * self._coefficient_lip
235
- # avg = torch.nn.AdaptiveAvgPool2d.forward(self, torch.square(input))
236
- # return sqrt_with_gradeps(avg,self.eps_grad_sqrt)* coeff
237
- return ( # type: ignore
238
- sqrt_with_gradeps (
239
- torch .square (input ).sum (axis = (2 , 3 ), keepdim = True ),
240
- self .eps_grad_sqrt ,
241
- )
242
- * self ._coefficient_lip
243
- )
231
+ return F .lp_pool2d (input , 2 , input .shape [- 2 :]) * self ._coefficient_lip
244
232
245
233
def vanilla_export (self ):
246
234
return self
235
+
0 commit comments