Skip to content

Commit 8a8ed6b

Browse files
committed
deprecated spectral_ and bjorck_ initialization
1 parent baa4ee2 commit 8a8ed6b

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

deel/torchlip/init.py

+3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
# =====================================================================================
2727
"""
2828
"""
29+
import warnings
2930
import torch
3031

3132
from .normalizers import bjorck_normalization
@@ -57,6 +58,7 @@ def spectral_(
5758
eps_spectral (float): stopping criterion of iterative power method
5859
maxiter_spectral (int): maximum number of iterations for the power iteration
5960
"""
61+
warnings.warn("spectral_ initialization is deprecated, use torch.nn.init.orthogonal_ instead")
6062
with torch.no_grad():
6163
tensor.copy_(
6264
spectral_normalization(
@@ -91,6 +93,7 @@ def bjorck_(
9193
maxiter_bjorck (int): maximum number of iterations for bjorck algorithm
9294
beta: Value to use for the :math:`\beta` parameter.
9395
"""
96+
warnings.warn("bjorck_ initialization is deprecated, use torch.nn.init.orthogonal_ instead")
9497
with torch.no_grad():
9598
spectral_tensor = spectral_normalization(
9699
tensor, None, eps=eps_spectral, maxiter=maxiter_spectral

tests/utils_framework.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def __call__(self, **kwargs):
174174
Model = module_Unavailable_class
175175
compute_layer_sv = module_Unavailable_class
176176
OrthLinearRegularizer = module_Unavailable_class
177+
SpectralInitializer = module_Unavailable_class
177178

178179
MODEL_PATH = "model.h5"
179180
LIP_LAYERS = "torchlip_layers"
@@ -591,13 +592,6 @@ def scaleDivAlpha(alpha):
591592
return 1.0 / (1 + 1.0 / alpha)
592593

593594

594-
def SpectralInitializer(eps_spectral, eps_bjorck):
595-
if eps_bjorck is None:
596-
return partial(spectral_, eps_spectral=eps_spectral)
597-
else:
598-
return partial(bjorck_, eps_spectral=eps_spectral, eps_bjorck=eps_bjorck)
599-
600-
601595
class tAdd(torch.nn.Module):
602596
def __init__(self):
603597
super(tAdd, self).__init__()

0 commit comments

Comments
 (0)