Skip to content

Commit a058441

Browse files
committed
deprecated spectral_ and bjorck_ initialization
1 parent baa4ee2 commit a058441

File tree

4 files changed

+14
-12
lines changed

4 files changed

+14
-12
lines changed

deel/torchlip/init.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
2525
# CRIAQ and ANITI - https://www.deel.ai/
2626
# =====================================================================================
27-
"""
28-
"""
27+
""" """
28+
import warnings
2929
import torch
3030

3131
from .normalizers import bjorck_normalization
@@ -57,6 +57,9 @@ def spectral_(
5757
eps_spectral (float): stopping criterion of iterative power method
5858
maxiter_spectral (int): maximum number of iterations for the power iteration
5959
"""
60+
warnings.warn(
61+
"spectral_ initialization is deprecated, use torch.nn.init.orthogonal_ instead"
62+
)
6063
with torch.no_grad():
6164
tensor.copy_(
6265
spectral_normalization(
@@ -91,6 +94,9 @@ def bjorck_(
9194
maxiter_bjorck (int): maximum number of iterations for bjorck algorithm
9295
beta: Value to use for the :math:`\beta` parameter.
9396
"""
97+
warnings.warn(
98+
"bjorck_ initialization is deprecated, use torch.nn.init.orthogonal_ instead"
99+
)
94100
with torch.no_grad():
95101
spectral_tensor = spectral_normalization(
96102
tensor, None, eps=eps_spectral, maxiter=maxiter_spectral

tests/test_compute_layer_sv.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
2525
# CRIAQ and ANITI - https://www.deel.ai/
2626
# =====================================================================================
27-
"""Tests for singular value computation (in compute_layer_sv.py)
28-
"""
27+
"""Tests for singular value computation (in compute_layer_sv.py)"""
2928
import os
3029
import pprint
3130
import pytest

tests/test_initializers.py

+4
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
)
3636

3737

38+
@pytest.mark.skipif(
39+
hasattr(SpectralInitializer, "unavailable_class"),
40+
reason="SpectralInitializer not available",
41+
)
3842
@pytest.mark.parametrize(
3943
"layer_type, layer_params,input_shape, orthogonal_test",
4044
[

tests/utils_framework.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
CategoricalHingeLoss,
6565
)
6666

67-
from deel.torchlip.init import spectral_, bjorck_
6867
from deel.torchlip.normalizers import spectral_normalization
6968
from deel.torchlip.normalizers import bjorck_normalization
7069
from deel.torchlip.normalizers import DEFAULT_EPS_SPECTRAL
@@ -174,6 +173,7 @@ def __call__(self, **kwargs):
174173
Model = module_Unavailable_class
175174
compute_layer_sv = module_Unavailable_class
176175
OrthLinearRegularizer = module_Unavailable_class
176+
SpectralInitializer = module_Unavailable_class
177177

178178
MODEL_PATH = "model.h5"
179179
LIP_LAYERS = "torchlip_layers"
@@ -591,13 +591,6 @@ def scaleDivAlpha(alpha):
591591
return 1.0 / (1 + 1.0 / alpha)
592592

593593

594-
def SpectralInitializer(eps_spectral, eps_bjorck):
595-
if eps_bjorck is None:
596-
return partial(spectral_, eps_spectral=eps_spectral)
597-
else:
598-
return partial(bjorck_, eps_spectral=eps_spectral, eps_bjorck=eps_bjorck)
599-
600-
601594
class tAdd(torch.nn.Module):
602595
def __init__(self):
603596
super(tAdd, self).__init__()

0 commit comments

Comments
 (0)