Skip to content

Commit bb08d3e

Browse files
committed
feat: add regularizer base class
1 parent 2190cde commit bb08d3e

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

neural_nets/regularizers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .regularizers import *
+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
Built-in regularizers.
3+
"""
4+
5+
import numpy as np
6+
7+
8+
class Regularizer(object):
9+
"""
10+
Regularizer base class.
11+
"""
12+
13+
def __call__(self, x):
14+
return 0.
15+
16+
@classmethod
17+
def from_config(cls, config):
18+
return cls(**config)
19+
20+
21+
class L1L2(Regularizer):
22+
"""
23+
Regularizer for L1 and L2 regularization.
24+
25+
Arguments
26+
--------
27+
l1: Float; L1 regularization factor.
28+
l2: Float; L2 regularization factor.
29+
"""
30+
31+
def __init__(self, l1=0., l2=0.):
32+
self.l1 = l1
33+
self.l2 = l2
34+
35+
def __call__(self, x):
36+
regularization = 0.
37+
if self.l1:
38+
regularization += np.sum(self.l1 * np.abs(x))
39+
if self.l2:
40+
regularization += np.sum(self.l2 * np.square(x))
41+
return regularization
42+
43+
def get_config(self):
44+
return {'l1': float(self.l1),
45+
'l2': float(self.l2)}
46+
47+
48+
def l1(l=0.01):
49+
return L1L2(l1=l)
50+
51+
52+
def l2(l=0.01):
53+
return L1L2(l2=l)
54+
55+
56+
def l1_l2(l1=0.01, l2=0.01):
57+
return L1L2(l1=l1, l2=l2)

0 commit comments

Comments
 (0)