Skip to content

Commit

Permalink
[MRG] Uniformize initialization for all algorithms (#195)
Browse files Browse the repository at this point in the history
* initiate PR

* Revert "initiate PR"

This reverts commit a2ae9e1.

* FEAT: uniformize init for NCA and RCA

* Let the check of num_dims be done in the other PR

* Add metric initialization for algorithms that learn a mahalanobis matrix

* Add initialization for MLKR

* FIX: fix error message for dimension

* FIX fix StringRepr for MLKR

* FIX tests by reshaping to the right dataset size

* Remove lda in docstring of MLKR

* MAINT: Add deprecation for previous initializations

* Update tests with new initialization

* Make random init for mahalanobis metric generate an SPD matrix

* Ensure the input mahalanobis metric initialization is symmetric, and say it should be SPD

* various fixes

* MAINT: various refactoring
- MLKR: update default test init
- SDML: refactor prior_inv

* FIX fix default covariance for SDML in tests

* Enhance docstring

* Set random state for SDML

* Fix merge remove_spaces that was forgotten

* Fix indent

* XP: try to change the way we choose n_components to see if it fixes the test

* Revert "XP: try to change the way we choose n_components to see if it fixes the test"

This reverts commit e86b61b.

* Be more tolerant in test

* Add test for singular covariance matrix

* Fix test_singular_covariance_init

* DOC: update docstring saying pseudo-inverse

* Revert "Fix test_singular_covariance_init"

This reverts commit d2cc7ce.

* Ensure definiteness before returning the inverse

* wip deal with non definiteness

* Rename init to prior for SDML and LSML

* Update error messages with either prior or init

* Remove message

* A few nitpicks

* PEP8 errors + change init in test

* STY: PEP8 fixes

* Address and remove TODOs

* Replace init by prior for ITML

* TST: fix ITML test with init changed into prior

* Add precision for MMC

* Add ChangedBehaviorWarning for the algorithms that changed

* Address #195 (review)

* Remove the warnings check since we now have a ChangedBehaviorWarning

* Be more precise: it should not raise any ConvergenceWarningError

* Address #195 (review)

* FIX remaining comment

* TST: update test error message

* Improve readability

* Address #195 (review)

* TST: Fix docsting lmnn

* Fix warning messages

* Fix warnings messages changed
  • Loading branch information
wdevazelhes authored and perimosocordiae committed Jun 7, 2019
1 parent 3899653 commit 130cbad
Show file tree
Hide file tree
Showing 18 changed files with 1,626 additions and 223 deletions.
2 changes: 1 addition & 1 deletion bench/benchmarks/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'LMNN': metric_learn.LMNN(k=5, learn_rate=1e-6, verbose=False),
'LSML_Supervised': metric_learn.LSML_Supervised(num_constraints=200),
'MLKR': metric_learn.MLKR(),
'NCA': metric_learn.NCA(max_iter=700, num_dims=2),
'NCA': metric_learn.NCA(max_iter=700, n_components=2),
'RCA_Supervised': metric_learn.RCA_Supervised(dim=2, num_chunks=30,
chunk_size=2),
'SDML_Supervised': metric_learn.SDML_Supervised(num_constraints=1500),
Expand Down
335 changes: 328 additions & 7 deletions metric_learn/_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import warnings
import numpy as np
import scipy
import six
from numpy.linalg import LinAlgError
from sklearn.datasets import make_spd_matrix
from sklearn.decomposition import PCA
from sklearn.utils import check_array
from sklearn.utils.validation import check_X_y
from metric_learn.exceptions import PreprocessorError
from sklearn.utils.validation import check_X_y, check_random_state
from .exceptions import PreprocessorError, NonPSDError
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from scipy.linalg import pinvh
import sys
import time

# hack around lack of axis kwarg in older numpy versions
try:
Expand Down Expand Up @@ -335,28 +341,38 @@ def check_collapsed_pairs(pairs):
def _check_sdp_from_eigen(w, tol=None):
"""Checks if some of the eigenvalues given are negative, up to a tolerance
level, with a default value of the tolerance depending on the eigenvalues.
It also returns whether the matrix is positive definite, up to the above
tolerance.
Parameters
----------
w : array-like, shape=(n_eigenvalues,)
Eigenvalues to check for non semidefinite positiveness.
tol : positive `float`, optional
Negative eigenvalues above - tol are considered zero. If
Absolute eigenvalues below tol are considered zero. If
tol is None, and eps is the epsilon value for datatype of w, then tol
is set to w.max() * len(w) * eps.
is set to abs(w).max() * len(w) * eps.
Returns
-------
is_definite : bool
Whether the matrix is positive definite or not.
See Also
--------
np.linalg.matrix_rank for more details on the choice of tolerance (the same
strategy is applied here)
"""
if tol is None:
tol = w.max() * len(w) * np.finfo(w.dtype).eps
tol = np.abs(w).max() * len(w) * np.finfo(w.dtype).eps
if tol < 0:
raise ValueError("tol should be positive.")
if any(w < - tol):
raise ValueError("Matrix is not positive semidefinite (PSD).")
raise NonPSDError()
if any(abs(w) < tol):
return False
return True


def transformer_from_metric(metric, tol=None):
Expand Down Expand Up @@ -413,6 +429,311 @@ def validate_vector(u, dtype=None):
return u


def _initialize_transformer(n_components, input, y=None, init='auto',
verbose=False, random_state=None,
has_classes=True):
"""Returns the initial transformer to be used depending on the arguments.
Parameters
----------
n_components : int
The number of components to take. (Note: it should have been checked
before, meaning it should not be None and it should be a value in
[1, X.shape[1]])
input : array-like
The input samples (can be tuples or regular samples).
y : array-like or None
The input labels (or not if there are no labels).
init : string or numpy array, optional (default='auto')
Initialization of the linear transformation. Possible options are
'auto', 'pca', 'lda', 'identity', 'random', and a numpy array of shape
(n_features_a, n_features_b).
'auto'
Depending on ``n_components``, the most reasonable initialization
will be chosen. If ``n_components <= n_classes`` we use 'lda' (see
the description of 'lda' init), as it uses labels information. If
not, but ``n_components < min(n_features, n_samples)``, we use 'pca',
as it projects data onto meaningful directions (those of higher
variance). Otherwise, we just use 'identity'.
'pca'
``n_components`` principal components of the inputs passed
to :meth:`fit` will be used to initialize the transformation.
(See `sklearn.decomposition.PCA`)
'lda'
``min(n_components, n_classes)`` most discriminative
components of the inputs passed to :meth:`fit` will be used to
initialize the transformation. (If ``n_components > n_classes``,
the rest of the components will be zero.) (See
`sklearn.discriminant_analysis.LinearDiscriminantAnalysis`).
This initialization is possible only if `has_classes == True`.
'identity'
The identity matrix. If ``n_components`` is strictly smaller than the
dimensionality of the inputs passed to :meth:`fit`, the identity
matrix will be truncated to the first ``n_components`` rows.
'random'
The initial transformation will be a random array of shape
`(n_components, n_features)`. Each value is sampled from the
standard normal distribution.
numpy array
n_features_b must match the dimensionality of the inputs passed to
:meth:`fit` and n_features_a must be less than or equal to that.
If ``n_components`` is not None, n_features_a must match it.
verbose : bool
Whether to print the details of the initialization or not.
random_state : int or `numpy.RandomState` or None, optional (default=None)
A pseudo random number generator object or a seed for it if int. If
``init='random'``, ``random_state`` is used to initialize the random
transformation. If ``init='pca'``, ``random_state`` is passed as an
argument to PCA when initializing the transformation.
has_classes : bool (default=True)
Whether the labels are in fact classes. If true, this will allow to use
the 'lda' initialization.
Returns
-------
init_transformer : `numpy.ndarray`
The initial transformer to use.
"""
# if we are doing a regression we cannot use lda:
n_features = input.shape[-1]
authorized_inits = ['auto', 'pca', 'identity', 'random']
if has_classes:
authorized_inits.append('lda')

if isinstance(init, np.ndarray):
# we copy the array, so that if we update the metric, we don't want to
# update the init
init = check_array(init, copy=True)

# Assert that init.shape[1] = X.shape[1]
if init.shape[1] != n_features:
raise ValueError('The input dimensionality ({}) of the given '
'linear transformation `init` must match the '
'dimensionality of the given inputs `X` ({}).'
.format(init.shape[1], n_features))

# Assert that init.shape[0] <= init.shape[1]
if init.shape[0] > init.shape[1]:
raise ValueError('The output dimensionality ({}) of the given '
'linear transformation `init` cannot be '
'greater than its input dimensionality ({}).'
.format(init.shape[0], init.shape[1]))

# Assert that self.n_components = init.shape[0]
if n_components != init.shape[0]:
raise ValueError('The preferred dimensionality of the '
'projected space `n_components` ({}) does'
' not match the output dimensionality of '
'the given linear transformation '
'`init` ({})!'
.format(n_components,
init.shape[0]))
elif init not in authorized_inits:
raise ValueError(
"`init` must be '{}' "
"or a numpy array of shape (n_components, n_features)."
.format("', '".join(authorized_inits)))

random_state = check_random_state(random_state)
if isinstance(init, np.ndarray):
return init
n_samples = input.shape[0]
if init == 'auto':
if has_classes:
n_classes = len(np.unique(y))
else:
n_classes = -1
init = _auto_select_init(has_classes, n_features, n_samples, n_components,
n_classes)
if init == 'identity':
return np.eye(n_components, input.shape[-1])
elif init == 'random':
return random_state.randn(n_components, input.shape[-1])
elif init in {'pca', 'lda'}:
init_time = time.time()
if init == 'pca':
pca = PCA(n_components=n_components,
random_state=random_state)
if verbose:
print('Finding principal components... ')
sys.stdout.flush()
pca.fit(input)
transformation = pca.components_
elif init == 'lda':
lda = LinearDiscriminantAnalysis(n_components=n_components)
if verbose:
print('Finding most discriminative components... ')
sys.stdout.flush()
lda.fit(input, y)
transformation = lda.scalings_.T[:n_components]
if verbose:
print('done in {:5.2f}s'.format(time.time() - init_time))
return transformation


def _auto_select_init(has_classes, n_features, n_samples, n_components,
n_classes):
if has_classes and n_components <= min(n_features, n_classes - 1):
init = 'lda'
elif n_components < min(n_features, n_samples):
init = 'pca'
else:
init = 'identity'
return init


def _initialize_metric_mahalanobis(input, init='identity', random_state=None,
return_inverse=False, strict_pd=False,
matrix_name='matrix'):
"""Returns a PSD matrix that can be used as a prior or an initialization
for the Mahalanobis distance
Parameters
----------
input : array-like
The input samples (can be tuples or regular samples).
init : string or numpy array, optional (default='identity')
Specification for the matrix to initialize. Possible options are
'identity', 'covariance', 'random', and a numpy array of shape
(n_features, n_features).
'identity'
An identity matrix of shape (n_features, n_features).
'covariance'
The (pseudo-)inverse covariance matrix (raises an error if the
covariance matrix is not definite and `strict_pd == True`)
'random'
A random positive definite (PD) matrix of shape
`(n_features, n_features)`, generated using
`sklearn.datasets.make_spd_matrix`.
numpy array
A PSD matrix (or strictly PD if strict_pd==True) of
shape (n_features, n_features), that will be used as such to
initialize the metric, or set the prior.
random_state : int or `numpy.RandomState` or None, optional (default=None)
A pseudo random number generator object or a seed for it if int. If
``init='random'``, ``random_state`` is used to set the random Mahalanobis
matrix. If ``init='pca'``, ``random_state`` is passed as an
argument to PCA when initializing the matrix.
return_inverse : bool, optional (default=False)
Whether to return the inverse of the specified matrix. This
can be sometimes useful. It will return the pseudo-inverse (which is the
same as the inverse if the matrix is definite (i.e. invertible)). If
`strict_pd == True` and the matrix is not definite, it will return an
error.
strict_pd : bool, optional (default=False)
Whether to enforce that the provided matrix is definite (in addition to
being PSD).
param_name : str, optional (default='matrix')
The name of the matrix used (example: 'init', 'prior'). Will be used in
error messages.
Returns
-------
M, or (M, M_inv) : `numpy.ndarray`
The initial matrix to use M, and its inverse if `return_inverse=True`.
"""
n_features = input.shape[-1]
if isinstance(init, np.ndarray):
# we copy the array, so that if we update the metric, we don't want to
# update the init
init = check_array(init, copy=True)

# Assert that init.shape[1] = n_features
if init.shape != (n_features,) * 2:
raise ValueError('The input dimensionality {} of the given '
'mahalanobis matrix `{}` must match the '
'dimensionality of the given inputs ({}).'
.format(init.shape, matrix_name, n_features))

# Assert that the matrix is symmetric
if not np.allclose(init, init.T):
raise ValueError("`{}` is not symmetric.".format(matrix_name))

elif init not in ['identity', 'covariance', 'random']:
raise ValueError(
"`{}` must be 'identity', 'covariance', 'random' "
"or a numpy array of shape (n_features, n_features)."
.format(matrix_name))

random_state = check_random_state(random_state)
M = init
if isinstance(init, np.ndarray):
s, u = scipy.linalg.eigh(init)
init_is_definite = _check_sdp_from_eigen(s)
if strict_pd and not init_is_definite:
raise LinAlgError("You should provide a strictly positive definite "
"matrix as `{}`. This one is not definite. Try another"
" {}, or an algorithm that does not "
"require the {} to be strictly positive definite."
.format(*((matrix_name,) * 3)))
if return_inverse:
M_inv = np.dot(u / s, u.T)
return M, M_inv
else:
return M
elif init == 'identity':
M = np.eye(n_features, n_features)
if return_inverse:
M_inv = M.copy()
return M, M_inv
else:
return M
elif init == 'covariance':
if input.ndim == 3:
# if the input are tuples, we need to form an X by deduplication
X = np.vstack({tuple(row) for row in input.reshape(-1, n_features)})
else:
X = input
# atleast2d is necessary to deal with scalar covariance matrices
M_inv = np.atleast_2d(np.cov(X, rowvar=False))
s, u = scipy.linalg.eigh(M_inv)
cov_is_definite = _check_sdp_from_eigen(s)
if strict_pd and not cov_is_definite:
raise LinAlgError("Unable to get a true inverse of the covariance "
"matrix since it is not definite. Try another "
"`{}`, or an algorithm that does not "
"require the `{}` to be strictly positive definite."
.format(*((matrix_name,) * 2)))
M = np.dot(u / s, u.T)
if return_inverse:
return M, M_inv
else:
return M
elif init == 'random':
# we need to create a random symmetric matrix
M = make_spd_matrix(n_features, random_state=random_state)
if return_inverse:
# we use pinvh even if we know the matrix is definite, just because
# we need the returned matrix to be symmetric (and sometimes
# np.linalg.inv returns not symmetric inverses of symmetric matrices)
# TODO: there might be a more efficient method to do so
M_inv = pinvh(M)
return M, M_inv
else:
return M


def _check_n_components(n_features, n_components):
"""Checks that n_components is less than n_features and deal with the None
case"""
Expand Down
2 changes: 1 addition & 1 deletion metric_learn/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Covariance(MahalanobisMixin, TransformerMixin):
Attributes
----------
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
transformer_ : `numpy.ndarray`, shape=(n_features, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See function `transformer_from_metric`.)
"""
Expand Down
Loading

0 comments on commit 130cbad

Please sign in to comment.