Skip to content

Commit e1d8355

Browse files
author
JAXopt authors
committed
Merge pull request #394 from mblondel:release_0.6
PiperOrigin-RevId: 508371158
2 parents 0c8b25b + 730b5a6 commit e1d8355

File tree

6 files changed

+67
-20
lines changed

6 files changed

+67
-20
lines changed

docs/api.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ Line search
135135
:toctree: _autosummary
136136

137137
jaxopt.BacktrackingLineSearch
138-
138+
jaxopt.HagerZhangLineSearch
139139

140140

141141
Perturbed optimizers

docs/changelog.rst

+38
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,44 @@
11
Changelog
22
=========
33

4+
Version 0.6
5+
-----------
6+
7+
New features
8+
~~~~~~~~~~~~
9+
10+
- Added new Hager-Zhang linesearch in LBFGS, by Srinivas Vasudevan (code review by Emily Fertig).
11+
- Added perceptron and hinge losses, by Quentin Berthet.
12+
- Added binary sparsemax loss, sparse_plus and sparse_sigmoid, by Vincent Roulet.
13+
- Added isotonic regression, by Michael Sander.
14+
15+
Bug fixes and enhancements
16+
~~~~~~~~~~~~~~~~~~~~~~~~~~
17+
18+
- Added TPU support to notebooks, by Ayush Shridhar.
19+
- Allowed users to restart from a previous optimizer state in LBFGS, by Zaccharie Ramzi.
20+
- Added faster error computation in gradient descent algorithm, by Zaccharie Ramzi.
21+
- Got rid of extra function call in BFGS and LBFGS, by Zaccharie Ramzi.
22+
- Improved dtype consistency between input and output of update method, by Mathieu Blondel.
23+
- Added perturbed optimizers notebook and narrative documentation, by Quentin Berthet and Fabian Pedregosa.
24+
- Enabled auxiliary value returned by linesearch methods, by Zaccharie Ramzi.
25+
- Added distributed examples to the website, by Fabian Pedregosa.
26+
- Added Custom loop pjit example, by Felipe Llinares.
27+
- Fixed wrong latex in maml.ipynb, by Fabian Pedregosa.
28+
- Fixed bug in backtracking line search, by Srinivas Vasudevan (code review by Emily Fertig).
29+
- Added pylintrc to top level directory, by Fabian Pedregosa.
30+
- Corrected the condition function in LBFGS, by Zaccharie Ramzi.
31+
- Added custom loop pmap example, by Felipe Llinares.
32+
- Fixed pytree support in IterativeRefinement, by Louis Béthune.
33+
- Fixed has_aux support in ArmijoSGD, by Louis Béthune.
34+
- Documentation improvements, by Fabian Pedregosa and Mathieu Blondel.
35+
36+
Contributors
37+
~~~~~~~~~~~~
38+
39+
Ayush Shridhar, Fabian Pedregosa, Felipe Llinares, Louis Bethune,
40+
Mathieu Blondel, Michael Sander, Quentin Berthet, Srinivas Vasudevan, Vincent Roulet, Zaccharie Ramzi.
41+
442
Version 0.5.5
543
-------------
644

docs/line_search.rst

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ Algorithms
5151
:toctree: _autosummary
5252

5353
jaxopt.BacktrackingLineSearch
54+
jaxopt.HagerZhangLineSearch
5455

5556
The :class:`BacktrackingLineSearch <jaxopt.BacktrackingLineSearch>` algorithm
5657
iteratively reduces the step size by some decrease factor until the conditions

docs/objective_and_loss.rst

+8-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ Binary classification
2929
Binary classification losses are of the form ``loss(int: label, float: score) -> float``,
3030
where ``label`` is the ground-truth (``0`` or ``1``) and ``score`` is the model's output.
3131

32+
The following utility functions are useful for the binary sparsemax loss.
33+
34+
.. autosummary::
35+
:toctree: _autosummary
36+
37+
jaxopt.loss.sparse_plus
38+
jaxopt.loss.sparse_sigmoid
39+
3240
Multiclass classification
3341
~~~~~~~~~~~~~~~~~~~~~~~~~
3442

@@ -79,5 +87,3 @@ Other functions
7987
jaxopt.objective.multiclass_logreg_with_intercept
8088
jaxopt.objective.l2_multiclass_logreg
8189
jaxopt.objective.l2_multiclass_logreg_with_intercept
82-
jaxopt.loss.sparse_plus
83-
jaxopt.loss.sparse_sigmoid

jaxopt/_src/loss.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -74,58 +74,60 @@ def binary_sparsemax_loss(label: int, logit: float) -> float:
7474
loss value
7575
7676
References:
77-
Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins,
77+
Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins,
7878
Vlad Niculae. JMLR 2020. (Sec. 4.4)
7979
"""
8080
return sparse_plus(jnp.where(label, -logit, logit))
8181

8282

8383
def sparse_plus(x: float) -> float:
84-
"""Sparse plus function.
84+
r"""Sparse plus function.
8585
8686
Computes the function:
8787
88-
.. math:
89-
\mathrm{sparseplus}(x) = \begin{cases}
88+
.. math::
89+
90+
\mathrm{sparse\_plus}(x) = \begin{cases}
9091
0, & x \leq -1\\
91-
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
92+
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
9293
x, & 1 \leq x
9394
\end{cases}
9495
95-
This is the twin function of the softplus activation ensuring a zero output
96-
for inputs less than -1 and a linear output for inputs greater than 1,
97-
while remaining smooth, convex, monotonic by an adequate definition between
96+
This is the twin function of the softplus activation ensuring a zero output
97+
for inputs less than -1 and a linear output for inputs greater than 1,
98+
while remaining smooth, convex, monotonic by an adequate definition between
9899
-1 and 1.
99100
100101
Args:
101102
x: input (float)
102103
Returns:
103-
sparseplus(x) as defined above
104+
sparse_plus(x) as defined above
104105
"""
105106
return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4))
106107

107108

108109
def sparse_sigmoid(x: float) -> float:
109-
"""Sparse sigmoid function.
110+
r"""Sparse sigmoid function.
111+
112+
Computes the function:
110113
111-
Computes the function:
114+
.. math::
112115
113-
.. math:
114-
\mathrm{sparsesigmoid}(x) = \begin{cases}
116+
\mathrm{sparse\_sigmoid}(x) = \begin{cases}
115117
0, & x \leq -1\\
116-
\frac{1}{2}(x+1), & -1 < x < 1 \\
118+
\frac{1}{2}(x+1), & -1 < x < 1 \\
117119
1, & 1 \leq x
118120
\end{cases}
119121
120122
This is the twin function of the sigmoid activation ensuring a zero output
121123
for inputs less than -1, a 1 ouput for inputs greater than 1, and a linear
122-
output for inputs between -1 and 1. This is the derivative of the sparse
124+
output for inputs between -1 and 1. This is the derivative of the sparse
123125
plus function.
124126
125127
Args:
126128
x: input (float)
127129
Returns:
128-
sparsesigmoid(x) as defined above
130+
sparse_sigmoid(x) as defined above
129131
"""
130132
return 0.5 * projection_hypercube(x + 1.0, 2.0)
131133

jaxopt/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414

1515
"""JAXopt version."""
1616

17-
__version__ = "0.5.5"
17+
__version__ = "0.6"

0 commit comments

Comments
 (0)