diff --git a/bench/benchmarks/iris.py b/bench/benchmarks/iris.py index 305c3a0f..e3390930 100644 --- a/bench/benchmarks/iris.py +++ b/bench/benchmarks/iris.py @@ -10,7 +10,7 @@ 'LMNN': metric_learn.LMNN(k=5, learn_rate=1e-6, verbose=False), 'LSML_Supervised': metric_learn.LSML_Supervised(num_constraints=200), 'MLKR': metric_learn.MLKR(), - 'NCA': metric_learn.NCA(max_iter=700, num_dims=2), + 'NCA': metric_learn.NCA(max_iter=700, n_components=2), 'RCA_Supervised': metric_learn.RCA_Supervised(dim=2, num_chunks=30, chunk_size=2), 'SDML_Supervised': metric_learn.SDML_Supervised(num_constraints=1500), diff --git a/metric_learn/_util.py b/metric_learn/_util.py index 583f1105..9cf6d7c6 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -1,10 +1,16 @@ -import warnings import numpy as np +import scipy import six from numpy.linalg import LinAlgError +from sklearn.datasets import make_spd_matrix +from sklearn.decomposition import PCA from sklearn.utils import check_array -from sklearn.utils.validation import check_X_y -from metric_learn.exceptions import PreprocessorError +from sklearn.utils.validation import check_X_y, check_random_state +from .exceptions import PreprocessorError, NonPSDError +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from scipy.linalg import pinvh +import sys +import time # hack around lack of axis kwarg in older numpy versions try: @@ -335,6 +341,8 @@ def check_collapsed_pairs(pairs): def _check_sdp_from_eigen(w, tol=None): """Checks if some of the eigenvalues given are negative, up to a tolerance level, with a default value of the tolerance depending on the eigenvalues. + It also returns whether the matrix is positive definite, up to the above + tolerance. Parameters ---------- @@ -342,9 +350,14 @@ def _check_sdp_from_eigen(w, tol=None): Eigenvalues to check for non semidefinite positiveness. tol : positive `float`, optional - Negative eigenvalues above - tol are considered zero. If + Absolute eigenvalues below tol are considered zero. If tol is None, and eps is the epsilon value for datatype of w, then tol - is set to w.max() * len(w) * eps. + is set to abs(w).max() * len(w) * eps. + + Returns + ------- + is_definite : bool + Whether the matrix is positive definite or not. See Also -------- @@ -352,11 +365,14 @@ def _check_sdp_from_eigen(w, tol=None): strategy is applied here) """ if tol is None: - tol = w.max() * len(w) * np.finfo(w.dtype).eps + tol = np.abs(w).max() * len(w) * np.finfo(w.dtype).eps if tol < 0: raise ValueError("tol should be positive.") if any(w < - tol): - raise ValueError("Matrix is not positive semidefinite (PSD).") + raise NonPSDError() + if any(abs(w) < tol): + return False + return True def transformer_from_metric(metric, tol=None): @@ -413,6 +429,311 @@ def validate_vector(u, dtype=None): return u +def _initialize_transformer(n_components, input, y=None, init='auto', + verbose=False, random_state=None, + has_classes=True): + """Returns the initial transformer to be used depending on the arguments. + + Parameters + ---------- + n_components : int + The number of components to take. (Note: it should have been checked + before, meaning it should not be None and it should be a value in + [1, X.shape[1]]) + + input : array-like + The input samples (can be tuples or regular samples). + + y : array-like or None + The input labels (or not if there are no labels). + + init : string or numpy array, optional (default='auto') + Initialization of the linear transformation. Possible options are + 'auto', 'pca', 'lda', 'identity', 'random', and a numpy array of shape + (n_features_a, n_features_b). + + 'auto' + Depending on ``n_components``, the most reasonable initialization + will be chosen. If ``n_components <= n_classes`` we use 'lda' (see + the description of 'lda' init), as it uses labels information. If + not, but ``n_components < min(n_features, n_samples)``, we use 'pca', + as it projects data onto meaningful directions (those of higher + variance). Otherwise, we just use 'identity'. + + 'pca' + ``n_components`` principal components of the inputs passed + to :meth:`fit` will be used to initialize the transformation. + (See `sklearn.decomposition.PCA`) + + 'lda' + ``min(n_components, n_classes)`` most discriminative + components of the inputs passed to :meth:`fit` will be used to + initialize the transformation. (If ``n_components > n_classes``, + the rest of the components will be zero.) (See + `sklearn.discriminant_analysis.LinearDiscriminantAnalysis`). + This initialization is possible only if `has_classes == True`. + + 'identity' + The identity matrix. If ``n_components`` is strictly smaller than the + dimensionality of the inputs passed to :meth:`fit`, the identity + matrix will be truncated to the first ``n_components`` rows. + + 'random' + The initial transformation will be a random array of shape + `(n_components, n_features)`. Each value is sampled from the + standard normal distribution. + + numpy array + n_features_b must match the dimensionality of the inputs passed to + :meth:`fit` and n_features_a must be less than or equal to that. + If ``n_components`` is not None, n_features_a must match it. + + verbose : bool + Whether to print the details of the initialization or not. + + random_state : int or `numpy.RandomState` or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``init='random'``, ``random_state`` is used to initialize the random + transformation. If ``init='pca'``, ``random_state`` is passed as an + argument to PCA when initializing the transformation. + + has_classes : bool (default=True) + Whether the labels are in fact classes. If true, this will allow to use + the 'lda' initialization. + + Returns + ------- + init_transformer : `numpy.ndarray` + The initial transformer to use. + """ + # if we are doing a regression we cannot use lda: + n_features = input.shape[-1] + authorized_inits = ['auto', 'pca', 'identity', 'random'] + if has_classes: + authorized_inits.append('lda') + + if isinstance(init, np.ndarray): + # we copy the array, so that if we update the metric, we don't want to + # update the init + init = check_array(init, copy=True) + + # Assert that init.shape[1] = X.shape[1] + if init.shape[1] != n_features: + raise ValueError('The input dimensionality ({}) of the given ' + 'linear transformation `init` must match the ' + 'dimensionality of the given inputs `X` ({}).' + .format(init.shape[1], n_features)) + + # Assert that init.shape[0] <= init.shape[1] + if init.shape[0] > init.shape[1]: + raise ValueError('The output dimensionality ({}) of the given ' + 'linear transformation `init` cannot be ' + 'greater than its input dimensionality ({}).' + .format(init.shape[0], init.shape[1])) + + # Assert that self.n_components = init.shape[0] + if n_components != init.shape[0]: + raise ValueError('The preferred dimensionality of the ' + 'projected space `n_components` ({}) does' + ' not match the output dimensionality of ' + 'the given linear transformation ' + '`init` ({})!' + .format(n_components, + init.shape[0])) + elif init not in authorized_inits: + raise ValueError( + "`init` must be '{}' " + "or a numpy array of shape (n_components, n_features)." + .format("', '".join(authorized_inits))) + + random_state = check_random_state(random_state) + if isinstance(init, np.ndarray): + return init + n_samples = input.shape[0] + if init == 'auto': + if has_classes: + n_classes = len(np.unique(y)) + else: + n_classes = -1 + init = _auto_select_init(has_classes, n_features, n_samples, n_components, + n_classes) + if init == 'identity': + return np.eye(n_components, input.shape[-1]) + elif init == 'random': + return random_state.randn(n_components, input.shape[-1]) + elif init in {'pca', 'lda'}: + init_time = time.time() + if init == 'pca': + pca = PCA(n_components=n_components, + random_state=random_state) + if verbose: + print('Finding principal components... ') + sys.stdout.flush() + pca.fit(input) + transformation = pca.components_ + elif init == 'lda': + lda = LinearDiscriminantAnalysis(n_components=n_components) + if verbose: + print('Finding most discriminative components... ') + sys.stdout.flush() + lda.fit(input, y) + transformation = lda.scalings_.T[:n_components] + if verbose: + print('done in {:5.2f}s'.format(time.time() - init_time)) + return transformation + + +def _auto_select_init(has_classes, n_features, n_samples, n_components, + n_classes): + if has_classes and n_components <= min(n_features, n_classes - 1): + init = 'lda' + elif n_components < min(n_features, n_samples): + init = 'pca' + else: + init = 'identity' + return init + + +def _initialize_metric_mahalanobis(input, init='identity', random_state=None, + return_inverse=False, strict_pd=False, + matrix_name='matrix'): + """Returns a PSD matrix that can be used as a prior or an initialization + for the Mahalanobis distance + + Parameters + ---------- + input : array-like + The input samples (can be tuples or regular samples). + + init : string or numpy array, optional (default='identity') + Specification for the matrix to initialize. Possible options are + 'identity', 'covariance', 'random', and a numpy array of shape + (n_features, n_features). + + 'identity' + An identity matrix of shape (n_features, n_features). + + 'covariance' + The (pseudo-)inverse covariance matrix (raises an error if the + covariance matrix is not definite and `strict_pd == True`) + + 'random' + A random positive definite (PD) matrix of shape + `(n_features, n_features)`, generated using + `sklearn.datasets.make_spd_matrix`. + + numpy array + A PSD matrix (or strictly PD if strict_pd==True) of + shape (n_features, n_features), that will be used as such to + initialize the metric, or set the prior. + + random_state : int or `numpy.RandomState` or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``init='random'``, ``random_state`` is used to set the random Mahalanobis + matrix. If ``init='pca'``, ``random_state`` is passed as an + argument to PCA when initializing the matrix. + + return_inverse : bool, optional (default=False) + Whether to return the inverse of the specified matrix. This + can be sometimes useful. It will return the pseudo-inverse (which is the + same as the inverse if the matrix is definite (i.e. invertible)). If + `strict_pd == True` and the matrix is not definite, it will return an + error. + + strict_pd : bool, optional (default=False) + Whether to enforce that the provided matrix is definite (in addition to + being PSD). + + param_name : str, optional (default='matrix') + The name of the matrix used (example: 'init', 'prior'). Will be used in + error messages. + + Returns + ------- + M, or (M, M_inv) : `numpy.ndarray` + The initial matrix to use M, and its inverse if `return_inverse=True`. + """ + n_features = input.shape[-1] + if isinstance(init, np.ndarray): + # we copy the array, so that if we update the metric, we don't want to + # update the init + init = check_array(init, copy=True) + + # Assert that init.shape[1] = n_features + if init.shape != (n_features,) * 2: + raise ValueError('The input dimensionality {} of the given ' + 'mahalanobis matrix `{}` must match the ' + 'dimensionality of the given inputs ({}).' + .format(init.shape, matrix_name, n_features)) + + # Assert that the matrix is symmetric + if not np.allclose(init, init.T): + raise ValueError("`{}` is not symmetric.".format(matrix_name)) + + elif init not in ['identity', 'covariance', 'random']: + raise ValueError( + "`{}` must be 'identity', 'covariance', 'random' " + "or a numpy array of shape (n_features, n_features)." + .format(matrix_name)) + + random_state = check_random_state(random_state) + M = init + if isinstance(init, np.ndarray): + s, u = scipy.linalg.eigh(init) + init_is_definite = _check_sdp_from_eigen(s) + if strict_pd and not init_is_definite: + raise LinAlgError("You should provide a strictly positive definite " + "matrix as `{}`. This one is not definite. Try another" + " {}, or an algorithm that does not " + "require the {} to be strictly positive definite." + .format(*((matrix_name,) * 3))) + if return_inverse: + M_inv = np.dot(u / s, u.T) + return M, M_inv + else: + return M + elif init == 'identity': + M = np.eye(n_features, n_features) + if return_inverse: + M_inv = M.copy() + return M, M_inv + else: + return M + elif init == 'covariance': + if input.ndim == 3: + # if the input are tuples, we need to form an X by deduplication + X = np.vstack({tuple(row) for row in input.reshape(-1, n_features)}) + else: + X = input + # atleast2d is necessary to deal with scalar covariance matrices + M_inv = np.atleast_2d(np.cov(X, rowvar=False)) + s, u = scipy.linalg.eigh(M_inv) + cov_is_definite = _check_sdp_from_eigen(s) + if strict_pd and not cov_is_definite: + raise LinAlgError("Unable to get a true inverse of the covariance " + "matrix since it is not definite. Try another " + "`{}`, or an algorithm that does not " + "require the `{}` to be strictly positive definite." + .format(*((matrix_name,) * 2))) + M = np.dot(u / s, u.T) + if return_inverse: + return M, M_inv + else: + return M + elif init == 'random': + # we need to create a random symmetric matrix + M = make_spd_matrix(n_features, random_state=random_state) + if return_inverse: + # we use pinvh even if we know the matrix is definite, just because + # we need the returned matrix to be symmetric (and sometimes + # np.linalg.inv returns not symmetric inverses of symmetric matrices) + # TODO: there might be a more efficient method to do so + M_inv = pinvh(M) + return M, M_inv + else: + return M + + def _check_n_components(n_features, n_components): """Checks that n_components is less than n_features and deal with the None case""" diff --git a/metric_learn/covariance.py b/metric_learn/covariance.py index 7f606921..19dad5d8 100644 --- a/metric_learn/covariance.py +++ b/metric_learn/covariance.py @@ -22,7 +22,7 @@ class Covariance(MahalanobisMixin, TransformerMixin): Attributes ---------- - transformer_ : `numpy.ndarray`, shape=(n_components, n_features) + transformer_ : `numpy.ndarray`, shape=(n_features, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) """ diff --git a/metric_learn/exceptions.py b/metric_learn/exceptions.py index 424d2c4f..76f09778 100644 --- a/metric_learn/exceptions.py +++ b/metric_learn/exceptions.py @@ -2,6 +2,7 @@ The :mod:`metric_learn.exceptions` module includes all custom warnings and error classes used across metric-learn. """ +from numpy.linalg import LinAlgError class PreprocessorError(Exception): @@ -10,3 +11,10 @@ def __init__(self, original_error): err_msg = ("An error occurred when trying to use the " "preprocessor: {}").format(repr(original_error)) super(PreprocessorError, self).__init__(err_msg) + + +class NonPSDError(LinAlgError): + + def __init__(self): + err_msg = "Matrix is not positive semidefinite (PSD)." + super(LinAlgError, self).__init__(err_msg) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 25518bf6..21303c18 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -23,7 +23,7 @@ from sklearn.base import TransformerMixin from .base_metric import _PairsClassifierMixin, MahalanobisMixin from .constraints import Constraints, wrap_pairs -from ._util import vector_norm, transformer_from_metric +from ._util import transformer_from_metric, _initialize_metric_mahalanobis class _BaseITML(MahalanobisMixin): @@ -32,7 +32,8 @@ class _BaseITML(MahalanobisMixin): _tuple_size = 2 # constraints are pairs def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, - A0=None, verbose=False, preprocessor=None): + prior='identity', A0='deprecated', verbose=False, + preprocessor=None, random_state=None): """Initialize ITML. Parameters @@ -44,8 +45,32 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, convergence_threshold : float, optional - A0 : (d x d) matrix, optional - initial regularization matrix, defaults to identity + prior : string or numpy array, optional (default='identity') + The Mahalanobis matrix to use as a prior. Possible options are + 'identity', 'covariance', 'random', and a numpy array of shape + (n_features, n_features). For ITML, the prior should be strictly + positive definite (PD). + + 'identity' + An identity matrix of shape (n_features, n_features). + + 'covariance' + The inverse covariance matrix. + + 'random' + The prior will be a random SPD matrix of shape + `(n_features, n_features)`, generated using + `sklearn.datasets.make_spd_matrix`. + + numpy array + A positive definite (PD) matrix of shape + (n_features, n_features), that will be used as such to set the + prior. + + A0 : Not used + .. deprecated:: 0.5.0 + `A0` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Use 'prior' instead. verbose : bool, optional if True, prints information while learning @@ -53,15 +78,26 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, preprocessor : array-like, shape=(n_samples, n_features) or callable The preprocessor to call to get tuples from indices. If array-like, tuples will be formed like this: X[indices]. + + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``prior='random'``, ``random_state`` is used to set the prior. """ self.gamma = gamma self.max_iter = max_iter self.convergence_threshold = convergence_threshold + self.prior = prior self.A0 = A0 self.verbose = verbose + self.random_state = random_state super(_BaseITML, self).__init__(preprocessor) def _fit(self, pairs, y, bounds=None): + if self.A0 != 'deprecated': + warnings.warn('"A0" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. Use "prior" instead.', + DeprecationWarning) pairs, y = self._prepare_inputs(pairs, y, type_of_inputs='tuples') # init bounds @@ -76,11 +112,11 @@ def _fit(self, pairs, y, bounds=None): raise ValueError("`bounds` should be an array-like of two elements.") self.bounds_ = bounds self.bounds_[self.bounds_ == 0] = 1e-9 - # init metric - if self.A0 is None: - A = np.identity(pairs.shape[2]) - else: - A = check_array(self.A0, copy=True) + # set the prior + # pairs will be deduplicated into X two times, TODO: avoid that + A = _initialize_metric_mahalanobis(pairs, self.prior, self.random_state, + strict_pd=True, + matrix_name='prior') gamma = self.gamma pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] num_pos = len(pos_pairs) @@ -150,7 +186,7 @@ class ITML(_BaseITML, _PairsClassifierMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(n_components, n_features) + transformer_ : `numpy.ndarray`, shape=(n_features, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -218,14 +254,15 @@ class ITML_Supervised(_BaseITML, TransformerMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(n_components, n_features) + transformer_ : `numpy.ndarray`, shape=(n_features, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) """ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, num_labeled='deprecated', num_constraints=None, - bounds='deprecated', A0=None, verbose=False, preprocessor=None): + bounds='deprecated', prior='identity', A0='deprecated', + verbose=False, preprocessor=None, random_state=None): """Initialize the supervised version of `ITML`. `ITML_Supervised` creates pairs of similar sample by taking same class @@ -249,17 +286,46 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, `bounds` was deprecated in version 0.5.0 and will be removed in 0.6.0. Set `bounds` at fit time instead : `itml_supervised.fit(X, y, bounds=...)` - A0 : (d x d) matrix, optional - initial regularization matrix, defaults to identity + + prior : string or numpy array, optional (default='identity') + Initialization of the Mahalanobis matrix. Possible options are + 'identity', 'covariance', 'random', and a numpy array of shape + (n_features, n_features). For ITML, the prior should be strictly + positive definite (PD). + + 'identity' + An identity matrix of shape (n_features, n_features). + + 'covariance' + The inverse covariance matrix. + + 'random' + The prior will be a random SPD matrix of shape + `(n_features, n_features)`, generated using + `sklearn.datasets.make_spd_matrix`. + + numpy array + A positive definite (PD) matrix of shape + (n_features, n_features), that will be used as such to set the + prior. + + A0 : Not used + .. deprecated:: 0.5.0 + `A0` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Use 'prior' instead. verbose : bool, optional if True, prints information while learning preprocessor : array-like, shape=(n_samples, n_features) or callable The preprocessor to call to get tuples from indices. If array-like, tuples will be formed like this: X[indices]. + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``prior='random'``, ``random_state`` is used to set the prior. """ _BaseITML.__init__(self, gamma=gamma, max_iter=max_iter, convergence_threshold=convergence_threshold, - A0=A0, verbose=verbose, preprocessor=preprocessor) + A0=A0, prior=prior, verbose=verbose, + preprocessor=preprocessor, random_state=random_state) self.num_labeled = num_labeled self.num_constraints = num_constraints self.bounds = bounds diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 1ba87684..c2437b86 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -20,20 +20,60 @@ from sklearn.metrics import euclidean_distances from sklearn.base import TransformerMixin -from ._util import _check_n_components +from ._util import _initialize_transformer, _check_n_components from .base_metric import MahalanobisMixin # commonality between LMNN implementations class _base_LMNN(MahalanobisMixin, TransformerMixin): - def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, - regularization=0.5, convergence_tol=0.001, use_pca=True, - verbose=False, preprocessor=None, n_components=None, - num_dims='deprecated'): + def __init__(self, init='auto', k=3, min_iter=50, max_iter=1000, + learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, + use_pca=True, verbose=False, preprocessor=None, + n_components=None, num_dims='deprecated', random_state=None): """Initialize the LMNN object. Parameters ---------- + init : string or numpy array, optional (default='auto') + Initialization of the linear transformation. Possible options are + 'auto', 'pca', 'lda', 'identity', 'random', and a numpy array of shape + (n_features_a, n_features_b). + + 'auto' + Depending on ``n_components``, the most reasonable initialization + will be chosen. If ``n_components <= n_classes`` we use 'lda', as + it uses labels information. If not, but + ``n_components < min(n_features, n_samples)``, we use 'pca', as + it projects data in meaningful directions (those of higher + variance). Otherwise, we just use 'identity'. + + 'pca' + ``n_components`` principal components of the inputs passed + to :meth:`fit` will be used to initialize the transformation. + (See `sklearn.decomposition.PCA`) + + 'lda' + ``min(n_components, n_classes)`` most discriminative + components of the inputs passed to :meth:`fit` will be used to + initialize the transformation. (If ``n_components > n_classes``, + the rest of the components will be zero.) (See + `sklearn.discriminant_analysis.LinearDiscriminantAnalysis`) + + 'identity' + If ``n_components`` is strictly smaller than the + dimensionality of the inputs passed to :meth:`fit`, the identity + matrix will be truncated to the first ``n_components`` rows. + + 'random' + The initial transformation will be a random array of shape + `(n_components, n_features)`. Each value is sampled from the + standard normal distribution. + + numpy array + n_features_b must match the dimensionality of the inputs passed to + :meth:`fit` and n_features_a must be less than or equal to that. + If ``n_components`` is not None, n_features_a must match it. + k : int, optional Number of neighbors to consider, not including self-edges. @@ -52,7 +92,14 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, .. deprecated:: 0.5.0 `num_dims` was deprecated in version 0.5.0 and will be removed in 0.6.0. Use `n_components` instead. + + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``init='random'``, ``random_state`` is used to initialize the random + transformation. If ``init='pca'``, ``random_state`` is passed as an + argument to PCA when initializing the transformation. """ + self.init = init self.k = k self.min_iter = min_iter self.max_iter = max_iter @@ -63,6 +110,7 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, self.verbose = verbose self.n_components = n_components self.num_dims = num_dims + self.random_state = random_state super(_base_LMNN, self).__init__(preprocessor) @@ -87,9 +135,9 @@ def fit(self, X, y): if len(label_inds) != num_pts: raise ValueError('Must have one label per point.') self.labels_ = np.arange(len(unique_labels)) - if self.use_pca: - warnings.warn('use_pca does nothing for the python_LMNN implementation') - self.transformer_ = np.eye(output_dim, d) + self.transformer_ = _initialize_transformer(output_dim, X, y, self.init, + self.verbose, + self.random_state) required_k = np.bincount(label_inds).min() if self.k > required_k: raise ValueError('not enough class labels for specified k' @@ -122,6 +170,8 @@ def fit(self, X, y): self._loss_grad(X, L, dfG, impostors, 1, k, reg, target_neighbors, df, a1, a2)) + it = 1 # we already made one iteration + # main loop for it in xrange(2, self.max_iter): # then at each iteration, we try to find a value of L that has better diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 94366b88..4350b003 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -20,36 +20,64 @@ import scipy.linalg from six.moves import xrange from sklearn.base import TransformerMixin +from sklearn.exceptions import ChangedBehaviorWarning from .base_metric import _QuadrupletsClassifierMixin, MahalanobisMixin from .constraints import Constraints -from ._util import transformer_from_metric +from ._util import transformer_from_metric, _initialize_metric_mahalanobis class _BaseLSML(MahalanobisMixin): _tuple_size = 4 # constraints are quadruplets - def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False, - preprocessor=None): + def __init__(self, tol=1e-3, max_iter=1000, prior=None, + verbose=False, preprocessor=None, random_state=None): """Initialize LSML. Parameters ---------- + prior : None, string or numpy array, optional (default=None) + Prior to set for the metric. Possible options are + 'identity', 'covariance', 'random', and a numpy array of + shape (n_features, n_features). For LSML, the prior should be strictly + positive definite (PD). If `None`, will be set + automatically to 'identity' (this is to raise a warning if + `prior` is not set, and stays to its default value (None), in v0.5.0). + + 'identity' + An identity matrix of shape (n_features, n_features). + + 'covariance' + The inverse covariance matrix. + + 'random' + The initial Mahalanobis matrix will be a random positive definite + (PD) matrix of shape `(n_features, n_features)`, generated using + `sklearn.datasets.make_spd_matrix`. + + numpy array + A positive definite (PD) matrix of shape + (n_features, n_features), that will be used as such to set the + prior. + tol : float, optional max_iter : int, optional - prior : (d x d) matrix, optional - guess at a metric [default: inv(covariance(X))] verbose : bool, optional if True, prints information while learning preprocessor : array-like, shape=(n_samples, n_features) or callable The preprocessor to call to get tuples from indices. If array-like, tuples will be formed like this: X[indices]. + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``init='random'``, ``random_state`` is used to set the random + prior. """ self.prior = prior self.tol = tol self.max_iter = max_iter self.verbose = verbose + self.random_state = random_state super(_BaseLSML, self).__init__(preprocessor) def _fit(self, quadruplets, weights=None): @@ -66,14 +94,23 @@ def _fit(self, quadruplets, weights=None): else: self.w_ = weights self.w_ /= self.w_.sum() # weights must sum to 1 + # if the prior is the default (identity), we raise a warning just in case if self.prior is None: - X = np.vstack({tuple(row) for row in - quadruplets.reshape(-1, quadruplets.shape[2])}) - prior_inv = np.atleast_2d(np.cov(X, rowvar=False)) - M = np.linalg.inv(prior_inv) + msg = ("Warning, no prior was set (`prior=None`). As of version 0.5.0, " + "the default prior will now be set to " + "'identity', instead of 'covariance'. If you still want to use " + "the inverse of the covariance matrix as a prior, " + "set prior='covariance'. This warning will disappear in " + "v0.6.0, and `prior` parameter's default value will be set to " + "'identity'.") + warnings.warn(msg, ChangedBehaviorWarning) + prior = 'identity' else: - M = self.prior - prior_inv = np.linalg.inv(self.prior) + prior = self.prior + M, prior_inv = _initialize_metric_mahalanobis(quadruplets, prior, + return_inverse=True, + strict_pd=True, + matrix_name='prior') step_sizes = np.logspace(-10, 0, 10) # Keep track of the best step size and the loss at that step. @@ -146,7 +183,7 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(n_components, n_features) + transformer_ : `numpy.ndarray`, shape=(n_features, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) """ @@ -182,15 +219,14 @@ class LSML_Supervised(_BaseLSML, TransformerMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(n_components, n_features) + transformer_ : `numpy.ndarray`, shape=(n_features, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) """ def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled='deprecated', num_constraints=None, weights=None, - verbose=False, - preprocessor=None): + verbose=False, preprocessor=None, random_state=None): """Initialize the supervised version of `LSML`. `LSML_Supervised` creates quadruplets from labeled samples by taking two @@ -202,8 +238,29 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, ---------- tol : float, optional max_iter : int, optional - prior : (d x d) matrix, optional - guess at a metric [default: covariance(X)] + prior : None, string or numpy array, optional (default=None) + Prior to set for the metric. Possible options are + 'identity', 'covariance', 'random', and a numpy array of + shape (n_features, n_features). For LSML, the prior should be strictly + positive definite (PD). If `None`, will be set + automatically to 'identity' (this is to raise a warning if + `prior` is not set, and stays to its default value (None), in v0.5.0). + + 'identity' + An identity matrix of shape (n_features, n_features). + + 'covariance' + The inverse covariance matrix. + + 'random' + The initial Mahalanobis matrix will be a random positive definite + (PD) matrix of shape `(n_features, n_features)`, generated using + `sklearn.datasets.make_spd_matrix`. + + numpy array + A positive definite (PD) matrix of shape + (n_features, n_features), that will be used as such to set the + prior. num_labeled : Not used .. deprecated:: 0.5.0 `num_labeled` was deprecated in version 0.5.0 and will @@ -217,9 +274,14 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, preprocessor : array-like, shape=(n_samples, n_features) or callable The preprocessor to call to get tuples from indices. If array-like, tuples will be formed like this: X[indices]. + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``init='random'``, ``random_state`` is used to set the random + prior. """ _BaseLSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior, - verbose=verbose, preprocessor=preprocessor) + verbose=verbose, preprocessor=preprocessor, + random_state=random_state) self.num_labeled = num_labeled self.num_constraints = num_constraints self.weights = weights diff --git a/metric_learn/mlkr.py b/metric_learn/mlkr.py index 762317b9..9e9cf433 100644 --- a/metric_learn/mlkr.py +++ b/metric_learn/mlkr.py @@ -14,18 +14,16 @@ import sys import warnings import numpy as np -from sklearn.exceptions import ConvergenceWarning +from sklearn.exceptions import ConvergenceWarning, ChangedBehaviorWarning from sklearn.utils.fixes import logsumexp from scipy.optimize import minimize -from scipy.spatial.distance import pdist, squareform from sklearn.base import TransformerMixin -from sklearn.decomposition import PCA - from sklearn.metrics import pairwise_distances from metric_learn._util import _check_n_components from .base_metric import MahalanobisMixin +from ._util import _initialize_transformer EPS = np.finfo(float).eps @@ -42,8 +40,9 @@ class MLKR(MahalanobisMixin, TransformerMixin): The learned linear transformation ``L``. """ - def __init__(self, n_components=None, num_dims='deprecated', A0=None, - tol=None, max_iter=1000, verbose=False, preprocessor=None): + def __init__(self, n_components=None, num_dims='deprecated', init=None, + A0='deprecated', tol=None, max_iter=1000, verbose=False, + preprocessor=None, random_state=None): """ Initialize MLKR. @@ -58,14 +57,49 @@ def __init__(self, n_components=None, num_dims='deprecated', A0=None, `num_dims` was deprecated in version 0.5.0 and will be removed in 0.6.0. Use `n_components` instead. - A0: array-like, optional - Initialization of transformation matrix. Defaults to PCA loadings. + init : None, string or numpy array, optional (default=None) + Initialization of the linear transformation. Possible options are + 'auto', 'pca', 'identity', 'random', and a numpy array of shape + (n_features_a, n_features_b). If None, will be set automatically to + 'auto' (this option is to raise a warning if 'init' is not set, + and stays to its default value None, in v0.5.0). + + 'auto' + Depending on ``n_components``, the most reasonable initialization + will be chosen. If ``n_components < min(n_features, n_samples)``, + we use 'pca', as it projects data in meaningful directions (those + of higher variance). Otherwise, we just use 'identity'. + + 'pca' + ``n_components`` principal components of the inputs passed + to :meth:`fit` will be used to initialize the transformation. + (See `sklearn.decomposition.PCA`) + + 'identity' + If ``n_components`` is strictly smaller than the + dimensionality of the inputs passed to :meth:`fit`, the identity + matrix will be truncated to the first ``n_components`` rows. + + 'random' + The initial transformation will be a random array of shape + `(n_components, n_features)`. Each value is sampled from the + standard normal distribution. + + numpy array + n_features_b must match the dimensionality of the inputs passed to + :meth:`fit` and n_features_a must be less than or equal to that. + If ``num_dims`` is not None, n_features_a must match it. + + A0: Not used. + .. deprecated:: 0.5.0 + `A0` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Use 'init' instead. tol: float, optional (default=None) Convergence tolerance for the optimization. max_iter: int, optional - Cap on number of congugate gradient iterations. + Cap on number of conjugate gradient iterations. verbose : bool, optional (default=False) Whether to print progress messages or not. @@ -73,13 +107,21 @@ def __init__(self, n_components=None, num_dims='deprecated', A0=None, preprocessor : array-like, shape=(n_samples, n_features) or callable The preprocessor to call to get tuples from indices. If array-like, tuples will be formed like this: X[indices]. + + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``init='random'``, ``random_state`` is used to initialize the random + transformation. If ``init='pca'``, ``random_state`` is passed as an + argument to PCA when initializing the transformation. """ self.n_components = n_components self.num_dims = num_dims + self.init = init self.A0 = A0 self.tol = tol self.max_iter = max_iter self.verbose = verbose + self.random_state = random_state super(MLKR, self).__init__(preprocessor) def fit(self, X, y): @@ -91,11 +133,18 @@ def fit(self, X, y): X : (n x d) array of samples y : (n) data labels """ + if self.A0 != 'deprecated': + warnings.warn('"A0" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. Use "init" instead.', + DeprecationWarning) + if self.num_dims != 'deprecated': warnings.warn('"num_dims" parameter is not used.' ' It has been deprecated in version 0.5.0 and will be' ' removed in 0.6.0. Use "n_components" instead', DeprecationWarning) + X, y = self._prepare_inputs(X, y, y_numeric=True, ensure_min_samples=2) n, d = X.shape @@ -103,18 +152,27 @@ def fit(self, X, y): raise ValueError('Data and label lengths mismatch: %d != %d' % (n, y.shape[0])) - A = self.A0 m = _check_n_components(d, self.n_components) m = self.n_components if m is None: m = d - if A is None: - # initialize to PCA transformation matrix - # note: not the same as n_components=m ! - A = PCA().fit(X).components_.T[:m] - elif A.shape != (m, d): - raise ValueError('A0 needs shape (%d,%d) but got %s' % ( - m, d, A.shape)) + # if the init is the default (identity), we raise a warning just in case + if self.init is None: + # TODO: + # replace init=None by init='auto' in v0.6.0 and remove the warning + msg = ("Warning, no init was set (`init=None`). As of version 0.5.0, " + "the default init will now be set to 'auto', instead of 'pca'. " + "If you still want to use PCA as an init, set init='pca'. " + "This warning will disappear in v0.6.0, and `init` parameter's" + " default value will be set to 'auto'.") + warnings.warn(msg, ChangedBehaviorWarning) + init = 'auto' + else: + init = self.init + A = _initialize_transformer(m, X, y, init=init, + random_state=self.random_state, + # MLKR works on regression targets: + has_classes=False) # Measure the total training time train_time = time.time() diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index 0e6cd5cb..b3e6c203 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -21,11 +21,12 @@ import numpy as np from six.moves import xrange from sklearn.base import TransformerMixin -from sklearn.utils.validation import check_array, assert_all_finite +from sklearn.utils.validation import assert_all_finite +from sklearn.exceptions import ChangedBehaviorWarning from .base_metric import _PairsClassifierMixin, MahalanobisMixin from .constraints import Constraints, wrap_pairs -from ._util import vector_norm, transformer_from_metric +from ._util import transformer_from_metric, _initialize_metric_mahalanobis class _BaseMMC(MahalanobisMixin): @@ -34,20 +35,51 @@ class _BaseMMC(MahalanobisMixin): _tuple_size = 2 # constraints are pairs def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, - A0=None, diagonal=False, diagonal_c=1.0, verbose=False, - preprocessor=None): + init=None, A0='deprecated', diagonal=False, + diagonal_c=1.0, verbose=False, preprocessor=None, + random_state=None): """Initialize MMC. Parameters ---------- max_iter : int, optional max_proj : int, optional convergence_threshold : float, optional - A0 : (d x d) matrix, optional - initial metric, defaults to identity - only the main diagonal is taken if `diagonal == True` + init : None, string or numpy array, optional (default=None) + Initialization of the Mahalanobis matrix. Possible options are + 'identity', 'covariance', 'random', and a numpy array of + shape (n_features, n_features). If None, will be set + automatically to 'identity' (this is to raise a warning if + 'init' is not set, and stays to its default value (None), in v0.5.0). + + 'identity' + An identity matrix of shape (n_features, n_features). + + 'covariance' + The (pseudo-)inverse of the covariance matrix. + + 'random' + The initial Mahalanobis matrix will be a random SPD matrix of shape + `(n_features, n_features)`, generated using + `sklearn.datasets.make_spd_matrix`. + + numpy array + An SPD matrix of shape (n_features, n_features), that will + be used as such to initialize the metric. + + verbose : bool, optional + if True, prints information while learning + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be gotten like this: X[indices]. + A0 : Not used. + .. deprecated:: 0.5.0 + `A0` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Use 'init' instead. diagonal : bool, optional if True, a diagonal metric will be learned, - i.e., a simple scaling of dimensions + i.e., a simple scaling of dimensions. The initialization will then + be the diagonal coefficients of the matrix given as 'init'. diagonal_c : float, optional weight of the dissimilarity constraint for diagonal metric learning @@ -56,29 +88,49 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, preprocessor : array-like, shape=(n_samples, n_features) or callable The preprocessor to call to get tuples from indices. If array-like, tuples will be gotten like this: X[indices]. + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``init='random'``, ``random_state`` is used to initialize the random + transformation. """ self.max_iter = max_iter self.max_proj = max_proj self.convergence_threshold = convergence_threshold + self.init = init self.A0 = A0 self.diagonal = diagonal self.diagonal_c = diagonal_c self.verbose = verbose + self.random_state = random_state super(_BaseMMC, self).__init__(preprocessor) def _fit(self, pairs, y): + if self.A0 != 'deprecated': + warnings.warn('"A0" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. Use "init" instead.', + DeprecationWarning) pairs, y = self._prepare_inputs(pairs, y, type_of_inputs='tuples') - # init metric - if self.A0 is None: - self.A_ = np.identity(pairs.shape[2]) - if not self.diagonal: - # Don't know why division by 10... it's in the original code - # and seems to affect the overall scale of the learned metric. - self.A_ /= 10.0 + if self.init is None: + # TODO: replace init=None by init='auto' in v0.6.0 and remove the warning + msg = ("Warning, no init was set (`init=None`). As of version 0.5.0, " + "the default init will now be set to 'identity', instead of the " + "identity divided by a scaling factor of 10. " + "If you still want to use the same init as in previous " + "versions, set init=np.eye(d)/10, where d is the dimension " + "of your input space (d=pairs.shape[1]). " + "This warning will disappear in v0.6.0, and `init` parameter's" + " default value will be set to 'auto'.") + warnings.warn(msg, ChangedBehaviorWarning) + init = 'identity' else: - self.A_ = check_array(self.A0) + init = self.init + + self.A_ = _initialize_metric_mahalanobis(pairs, init, + random_state=self.random_state, + matrix_name='init') if self.diagonal: return self._fit_diag(pairs, y) @@ -356,7 +408,7 @@ class MMC(_BaseMMC, _PairsClassifierMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(n_components, n_features) + transformer_ : `numpy.ndarray`, shape=(n_features, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -406,15 +458,15 @@ class MMC_Supervised(_BaseMMC, TransformerMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(n_components, n_features) + transformer_ : `numpy.ndarray`, shape=(n_features, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) """ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, - num_labeled='deprecated', num_constraints=None, A0=None, - diagonal=False, diagonal_c=1.0, verbose=False, - preprocessor=None): + num_labeled='deprecated', num_constraints=None, init=None, + A0='deprecated', diagonal=False, diagonal_c=1.0, verbose=False, + preprocessor=None, random_state=None): """Initialize the supervised version of `MMC`. `MMC_Supervised` creates pairs of similar sample by taking same class @@ -432,9 +484,38 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, be removed in 0.6.0. num_constraints: int, optional number of constraints to generate - A0 : (d x d) matrix, optional - initial metric, defaults to identity - only the main diagonal is taken if `diagonal == True` + init : None, string or numpy array, optional (default=None) + Initialization of the Mahalanobis matrix. Possible options are + 'identity', 'covariance', 'random', and a numpy array of + shape (n_features, n_features). If None, will be set + automatically to 'identity' (this is to raise a warning if + 'init' is not set, and stays to its default value (None), in v0.5.0). + + 'identity' + An identity matrix of shape (n_features, n_features). + + 'covariance' + The (pseudo-)inverse of the covariance matrix. + + 'random' + The initial Mahalanobis matrix will be a random SPD matrix of + shape `(n_features, n_features)`, generated using + `sklearn.datasets.make_spd_matrix`. + + numpy array + A numpy array of shape (n_features, n_features), that will + be used as such to initialize the metric. + + verbose : bool, optional + if True, prints information while learning + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be gotten like this: X[indices]. + A0 : Not used. + .. deprecated:: 0.5.0 + `A0` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Use 'init' instead. diagonal : bool, optional if True, a diagonal metric will be learned, i.e., a simple scaling of dimensions @@ -446,11 +527,16 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, preprocessor : array-like, shape=(n_samples, n_features) or callable The preprocessor to call to get tuples from indices. If array-like, tuples will be formed like this: X[indices]. + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``init='random'``, ``random_state`` is used to initialize the random + Mahalanobis matrix. """ _BaseMMC.__init__(self, max_iter=max_iter, max_proj=max_proj, convergence_threshold=convergence_threshold, - A0=A0, diagonal=diagonal, diagonal_c=diagonal_c, - verbose=verbose, preprocessor=preprocessor) + init=init, A0=A0, diagonal=diagonal, + diagonal_c=diagonal_c, verbose=verbose, + preprocessor=preprocessor, random_state=random_state) self.num_labeled = num_labeled self.num_constraints = num_constraints diff --git a/metric_learn/nca.py b/metric_learn/nca.py index 3545aa89..1626e02f 100644 --- a/metric_learn/nca.py +++ b/metric_learn/nca.py @@ -19,11 +19,11 @@ import numpy as np from scipy.optimize import minimize from sklearn.metrics import pairwise_distances -from sklearn.exceptions import ConvergenceWarning +from sklearn.exceptions import ConvergenceWarning, ChangedBehaviorWarning from sklearn.utils.fixes import logsumexp from sklearn.base import TransformerMixin -from ._util import _check_n_components +from ._util import _initialize_transformer, _check_n_components from .base_metric import MahalanobisMixin EPS = np.finfo(float).eps @@ -41,12 +41,55 @@ class NCA(MahalanobisMixin, TransformerMixin): The learned linear transformation ``L``. """ - def __init__(self, n_components=None, num_dims='deprecated', max_iter=100, - tol=None, verbose=False, preprocessor=None): + def __init__(self, init=None, n_components=None, num_dims='deprecated', + max_iter=100, tol=None, verbose=False, preprocessor=None, + random_state=None): """Neighborhood Components Analysis Parameters ---------- + init : None, string or numpy array, optional (default=None) + Initialization of the linear transformation. Possible options are + 'auto', 'pca', 'identity', 'random', and a numpy array of shape + (n_features_a, n_features_b). If None, will be set automatically to + 'auto' (this option is to raise a warning if 'init' is not set, + and stays to its default value None, in v0.5.0). + + 'auto' + Depending on ``n_components``, the most reasonable initialization + will be chosen. If ``n_components <= n_classes`` we use 'lda', as + it uses labels information. If not, but + ``n_components < min(n_features, n_samples)``, we use 'pca', as + it projects data in meaningful directions (those of higher + variance). Otherwise, we just use 'identity'. + + 'pca' + ``n_components`` principal components of the inputs passed + to :meth:`fit` will be used to initialize the transformation. + (See `sklearn.decomposition.PCA`) + + 'lda' + ``min(n_components, n_classes)`` most discriminative + components of the inputs passed to :meth:`fit` will be used to + initialize the transformation. (If ``n_components > n_classes``, + the rest of the components will be zero.) (See + `sklearn.discriminant_analysis.LinearDiscriminantAnalysis`) + + 'identity' + If ``n_components`` is strictly smaller than the + dimensionality of the inputs passed to :meth:`fit`, the identity + matrix will be truncated to the first ``n_components`` rows. + + 'random' + The initial transformation will be a random array of shape + `(n_components, n_features)`. Each value is sampled from the + standard normal distribution. + + numpy array + n_features_b must match the dimensionality of the inputs passed to + :meth:`fit` and n_features_a must be less than or equal to that. + If ``n_components`` is not None, n_features_a must match it. + n_components : int or None, optional (default=None) Dimensionality of reduced space (if None, defaults to dimension of X). @@ -64,12 +107,20 @@ def __init__(self, n_components=None, num_dims='deprecated', max_iter=100, verbose : bool, optional (default=False) Whether to print progress messages or not. + + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``init='random'``, ``random_state`` is used to initialize the random + transformation. If ``init='pca'``, ``random_state`` is passed as an + argument to PCA when initializing the transformation. """ self.n_components = n_components + self.init = init self.num_dims = num_dims self.max_iter = max_iter self.tol = tol self.verbose = verbose + self.random_state = random_state super(NCA, self).__init__(preprocessor) def fit(self, X, y): @@ -89,9 +140,22 @@ def fit(self, X, y): # Measure the total training time train_time = time.time() - # Initialize A to a scaling matrix - A = np.zeros((n_components, d)) - np.fill_diagonal(A, 1. / (np.maximum(X.max(axis=0) - X.min(axis=0), EPS))) + # Initialize A + # if the init is the default (auto), we raise a warning just in case + if self.init is None: + # TODO: replace init=None by init='auto' in v0.6.0 and remove the warning + msg = ("Warning, no init was set (`init=None`). As of version 0.5.0, " + "the default init will now be set to 'auto', instead of the " + "previous scaling matrix. same scaling matrix as before as an " + "init, set init=np.eye(X.shape[1])/" + "(np.maximum(X.max(axis=0)-X.min(axis=0), EPS))). This warning " + "will disappear in v0.6.0, and `init` parameter's default value " + "will be set to 'auto'.") + warnings.warn(msg, ChangedBehaviorWarning) + init = 'auto' + else: + init = self.init + A = _initialize_transformer(n_components, X, labels, init, self.verbose) # Run NCA mask = labels[:, np.newaxis] == labels[np.newaxis, :] diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 73eeefb7..b83c553d 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -18,11 +18,11 @@ from sklearn.base import TransformerMixin from scipy.linalg import pinvh from sklearn.covariance import graphical_lasso -from sklearn.exceptions import ConvergenceWarning +from sklearn.exceptions import ConvergenceWarning, ChangedBehaviorWarning from .base_metric import MahalanobisMixin, _PairsClassifierMixin from .constraints import Constraints, wrap_pairs -from ._util import transformer_from_metric +from ._util import transformer_from_metric, _initialize_metric_mahalanobis try: from inverse_covariance import quic except ImportError: @@ -35,8 +35,9 @@ class _BaseSDML(MahalanobisMixin): _tuple_size = 2 # constraints are pairs - def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, - verbose=False, preprocessor=None): + def __init__(self, balance_param=0.5, sparsity_param=0.01, prior=None, + use_cov='deprecated', verbose=False, preprocessor=None, + random_state=None): """ Parameters ---------- @@ -46,8 +47,34 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, sparsity_param : float, optional trade off between optimizer and sparseness (see graph_lasso) - use_cov : bool, optional - controls prior matrix, will use the identity if use_cov=False + prior : None, string or numpy array, optional (default=None) + Prior to set for the metric. Possible options are + 'identity', 'covariance', 'random', and a numpy array of + shape (n_features, n_features). For SDML, the prior should be strictly + positive definite (PD). If `None`, will be set + automatically to 'identity' (this is to raise a warning if + `prior` is not set, and stays to its default value (None), in v0.5.0). + + 'identity' + An identity matrix of shape (n_features, n_features). + + 'covariance' + The inverse covariance matrix. + + 'random' + The prior will be a random positive definite (PD) matrix of shape + `(n_features, n_features)`, generated using + `sklearn.datasets.make_spd_matrix`. + + numpy array + A positive definite (PD) matrix of shape + (n_features, n_features), that will be used as such to set the + prior. + + use_cov : Not used. + .. deprecated:: 0.5.0 + `A0` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Use 'prior' instead. verbose : bool, optional if True, prints information while learning @@ -55,14 +82,25 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, preprocessor : array-like, shape=(n_samples, n_features) or callable The preprocessor to call to get tuples from indices. If array-like, tuples will be gotten like this: X[indices]. + + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``prior='random'``, ``random_state`` is used to set the prior. """ self.balance_param = balance_param self.sparsity_param = sparsity_param + self.prior = prior self.use_cov = use_cov self.verbose = verbose + self.random_state = random_state super(_BaseSDML, self).__init__(preprocessor) def _fit(self, pairs, y): + if self.use_cov != 'deprecated': + warnings.warn('"use_cov" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. Use "prior" instead.', + DeprecationWarning) if not HAS_SKGGM: if self.verbose: print("SDML will use scikit-learn's graphical lasso solver.") @@ -73,11 +111,26 @@ def _fit(self, pairs, y): type_of_inputs='tuples') # set up (the inverse of) the prior M - if self.use_cov: - X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) - prior_inv = np.atleast_2d(np.cov(X, rowvar=False)) + # if the prior is the default (identity), we raise a warning just in case + if self.prior is None: + # TODO: + # replace prior=None by prior='identity' in v0.6.0 and remove the + # warning + msg = ("Warning, no prior was set (`prior=None`). As of version 0.5.0, " + "the default prior will now be set to " + "'identity', instead of 'covariance'. If you still want to use " + "the inverse of the covariance matrix as a prior, " + "set prior='covariance'. This warning will disappear in " + "v0.6.0, and `prior` parameter's default value will be set to " + "'identity'.") + warnings.warn(msg, ChangedBehaviorWarning) + prior = 'identity' else: - prior_inv = np.identity(pairs.shape[2]) + prior = self.prior + _, prior_inv = _initialize_metric_mahalanobis(pairs, prior, + return_inverse=True, + strict_pd=True, + matrix_name='prior') diff = pairs[:, 0] - pairs[:, 1] loss_matrix = (diff.T * y).dot(diff) emp_cov = prior_inv + self.balance_param * loss_matrix @@ -92,7 +145,7 @@ def _fit(self, pairs, y): "positive semi-definite (PSD). The algorithm may diverge, " "and lead to degenerate solutions. " "To prevent that, try to decrease the balance parameter " - "`balance_param` and/or to set use_cov=False.", + "`balance_param` and/or to set prior='identity'.", ConvergenceWarning) w -= min_eigval # we translate the eigenvalues to make them all positive w += 1e-10 # we add a small offset to avoid definiteness problems @@ -139,7 +192,7 @@ class SDML(_BaseSDML, _PairsClassifierMixin): Attributes ---------- - transformer_ : `numpy.ndarray`, shape=(n_components, n_features) + transformer_ : `numpy.ndarray`, shape=(n_features, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -187,27 +240,56 @@ class SDML_Supervised(_BaseSDML, TransformerMixin): Attributes ---------- - transformer_ : `numpy.ndarray`, shape=(n_components, n_features) + transformer_ : `numpy.ndarray`, shape=(n_features, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) """ - def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, - num_labeled='deprecated', num_constraints=None, verbose=False, - preprocessor=None): + def __init__(self, balance_param=0.5, sparsity_param=0.01, prior=None, + use_cov='deprecated', num_labeled='deprecated', + num_constraints=None, verbose=False, preprocessor=None, + random_state=None): """Initialize the supervised version of `SDML`. `SDML_Supervised` creates pairs of similar sample by taking same class samples, and pairs of dissimilar samples by taking different class samples. It then passes these pairs to `SDML` for training. + Parameters ---------- balance_param : float, optional trade off between sparsity and M0 prior sparsity_param : float, optional trade off between optimizer and sparseness (see graph_lasso) - use_cov : bool, optional - controls prior matrix, will use the identity if use_cov=False + prior : None, string or numpy array, optional (default=None) + Prior to set for the metric. Possible options are + 'identity', 'covariance', 'random', and a numpy array of + shape (n_features, n_features). For SDML, the prior should be strictly + positive definite (PD). If `None`, will be set + automatically to 'identity' (this is to raise a warning if + `prior` is not set, and stays to its default value (None), in v0.5.0). + + 'identity' + An identity matrix of shape (n_features, n_features). + + 'covariance' + The inverse covariance matrix. + + 'random' + The prior will be a random SPD matrix of shape + `(n_features, n_features)`, generated using + `sklearn.datasets.make_spd_matrix`. + + numpy array + A positive definite (PD) matrix of shape + (n_features, n_features), that will be used as such to set the + prior. + + use_cov : Not used. + .. deprecated:: 0.5.0 + `A0` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Use 'prior' instead. + num_labeled : Not used .. deprecated:: 0.5.0 `num_labeled` was deprecated in version 0.5.0 and will @@ -219,10 +301,15 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, preprocessor : array-like, shape=(n_samples, n_features) or callable The preprocessor to call to get tuples from indices. If array-like, tuples will be formed like this: X[indices]. + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``init='random'``, ``random_state`` is used to set the random + prior. """ _BaseSDML.__init__(self, balance_param=balance_param, - sparsity_param=sparsity_param, use_cov=use_cov, - verbose=verbose, preprocessor=preprocessor) + sparsity_param=sparsity_param, prior=prior, + use_cov=use_cov, verbose=verbose, + preprocessor=preprocessor, random_state=random_state) self.num_labeled = num_labeled self.num_constraints = num_constraints diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 969bd7e5..18643363 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -5,11 +5,12 @@ from scipy.optimize import check_grad, approx_fprime from six.moves import xrange from sklearn.metrics import pairwise_distances -from sklearn.datasets import load_iris, make_classification, make_regression +from sklearn.datasets import (load_iris, make_classification, make_regression, + make_spd_matrix) from numpy.testing import (assert_array_almost_equal, assert_array_equal, assert_allclose) from sklearn.utils.testing import assert_warns_message -from sklearn.exceptions import ConvergenceWarning +from sklearn.exceptions import ConvergenceWarning, ChangedBehaviorWarning from sklearn.utils.validation import check_X_y try: from inverse_covariance import quic @@ -19,7 +20,7 @@ HAS_SKGGM = True from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC, RCA, LSML_Supervised, ITML_Supervised, SDML_Supervised, - RCA_Supervised, MMC_Supervised, SDML, ITML) + RCA_Supervised, MMC_Supervised, SDML, ITML, LSML) # Import this specially for testing. from metric_learn.constraints import wrap_pairs from metric_learn.lmnn import python_LMNN, _sum_outer_products @@ -92,6 +93,31 @@ def test_deprecation_num_labeled(self): ' removed in 0.6.0') assert_warns_message(DeprecationWarning, msg, lsml_supervised.fit, X, y) + def test_changed_behaviour_warning(self): + # test that a ChangedBehavior warning is thrown about the init, if the + # default parameters are used. + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + lsml_supervised = LSML_Supervised() + msg = ("Warning, no prior was set (`prior=None`). As of version 0.5.0, " + "the default prior will now be set to " + "'identity', instead of 'covariance'. If you still want to use " + "the inverse of the covariance matrix as a prior, " + "set prior='covariance'. This warning will disappear in " + "v0.6.0, and `prior` parameter's default value will be set to " + "'identity'.") + with pytest.warns(ChangedBehaviorWarning) as raised_warning: + lsml_supervised.fit(X, y) + assert any(msg == str(wrn.message) for wrn in raised_warning) + + pairs = np.array([[[-10., 0.], [10., 0.], [-5., 3.], [5., 0.]], + [[0., 50.], [0., -60], [-10., 0.], [10., 0.]]]) + lsml = LSML() + with pytest.warns(ChangedBehaviorWarning) as raised_warning: + lsml.fit(pairs) + assert any(msg == str(wrn.message) for wrn in raised_warning) + class TestITML(MetricTestCase): def test_iris(self): @@ -126,6 +152,27 @@ def test_deprecation_bounds(self): 'fit method instead.') assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y) + def test_deprecation_A0(self): + # test that a deprecation message is thrown if A0 is set at + # initialization + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + itml_supervised = ITML_Supervised(A0=np.ones_like(X)) + msg = ('"A0" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. Use "prior" instead.') + with pytest.warns(DeprecationWarning) as raised_warning: + itml_supervised.fit(X, y) + assert any(msg == str(wrn.message) for wrn in raised_warning) + + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) + y_pairs = [1, -1] + itml = ITML(A0=np.ones_like(X)) + with pytest.warns(DeprecationWarning) as raised_warning: + itml.fit(pairs, y_pairs) + assert any(msg == str(wrn.message) for wrn in raised_warning) + @pytest.mark.parametrize('bounds', [None, (20., 100.), [20., 100.], np.array([20., 100.]), @@ -190,10 +237,10 @@ def test_loss_grad_lbfgs(self): X, y = lmnn._prepare_inputs(X, y, dtype=float, ensure_min_samples=2) - num_pts, num_dims = X.shape + num_pts, n_components = X.shape unique_labels, label_inds = np.unique(y, return_inverse=True) lmnn.labels_ = np.arange(len(unique_labels)) - lmnn.transformer_ = np.eye(num_dims) + lmnn.transformer_ = np.eye(n_components) target_neighbors = lmnn._select_targets(X, label_inds) impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds) @@ -243,10 +290,10 @@ def test_toy_ex_lmnn(X, y, loss): X, y = lmnn._prepare_inputs(X, y, dtype=float, ensure_min_samples=2) - num_pts, num_dims = X.shape + num_pts, n_components = X.shape unique_labels, label_inds = np.unique(y, return_inverse=True) lmnn.labels_ = np.arange(len(unique_labels)) - lmnn.transformer_ = np.eye(num_dims) + lmnn.transformer_ = np.eye(n_components) target_neighbors = lmnn._select_targets(X, label_inds) impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds) @@ -336,7 +383,7 @@ def test_sdml_raises_warning_msg_not_installed_skggm(self): # because it will return a non SPD matrix pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) y_pairs = [1, -1] - sdml = SDML(use_cov=False, balance_param=100, verbose=True) + sdml = SDML(prior='identity', balance_param=100, verbose=True) msg = ("There was a problem in SDML when using scikit-learn's graphical " "lasso solver. skggm's graphical lasso can sometimes converge on " @@ -352,14 +399,14 @@ def test_sdml_raises_warning_msg_not_installed_skggm(self): "installed.") def test_sdml_raises_warning_msg_installed_skggm(self): """Tests that the right warning message is raised if someone tries to - use SDML but has not installed skggm, and that the algorithm fails to + use SDML and has installed skggm, and that the algorithm fails to converge""" # TODO: remove if we don't need skggm anymore # case on which we know that skggm's graphical lasso fails # because it will return non finite values pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) y_pairs = [1, -1] - sdml = SDML(use_cov=False, balance_param=100, verbose=True) + sdml = SDML(prior='identity', balance_param=100, verbose=True) msg = ("There was a problem in SDML when using skggm's graphical " "lasso solver.") @@ -382,7 +429,7 @@ def test_sdml_supervised_raises_warning_msg_installed_skggm(self): # pathological case) X = np.array([[-10., 0.], [10., 0.], [5., 0.], [3., 0.]]) y = [0, 0, 1, 1] - sdml_supervised = SDML_Supervised(balance_param=0.5, use_cov=False, + sdml_supervised = SDML_Supervised(balance_param=0.5, prior='identity', sparsity_param=0.01) msg = ("There was a problem in SDML when using skggm's graphical " "lasso solver.") @@ -395,25 +442,27 @@ def test_sdml_supervised_raises_warning_msg_installed_skggm(self): "that no warning should be thrown.") def test_raises_no_warning_installed_skggm(self): # otherwise we should be able to instantiate and fit SDML and it - # should raise no warning + # should raise no error and no ConvergenceWarning pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]]) y_pairs = [1, -1] X, y = make_classification(random_state=42) - with pytest.warns(None) as record: - sdml = SDML() + with pytest.warns(None) as records: + sdml = SDML(prior='covariance') sdml.fit(pairs, y_pairs) - assert len(record) == 0 - with pytest.warns(None) as record: - sdml = SDML_Supervised(use_cov=False, balance_param=1e-5) - sdml.fit(X, y) - assert len(record) == 0 + for record in records: + assert record.category is not ConvergenceWarning + with pytest.warns(None) as records: + sdml_supervised = SDML_Supervised(prior='identity', balance_param=1e-5) + sdml_supervised.fit(X, y) + for record in records: + assert record.category is not ConvergenceWarning def test_iris(self): # Note: this is a flaky test, which fails for certain seeds. # TODO: un-flake it! rs = np.random.RandomState(5555) - sdml = SDML_Supervised(num_constraints=1500, use_cov=False, + sdml = SDML_Supervised(num_constraints=1500, prior='identity', balance_param=5e-5) sdml.fit(self.iris_points, self.iris_labels, random_state=rs) csep = class_separation(sdml.transform(self.iris_points), @@ -425,7 +474,7 @@ def test_deprecation_num_labeled(self): # initialization # TODO: remove in v.0.6 X, y = make_classification(random_state=42) - sdml_supervised = SDML_Supervised(num_labeled=np.inf, use_cov=False, + sdml_supervised = SDML_Supervised(num_labeled=np.inf, prior='identity', balance_param=5e-5) msg = ('"num_labeled" parameter is not used.' ' It has been deprecated in version 0.5.0 and will be' @@ -437,12 +486,12 @@ def test_sdml_raises_warning_non_psd(self): pseudo-covariance matrix is not PSD""" pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) y = [1, -1] - sdml = SDML(use_cov=True, sparsity_param=0.01, balance_param=0.5) + sdml = SDML(prior='covariance', sparsity_param=0.01, balance_param=0.5) msg = ("Warning, the input matrix of graphical lasso is not " "positive semi-definite (PSD). The algorithm may diverge, " "and lead to degenerate solutions. " "To prevent that, try to decrease the balance parameter " - "`balance_param` and/or to set use_cov=False.") + "`balance_param` and/or to set prior='identity'.") with pytest.warns(ConvergenceWarning) as raised_warning: try: sdml.fit(pairs, y) @@ -457,7 +506,7 @@ def test_sdml_converges_if_psd(self): pseudo-covariance matrix is PSD""" pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]]) y = [1, -1] - sdml = SDML(use_cov=True, sparsity_param=0.01, balance_param=0.5) + sdml = SDML(prior='covariance', sparsity_param=0.01, balance_param=0.5) sdml.fit(pairs, y) assert np.isfinite(sdml.get_mahalanobis_matrix()).all() @@ -470,8 +519,56 @@ def test_sdml_works_on_non_spd_pb_with_skggm(self): it should work, but scikit-learn's graphical_lasso does not work""" X, y = load_iris(return_X_y=True) sdml = SDML_Supervised(balance_param=0.5, sparsity_param=0.01, - use_cov=True) - sdml.fit(X, y) + prior='covariance') + sdml.fit(X, y, random_state=np.random.RandomState(42)) + + def test_deprecation_use_cov(self): + # test that a deprecation message is thrown if use_cov is set at + # initialization + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + sdml_supervised = SDML_Supervised(use_cov=np.ones_like(X), + balance_param=1e-5) + msg = ('"use_cov" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. Use "prior" instead.') + with pytest.warns(DeprecationWarning) as raised_warning: + sdml_supervised.fit(X, y) + assert any(msg == str(wrn.message) for wrn in raised_warning) + + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) + y_pairs = [1, -1] + sdml = SDML(use_cov=np.ones_like(X), balance_param=1e-5) + with pytest.warns(DeprecationWarning) as raised_warning: + sdml.fit(pairs, y_pairs) + assert any(msg == str(wrn.message) for wrn in raised_warning) + + def test_changed_behaviour_warning(self): + # test that a ChangedBehavior warning is thrown about the init, if the + # default parameters are used (except for the balance_param that we need + # to set for the algorithm to not diverge) + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + sdml_supervised = SDML_Supervised(balance_param=1e-5) + msg = ("Warning, no prior was set (`prior=None`). As of version 0.5.0, " + "the default prior will now be set to " + "'identity', instead of 'covariance'. If you still want to use " + "the inverse of the covariance matrix as a prior, " + "set prior='covariance'. This warning will disappear in " + "v0.6.0, and `prior` parameter's default value will be set to " + "'identity'.") + with pytest.warns(ChangedBehaviorWarning) as raised_warning: + sdml_supervised.fit(X, y) + assert any(msg == str(wrn.message) for wrn in raised_warning) + + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) + y_pairs = [1, -1] + sdml = SDML(balance_param=1e-5) + with pytest.warns(ChangedBehaviorWarning) as raised_warning: + sdml.fit(pairs, y_pairs) + assert any(msg == str(wrn.message) for wrn in raised_warning) @pytest.mark.skipif(not HAS_SKGGM, @@ -483,7 +580,7 @@ def test_verbose_has_installed_skggm_sdml(capsys): # TODO: remove if we don't need skggm anymore pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]]) y_pairs = [1, -1] - sdml = SDML(verbose=True) + sdml = SDML(verbose=True, prior='covariance') sdml.fit(pairs, y_pairs) out, _ = capsys.readouterr() assert "SDML will use skggm's graphical lasso solver." in out @@ -496,8 +593,8 @@ def test_verbose_has_installed_skggm_sdml_supervised(capsys): # Test that if users have installed skggm, a message is printed telling them # skggm's solver is used (when they use SDML_Supervised) # TODO: remove if we don't need skggm anymore - X, y = make_classification(random_state=42) - sdml = SDML_Supervised(verbose=True) + X, y = load_iris(return_X_y=True) + sdml = SDML_Supervised(verbose=True, prior='identity', balance_param=1e-5) sdml.fit(X, y) out, _ = capsys.readouterr() assert "SDML will use skggm's graphical lasso solver." in out @@ -512,7 +609,7 @@ def test_verbose_has_not_installed_skggm_sdml(capsys): # TODO: remove if we don't need skggm anymore pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]]) y_pairs = [1, -1] - sdml = SDML(verbose=True) + sdml = SDML(verbose=True, prior='covariance') sdml.fit(pairs, y_pairs) out, _ = capsys.readouterr() assert "SDML will use scikit-learn's graphical lasso solver." in out @@ -526,7 +623,7 @@ def test_verbose_has_not_installed_skggm_sdml_supervised(capsys): # skggm's solver is used (when they use SDML_Supervised) # TODO: remove if we don't need skggm anymore X, y = make_classification(random_state=42) - sdml = SDML_Supervised(verbose=True, balance_param=1e-5, use_cov=False) + sdml = SDML_Supervised(verbose=True, balance_param=1e-5, prior='identity') sdml.fit(X, y) out, _ = capsys.readouterr() assert "SDML will use scikit-learn's graphical lasso solver." in out @@ -622,11 +719,8 @@ def test_singleton_class(self): X = X[[ind_0[0], ind_1[0], ind_2[0]]] y = y[[ind_0[0], ind_1[0], ind_2[0]]] - EPS = np.finfo(float).eps - A = np.zeros((X.shape[1], X.shape[1])) - np.fill_diagonal(A, - 1. / (np.maximum(X.max(axis=0) - X.min(axis=0), EPS))) - nca = NCA(max_iter=30, n_components=X.shape[1]) + A = make_spd_matrix(X.shape[1], X.shape[1]) + nca = NCA(init=A, max_iter=30, n_components=X.shape[1]) nca.fit(X, y) assert_array_equal(nca.transformer_, A) @@ -635,18 +729,34 @@ def test_one_class(self): # must stay like the initialization X = self.iris_points[self.iris_labels == 0] y = self.iris_labels[self.iris_labels == 0] - EPS = np.finfo(float).eps - A = np.zeros((X.shape[1], X.shape[1])) - np.fill_diagonal(A, - 1. / (np.maximum(X.max(axis=0) - X.min(axis=0), EPS))) - nca = NCA(max_iter=30, n_components=X.shape[1]) + + A = make_spd_matrix(X.shape[1], X.shape[1]) + nca = NCA(init=A, max_iter=30, n_components=X.shape[1]) nca.fit(X, y) assert_array_equal(nca.transformer_, A) + def test_changed_behaviour_warning(self): + # test that a ChangedBehavior warning is thrown about the init, if the + # default parameters are used. + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + nca = NCA() + msg = ("Warning, no init was set (`init=None`). As of version 0.5.0, " + "the default init will now be set to 'auto', instead of the " + "previous scaling matrix. same scaling matrix as before as an " + "init, set init=np.eye(X.shape[1])/" + "(np.maximum(X.max(axis=0)-X.min(axis=0), EPS))). This warning will" + " disappear in v0.6.0, and `init` parameter's default value will " + "be set to 'auto'.") + with pytest.warns(ChangedBehaviorWarning) as raised_warning: + nca.fit(X, y) + assert any(msg == str(wrn.message) for wrn in raised_warning) + @pytest.mark.parametrize('num_dims', [None, 2]) def test_deprecation_num_dims_nca(num_dims): - # test that a deprecation message is thrown if num_labeled is set at + # test that a deprecation message is thrown if num_dims is set at # initialization # TODO: remove in v.0.6 X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) @@ -674,7 +784,7 @@ def test_iris(self): @pytest.mark.parametrize('num_dims', [None, 2]) def test_deprecation_num_dims_lfda(num_dims): - # test that a deprecation message is thrown if num_labeled is set at + # test that a deprecation message is thrown if num_dims is set at # initialization # TODO: remove in v.0.6 X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) @@ -715,7 +825,7 @@ def test_feature_null_variance(self): @pytest.mark.parametrize('num_dims', [None, 2]) def test_deprecation_num_dims_rca(num_dims): - # test that a deprecation message is thrown if num_labeled is set at + # test that a deprecation message is thrown if num_dims is set at # initialization # TODO: remove in v.0.6 X, y = load_iris(return_X_y=True) @@ -767,10 +877,40 @@ def grad_fn(M): rel_diff = check_grad(fun, grad_fn, M.ravel()) / np.linalg.norm(grad_fn(M)) np.testing.assert_almost_equal(rel_diff, 0.) + def test_deprecation_A0(self): + # test that a deprecation message is thrown if A0 is set at + # initialization + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + mlkr = MLKR(A0=np.ones_like(X)) + msg = ('"A0" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. Use "init" instead.') + with pytest.warns(DeprecationWarning) as raised_warning: + mlkr.fit(X, y) + assert any(msg == str(wrn.message) for wrn in raised_warning) + + def test_changed_behaviour_warning(self): + # test that a ChangedBehavior warning is thrown about the init, if the + # default parameters are used. + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([0.1, 0.2, 0.3, 0.4]) + mlkr = MLKR() + msg = ("Warning, no init was set (`init=None`). As of version 0.5.0, " + "the default init will now be set to 'auto', instead of 'pca'. " + "If you still want to use PCA as an init, set init='pca'. " + "This warning will disappear in v0.6.0, and `init` parameter's" + " default value will be set to 'auto'.") + with pytest.warns(ChangedBehaviorWarning) as raised_warning: + mlkr.fit(X, y) + assert any(msg == str(wrn.message) for wrn in raised_warning) + @pytest.mark.parametrize('num_dims', [None, 2]) def test_deprecation_num_dims_mlkr(num_dims): - # test that a deprecation message is thrown if num_labeled is set at + # test that a deprecation message is thrown if num_dims is set at # initialization # TODO: remove in v.0.6 X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) @@ -794,8 +934,9 @@ def test_iris(self): c, d = np.nonzero(np.triu(~mask, k=1)) # Full metric - mmc = MMC(convergence_threshold=0.01) - mmc.fit(*wrap_pairs(self.iris_points, [a,b,c,d])) + n_features = self.iris_points.shape[1] + mmc = MMC(convergence_threshold=0.01, init=np.eye(n_features) / 10) + mmc.fit(*wrap_pairs(self.iris_points, [a, b, c, d])) expected = [[+0.000514, +0.000868, -0.001195, -0.001703], [+0.000868, +0.001468, -0.002021, -0.002879], [-0.001195, -0.002021, +0.002782, +0.003964], @@ -834,6 +975,53 @@ def test_deprecation_num_labeled(self): ' removed in 0.6.0') assert_warns_message(DeprecationWarning, msg, mmc_supervised.fit, X, y) + def test_deprecation_A0(self): + # test that a deprecation message is thrown if A0 is set at + # initialization + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + mmc_supervised = MMC_Supervised(A0=np.ones_like(X)) + msg = ('"A0" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. Use "init" instead.') + with pytest.warns(DeprecationWarning) as raised_warning: + mmc_supervised.fit(X, y) + assert any(msg == str(wrn.message) for wrn in raised_warning) + + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) + y_pairs = [1, -1] + mmc = MMC(A0=np.ones_like(X)) + with pytest.warns(DeprecationWarning) as raised_warning: + mmc.fit(pairs, y_pairs) + assert any(msg == str(wrn.message) for wrn in raised_warning) + + def test_changed_behaviour_warning(self): + # test that a ChangedBehavior warning is thrown about the init, if the + # default parameters are used. + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + mmc_supervised = MMC_Supervised() + msg = ("Warning, no init was set (`init=None`). As of version 0.5.0, " + "the default init will now be set to 'identity', instead of the " + "identity divided by a scaling factor of 10. " + "If you still want to use the same init as in previous " + "versions, set init=np.eye(d)/10, where d is the dimension " + "of your input space (d=pairs.shape[1]). " + "This warning will disappear in v0.6.0, and `init` parameter's" + " default value will be set to 'auto'.") + with pytest.warns(ChangedBehaviorWarning) as raised_warning: + mmc_supervised.fit(X, y) + assert any(msg == str(wrn.message) for wrn in raised_warning) + + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) + y_pairs = [1, -1] + mmc = MMC() + with pytest.warns(ChangedBehaviorWarning) as raised_warning: + mmc.fit(pairs, y_pairs) + assert any(msg == str(wrn.message) for wrn in raised_warning) + @pytest.mark.parametrize(('algo_class', 'dataset'), [(NCA, make_classification()), diff --git a/test/test_base_metric.py b/test/test_base_metric.py index 7706b1e4..1b312b35 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -21,62 +21,67 @@ def test_covariance(self): def test_lmnn(self): self.assertRegexpMatches( str(metric_learn.LMNN()), - r"(python_)?LMNN\(convergence_tol=0.001, k=3, learn_rate=1e-07, " - r"max_iter=1000,\s+min_iter=50, n_components=None, " - r"num_dims='deprecated',\s+preprocessor=None, " - r"regularization=0.5, use_pca=True, verbose=False\)") + r"(python_)?LMNN\(convergence_tol=0.001, init='auto', k=3, " + r"learn_rate=1e-07,\s+" + r"max_iter=1000, min_iter=50, n_components=None, " + r"num_dims='deprecated',\s+preprocessor=None, random_state=None, " + r"regularization=0.5,\s+use_pca=True, verbose=False\)") def test_nca(self): self.assertEqual(remove_spaces(str(metric_learn.NCA())), - remove_spaces( - "NCA(max_iter=100, n_components=None, " - "num_dims='deprecated', preprocessor=None, " - "tol=None, verbose=False)")) + remove_spaces("NCA(init=None, max_iter=100," + "n_components=None, " + "num_dims='deprecated', " + "preprocessor=None, random_state=None, " + "tol=None, verbose=False)")) def test_lfda(self): self.assertEqual(remove_spaces(str(metric_learn.LFDA())), remove_spaces( - "LFDA(embedding_type='weighted', k=None, " - "n_components=None, num_dims='deprecated'," - "preprocessor=None)")) + "LFDA(embedding_type='weighted', k=None, " + "n_components=None, num_dims='deprecated'," + "preprocessor=None)")) def test_itml(self): self.assertEqual(remove_spaces(str(metric_learn.ITML())), remove_spaces(""" -ITML(A0=None, convergence_threshold=0.001, gamma=1.0, max_iter=1000, - preprocessor=None, verbose=False) +ITML(A0='deprecated', convergence_threshold=0.001, gamma=1.0, + max_iter=1000, preprocessor=None, prior='identity', random_state=None, + verbose=False) """)) self.assertEqual(remove_spaces(str(metric_learn.ITML_Supervised())), remove_spaces(""" -ITML_Supervised(A0=None, bounds='deprecated', convergence_threshold=0.001, - gamma=1.0, max_iter=1000, num_constraints=None, - num_labeled='deprecated', preprocessor=None, verbose=False) +ITML_Supervised(A0='deprecated', bounds='deprecated', + convergence_threshold=0.001, gamma=1.0, + max_iter=1000, num_constraints=None, num_labeled='deprecated', + preprocessor=None, prior='identity', random_state=None, verbose=False) """)) def test_lsml(self): - self.assertEqual( - remove_spaces(str(metric_learn.LSML())), - remove_spaces( - "LSML(max_iter=1000, preprocessor=None, prior=None, tol=0.001, " - "verbose=False)")) + self.assertEqual(remove_spaces(str(metric_learn.LSML())), + remove_spaces(""" +LSML(max_iter=1000, preprocessor=None, prior=None, + random_state=None, tol=0.001, verbose=False) +""")) self.assertEqual(remove_spaces(str(metric_learn.LSML_Supervised())), remove_spaces(""" -LSML_Supervised(max_iter=1000, num_constraints=None, num_labeled='deprecated', - preprocessor=None, prior=None, tol=0.001, verbose=False, - weights=None) +LSML_Supervised(max_iter=1000, num_constraints=None, + num_labeled='deprecated', preprocessor=None, prior=None, + random_state=None, tol=0.001, verbose=False, weights=None) """)) def test_sdml(self): self.assertEqual(remove_spaces(str(metric_learn.SDML())), - remove_spaces( - "SDML(balance_param=0.5, preprocessor=None, " - "sparsity_param=0.01, use_cov=True," - "\n verbose=False)")) + remove_spaces(""" +SDML(balance_param=0.5, preprocessor=None, prior=None, random_state=None, + sparsity_param=0.01, use_cov='deprecated', verbose=False) +""")) self.assertEqual(remove_spaces(str(metric_learn.SDML_Supervised())), remove_spaces(""" SDML_Supervised(balance_param=0.5, num_constraints=None, - num_labeled='deprecated', preprocessor=None, sparsity_param=0.01, - use_cov=True, verbose=False) + num_labeled='deprecated', preprocessor=None, prior=None, + random_state=None, sparsity_param=0.01, use_cov='deprecated', + verbose=False) """)) def test_rca(self): @@ -94,22 +99,26 @@ def test_rca(self): def test_mlkr(self): self.assertEqual(remove_spaces(str(metric_learn.MLKR())), - remove_spaces( - "MLKR(A0=None, max_iter=1000, n_components=None, " - "num_dims='deprecated', " - "preprocessor=None, tol=None, verbose=False)")) + remove_spaces("MLKR(A0='deprecated', init=None," + "max_iter=1000, n_components=None," + "num_dims='deprecated', preprocessor=None," + "random_state=None, tol=None, " + "verbose=False)" + )) def test_mmc(self): self.assertEqual(remove_spaces(str(metric_learn.MMC())), remove_spaces(""" -MMC(A0=None, convergence_threshold=0.001, diagonal=False, diagonal_c=1.0, - max_iter=100, max_proj=10000, preprocessor=None, verbose=False) +MMC(A0='deprecated', convergence_threshold=0.001, diagonal=False, + diagonal_c=1.0, init=None, max_iter=100, max_proj=10000, + preprocessor=None, random_state=None, verbose=False) """)) self.assertEqual(remove_spaces(str(metric_learn.MMC_Supervised())), remove_spaces(""" -MMC_Supervised(A0=None, convergence_threshold=1e-06, diagonal=False, - diagonal_c=1.0, max_iter=100, max_proj=10000, num_constraints=None, - num_labeled='deprecated', preprocessor=None, verbose=False) +MMC_Supervised(A0='deprecated', convergence_threshold=1e-06, diagonal=False, + diagonal_c=1.0, init=None, max_iter=100, max_proj=10000, + num_constraints=None, num_labeled='deprecated', preprocessor=None, + random_state=None, verbose=False) """)) diff --git a/test/test_fit_transform.py b/test/test_fit_transform.py index 5e8a87f4..b7255ea9 100644 --- a/test/test_fit_transform.py +++ b/test/test_fit_transform.py @@ -65,13 +65,13 @@ def test_lmnn(self): def test_sdml_supervised(self): seed = np.random.RandomState(1234) sdml = SDML_Supervised(num_constraints=1500, balance_param=1e-5, - use_cov=False) + prior='identity') sdml.fit(self.X, self.y, random_state=seed) res_1 = sdml.transform(self.X) seed = np.random.RandomState(1234) sdml = SDML_Supervised(num_constraints=1500, balance_param=1e-5, - use_cov=False) + prior='identity') res_2 = sdml.fit_transform(self.X, self.y, random_state=seed) assert_array_almost_equal(res_1, res_2) diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index e7fa5b17..54c37936 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -2,19 +2,24 @@ import pytest import numpy as np +from numpy.linalg import LinAlgError from numpy.testing import assert_array_almost_equal, assert_allclose from scipy.spatial.distance import pdist, squareform, mahalanobis +from scipy.stats import ortho_group from sklearn import clone from sklearn.cluster import DBSCAN +from sklearn.datasets import make_spd_matrix from sklearn.utils import check_random_state +from sklearn.utils.multiclass import type_of_target from sklearn.utils.testing import set_random_state from metric_learn._util import make_context from metric_learn.base_metric import (_QuadrupletsClassifierMixin, _PairsClassifierMixin) +from metric_learn.exceptions import NonPSDError from test.test_utils import (ids_metric_learners, metric_learners, - remove_y_quadruplets) + remove_y_quadruplets, ids_classifiers) RNG = check_random_state(0) @@ -56,7 +61,7 @@ def test_score_pairs_toy_example(estimator, build_dataset): pairs = np.stack([X[:10], X[10:20]], axis=1) embedded_pairs = pairs.dot(model.transformer_.T) distances = np.sqrt(np.sum((embedded_pairs[:, 1] - - embedded_pairs[:, 0])**2, + embedded_pairs[:, 0])**2, axis=-1)) assert_array_almost_equal(model.score_pairs(pairs), distances) @@ -190,7 +195,7 @@ def test_get_metric_equivalent_to_explicit_mahalanobis(estimator, a, b = (rng.randn(n_features), rng.randn(n_features)) expected_dist = mahalanobis(a[None], b[None], VI=model.get_mahalanobis_matrix()) - assert_allclose(metric(a, b), expected_dist, rtol=1e-15) + assert_allclose(metric(a, b), expected_dist, rtol=1e-13) @pytest.mark.parametrize('estimator, build_dataset', metric_learners, @@ -300,3 +305,349 @@ def test_transformer_is_2D(estimator, build_dataset): labels = labels[to_keep] model.fit(*remove_y_quadruplets(estimator, trunc_data, labels)) assert model.transformer_.shape == (1, 1) # the transformer must be 2D + + +@pytest.mark.parametrize('estimator, build_dataset', + [(ml, bd) for idml, (ml, bd) + in zip(ids_metric_learners, + metric_learners) + if hasattr(ml, 'n_components') and + hasattr(ml, 'init')], + ids=[idml for idml, (ml, _) + in zip(ids_metric_learners, + metric_learners) + if hasattr(ml, 'n_components') and + hasattr(ml, 'init')]) +def test_init_transformation(estimator, build_dataset): + input_data, labels, _, X = build_dataset() + is_classification = (type_of_target(labels) in ['multiclass', 'binary']) + model = clone(estimator) + rng = np.random.RandomState(42) + + # Start learning from scratch + model.set_params(init='identity') + model.fit(input_data, labels) + + # Initialize with random + model.set_params(init='random') + model.fit(input_data, labels) + + # Initialize with auto + model.set_params(init='auto') + model.fit(input_data, labels) + + # Initialize with PCA + model.set_params(init='pca') + model.fit(input_data, labels) + + # Initialize with LDA + if is_classification: + model.set_params(init='lda') + model.fit(input_data, labels) + + # Initialize with a numpy array + init = rng.rand(X.shape[1], X.shape[1]) + model.set_params(init=init) + model.fit(input_data, labels) + + # init.shape[1] must match X.shape[1] + init = rng.rand(X.shape[1], X.shape[1] + 1) + model.set_params(init=init) + msg = ('The input dimensionality ({}) of the given ' + 'linear transformation `init` must match the ' + 'dimensionality of the given inputs `X` ({}).' + .format(init.shape[1], X.shape[1])) + with pytest.raises(ValueError) as raised_error: + model.fit(input_data, labels) + assert str(raised_error.value) == msg + + # init.shape[0] must be <= init.shape[1] + init = rng.rand(X.shape[1] + 1, X.shape[1]) + model.set_params(init=init) + msg = ('The output dimensionality ({}) of the given ' + 'linear transformation `init` cannot be ' + 'greater than its input dimensionality ({}).' + .format(init.shape[0], init.shape[1])) + with pytest.raises(ValueError) as raised_error: + model.fit(input_data, labels) + assert str(raised_error.value) == msg + + # init.shape[0] must match n_components + init = rng.rand(X.shape[1], X.shape[1]) + n_components = X.shape[1] - 1 + model.set_params(init=init, n_components=n_components) + msg = ('The preferred dimensionality of the ' + 'projected space `n_components` ({}) does not match ' + 'the output dimensionality of the given ' + 'linear transformation `init` ({})!' + .format(n_components, init.shape[0])) + with pytest.raises(ValueError) as raised_error: + model.fit(input_data, labels) + assert str(raised_error.value) == msg + + # init must be as specified in the docstring + model.set_params(init=1) + msg = ("`init` must be 'auto', 'pca', 'identity', " + "'random'{} or a numpy array of shape " + "(n_components, n_features)." + .format(", 'lda'" if is_classification else '')) + with pytest.raises(ValueError) as raised_error: + model.fit(input_data, labels) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('n_samples', [3, 5, 7, 11]) +@pytest.mark.parametrize('n_features', [3, 5, 7, 11]) +@pytest.mark.parametrize('n_classes', [5, 7, 11]) +@pytest.mark.parametrize('n_components', [3, 5, 7, 11]) +@pytest.mark.parametrize('estimator, build_dataset', + [(ml, bd) for idml, (ml, bd) + in zip(ids_metric_learners, + metric_learners) + if hasattr(ml, 'n_components') and + hasattr(ml, 'init')], + ids=[idml for idml, (ml, _) + in zip(ids_metric_learners, + metric_learners) + if hasattr(ml, 'n_components') and + hasattr(ml, 'init')]) +def test_auto_init_transformation(n_samples, n_features, n_classes, + n_components, estimator, build_dataset): + # Test that auto choose the init transformation as expected with every + # configuration of order of n_samples, n_features, n_classes and + # n_components, for all metric learners that learn a transformation. + if n_classes >= n_samples: + pass + # n_classes > n_samples is impossible, and n_classes == n_samples + # throws an error from lda but is an absurd case + else: + input_data, labels, _, X = build_dataset() + model_base = clone(estimator) + rng = np.random.RandomState(42) + model_base.set_params(init='auto', + n_components=n_components, + random_state=rng) + # To make the test work for LMNN: + if 'LMNN' in model_base.__class__.__name__: + model_base.set_params(k=1) + # To make the test faster for estimators that have a max_iter: + if hasattr(model_base, 'max_iter'): + model_base.set_params(max_iter=1) + if n_components > n_features: + # this would return a ValueError, which is tested in + # test_init_transformation + pass + else: + # We need to build a dataset of the right shape: + num_to_pad_n_samples = ((n_samples // input_data.shape[0] + 1)) + num_to_pad_n_features = ((n_features // input_data.shape[-1] + 1)) + if input_data.ndim == 3: + input_data = np.tile(input_data, + (num_to_pad_n_samples, input_data.shape[1], + num_to_pad_n_features)) + else: + input_data = np.tile(input_data, + (num_to_pad_n_samples, num_to_pad_n_features)) + input_data = input_data[:n_samples, ..., :n_features] + assert input_data.shape[0] == n_samples + assert input_data.shape[-1] == n_features + has_classes = model_base.__class__.__name__ in ids_classifiers + if has_classes: + labels = np.tile(range(n_classes), n_samples // + n_classes + 1)[:n_samples] + else: + labels = np.tile(labels, n_samples // labels.shape[0] + 1)[:n_samples] + model = clone(model_base) + model.fit(input_data, labels) + if n_components <= min(n_classes - 1, n_features) and has_classes: + model_other = clone(model_base).set_params(init='lda') + elif n_components < min(n_features, n_samples): + model_other = clone(model_base).set_params(init='pca') + else: + model_other = clone(model_base).set_params(init='identity') + model_other.fit(input_data, labels) + assert_array_almost_equal(model.transformer_, + model_other.transformer_) + + +@pytest.mark.parametrize('estimator, build_dataset', + [(ml, bd) for idml, (ml, bd) + in zip(ids_metric_learners, + metric_learners) + if not hasattr(ml, 'n_components') and + hasattr(ml, 'init')], + ids=[idml for idml, (ml, _) + in zip(ids_metric_learners, + metric_learners) + if not hasattr(ml, 'n_components') and + hasattr(ml, 'init')]) +def test_init_mahalanobis(estimator, build_dataset): + """Tests that for estimators that learn a mahalanobis matrix + instead of a transformer, i.e. those that are mahalanobis metric learners + where we can change the init, but not choose the n_components, + (TODO: be more explicit on this characterization, for instance with + safe_flags like in scikit-learn) that the init has an expected behaviour. + """ + input_data, labels, _, X = build_dataset() + + matrices_to_set = [] + if hasattr(estimator, 'init'): + matrices_to_set.append('init') + if hasattr(estimator, 'prior'): + matrices_to_set.append('prior') + + for param in matrices_to_set: + model = clone(estimator) + set_random_state(model) + rng = np.random.RandomState(42) + + # Start learning from scratch + model.set_params(**{param: 'identity'}) + model.fit(input_data, labels) + + # Initialize with random + model.set_params(**{param: 'random'}) + model.fit(input_data, labels) + + # Initialize with covariance + model.set_params(**{param: 'covariance'}) + model.fit(input_data, labels) + + # Initialize with a random spd matrix + init = make_spd_matrix(X.shape[1], random_state=rng) + model.set_params(**{param: init}) + model.fit(input_data, labels) + + # init.shape[1] must match X.shape[1] + init = make_spd_matrix(X.shape[1] + 1, X.shape[1] + 1) + model.set_params(**{param: init}) + msg = ('The input dimensionality {} of the given ' + 'mahalanobis matrix `{}` must match the ' + 'dimensionality of the given inputs ({}).' + .format(init.shape, param, input_data.shape[-1])) + + with pytest.raises(ValueError) as raised_error: + model.fit(input_data, labels) + assert str(raised_error.value) == msg + + # The input matrix must be symmetric + init = rng.rand(X.shape[1], X.shape[1]) + model.set_params(**{param: init}) + msg = ("`{}` is not symmetric.".format(param)) + with pytest.raises(ValueError) as raised_error: + model.fit(input_data, labels) + assert str(raised_error.value) == msg + + # The input matrix must be SPD + P = ortho_group.rvs(X.shape[1], random_state=rng) + w = np.abs(rng.randn(X.shape[1])) + w[0] = -10. + M = P.dot(np.diag(w)).dot(P.T) + model.set_params(**{param: M}) + msg = ("Matrix is not positive semidefinite (PSD).") + with pytest.raises(NonPSDError) as raised_err: + model.fit(input_data, labels) + assert str(raised_err.value) == msg + + # init must be as specified in the docstring + model.set_params(**{param: 1}) + msg = ("`{}` must be 'identity', 'covariance', " + "'random' or a numpy array of shape " + "(n_features, n_features).".format(param)) + with pytest.raises(ValueError) as raised_error: + model.fit(input_data, labels) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, build_dataset', + [(ml, bd) for idml, (ml, bd) + in zip(ids_metric_learners, + metric_learners) + if idml[:4] in ['ITML', 'SDML', 'LSML']], + ids=[idml for idml, (ml, _) + in zip(ids_metric_learners, + metric_learners) + if idml[:4] in ['ITML', 'SDML', 'LSML']]) +def test_singular_covariance_init_or_prior(estimator, build_dataset): + """Tests that when using the 'covariance' init or prior, it returns the + appropriate error if the covariance matrix is singular, for algorithms + that need a strictly PD prior or init (see + https://github.com/metric-learn/metric-learn/issues/202 and + https://github.com/metric-learn/metric-learn/pull/195#issuecomment + -492332451) + """ + matrices_to_set = [] + if hasattr(estimator, 'init'): + matrices_to_set.append('init') + if hasattr(estimator, 'prior'): + matrices_to_set.append('prior') + + input_data, labels, _, X = build_dataset() + for param in matrices_to_set: + model = clone(estimator) + set_random_state(model) + # We create a feature that is a linear combination of the first two + # features: + input_data = np.concatenate([input_data, input_data[:, ..., :2] + .dot([[2], [3]])], + axis=-1) + model.set_params(**{param: 'covariance'}) + msg = ("Unable to get a true inverse of the covariance " + "matrix since it is not definite. Try another " + "`{}`, or an algorithm that does not " + "require the `{}` to be strictly positive definite." + .format(param, param)) + with pytest.raises(LinAlgError) as raised_err: + model.fit(input_data, labels) + assert str(raised_err.value) == msg + + +@pytest.mark.integration +@pytest.mark.parametrize('estimator, build_dataset', + [(ml, bd) for idml, (ml, bd) + in zip(ids_metric_learners, + metric_learners) + if idml[:4] in ['ITML', 'SDML', 'LSML']], + ids=[idml for idml, (ml, _) + in zip(ids_metric_learners, + metric_learners) + if idml[:4] in ['ITML', 'SDML', 'LSML']]) +@pytest.mark.parametrize('w0', [1e-20, 0., -1e-20]) +def test_singular_array_init_or_prior(estimator, build_dataset, w0): + """Tests that when using a custom array init (or prior), it returns the + appropriate error if it is singular, for algorithms + that need a strictly PD prior or init (see + https://github.com/metric-learn/metric-learn/issues/202 and + https://github.com/metric-learn/metric-learn/pull/195#issuecomment + -492332451) + """ + matrices_to_set = [] + if hasattr(estimator, 'init'): + matrices_to_set.append('init') + if hasattr(estimator, 'prior'): + matrices_to_set.append('prior') + + rng = np.random.RandomState(42) + input_data, labels, _, X = build_dataset() + for param in matrices_to_set: + model = clone(estimator) + set_random_state(model) + + P = ortho_group.rvs(X.shape[1], random_state=rng) + w = np.abs(rng.randn(X.shape[1])) + w[0] = w0 + M = P.dot(np.diag(w)).dot(P.T) + if hasattr(model, 'init'): + model.set_params(init=M) + if hasattr(model, 'prior'): + model.set_params(prior=M) + if not hasattr(model, 'prior') and not hasattr(model, 'init'): + raise RuntimeError("Neither prior or init could be set in the model.") + msg = ("You should provide a strictly positive definite " + "matrix as `{}`. This one is not definite. Try another" + " {}, or an algorithm that does not " + "require the {} to be strictly positive definite." + .format(*(param,) * 3)) + with pytest.raises(LinAlgError) as raised_err: + model.fit(input_data, labels) + assert str(raised_err.value) == msg diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index 6b451aee..0c0f098d 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -85,15 +85,15 @@ def stable_init(self, sparsity_param=0.01, num_labeled='deprecated', num_constraints=num_constraints, verbose=verbose, preprocessor=preprocessor, - balance_param=1e-5, use_cov=False) + balance_param=1e-5, prior='identity') dSDML.__init__ = stable_init check_estimator(dSDML) def test_rca(self): - def stable_init(self, num_dims=None, pca_comps=None, + def stable_init(self, n_components=None, pca_comps=None, chunk_size=2, preprocessor=None): # this init makes RCA stable for scikit-learn examples. - RCA_Supervised.__init__(self, num_chunks=2, num_dims=num_dims, + RCA_Supervised.__init__(self, num_chunks=2, n_components=n_components, pca_comps=pca_comps, chunk_size=chunk_size, preprocessor=preprocessor) dRCA.__init__ = stable_init diff --git a/test/test_transformer_metric_conversion.py b/test/test_transformer_metric_conversion.py index 0139f632..651f60ea 100644 --- a/test/test_transformer_metric_conversion.py +++ b/test/test_transformer_metric_conversion.py @@ -11,6 +11,7 @@ LMNN, NCA, LFDA, Covariance, MLKR, LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised) from metric_learn._util import transformer_from_metric +from metric_learn.exceptions import NonPSDError class TestTransformerMetricConversion(unittest.TestCase): @@ -49,7 +50,7 @@ def test_lmnn(self): def test_sdml_supervised(self): seed = np.random.RandomState(1234) - sdml = SDML_Supervised(num_constraints=1500, use_cov=False, + sdml = SDML_Supervised(num_constraints=1500, prior='identity', balance_param=1e-5) sdml.fit(self.X, self.y, random_state=seed) L = sdml.transformer_ @@ -162,10 +163,10 @@ def test_non_psd_raises(self): P = ortho_group.rvs(7, random_state=rng) M = P.dot(D).dot(P.T) msg = ("Matrix is not positive semidefinite (PSD).") - with pytest.raises(ValueError) as raised_error: + with pytest.raises(NonPSDError) as raised_error: transformer_from_metric(M) assert str(raised_error.value) == msg - with pytest.raises(ValueError) as raised_error: + with pytest.raises(NonPSDError) as raised_error: transformer_from_metric(D) assert str(raised_error.value) == msg diff --git a/test/test_utils.py b/test/test_utils.py index 08415a76..2e57f489 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -11,7 +11,8 @@ make_name, preprocess_points, check_collapsed_pairs, validate_vector, _check_sdp_from_eigen, _check_n_components, - check_y_valid_values_for_pairs) + check_y_valid_values_for_pairs, + _auto_select_init) from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA, LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised, MMC_Supervised, RCA_Supervised, SDML_Supervised, @@ -19,7 +20,7 @@ from metric_learn.base_metric import (ArrayIndexer, MahalanobisMixin, _PairsClassifierMixin, _QuadrupletsClassifierMixin) -from metric_learn.exceptions import PreprocessorError +from metric_learn.exceptions import PreprocessorError, NonPSDError from sklearn.datasets import make_regression, make_blobs, load_iris @@ -104,7 +105,7 @@ def build_quadruplets(with_preprocessor=False): pairs_learners = [(ITML(max_iter=2), build_pairs), # max_iter=2 to be faster (MMC(max_iter=2), build_pairs), # max_iter=2 to be faster - (SDML(use_cov=False, balance_param=1e-5), build_pairs)] + (SDML(prior='identity', balance_param=1e-5), build_pairs)] ids_pairs_learners = list(map(lambda x: x.__class__.__name__, [learner for (learner, _) in pairs_learners])) @@ -118,13 +119,13 @@ def build_quadruplets(with_preprocessor=False): (LSML_Supervised(), build_classification), (MMC_Supervised(max_iter=5), build_classification), (RCA_Supervised(num_chunks=10), build_classification), - (SDML_Supervised(use_cov=False, balance_param=1e-5), + (SDML_Supervised(prior='identity', balance_param=1e-5), build_classification)] ids_classifiers = list(map(lambda x: x.__class__.__name__, [learner for (learner, _) in classifiers])) -regressors = [(MLKR(), build_regression)] +regressors = [(MLKR(init='pca'), build_regression)] ids_regressors = list(map(lambda x: x.__class__.__name__, [learner for (learner, _) in regressors])) @@ -993,7 +994,7 @@ def test__validate_vector(): validate_vector(x) -def test_check_sdp_from_eigen_positive_err_messages(): +def test__check_sdp_from_eigen_positive_err_messages(): """Tests that if _check_sdp_from_eigen is given a negative tol it returns an error, and if positive (or None) it does not""" w = np.abs(np.random.RandomState(42).randn(10)) + 1 @@ -1008,6 +1009,37 @@ def test_check_sdp_from_eigen_positive_err_messages(): _check_sdp_from_eigen(w, None) +@pytest.mark.unit +@pytest.mark.parametrize('w', [np.array([-1.2, 5.5, 6.6]), + np.array([-1.2, -5.6])]) +def test__check_sdp_from_eigen_positive_eigenvalues(w): + """Tests that _check_sdp_from_eigen, returns a NonPSDError when + the eigenvalues are negatives or null.""" + with pytest.raises(NonPSDError): + _check_sdp_from_eigen(w) + + +@pytest.mark.unit +@pytest.mark.parametrize('w', [np.array([0., 2.3, 5.3]), + np.array([1e-20, 3.5]), + np.array([1.5, 2.4, 4.6])]) +def test__check_sdp_from_eigen_negative_eigenvalues(w): + """Tests that _check_sdp_from_eigen, returns no error when the + eigenvalues are positive.""" + _check_sdp_from_eigen(w) + + +@pytest.mark.unit +@pytest.mark.parametrize('w, is_definite', [(np.array([1e-15, 5.6]), False), + (np.array([-1e-15, 5.6]), False), + (np.array([3.2, 5.6, 0.01]), True), + ]) +def test__check_sdp_from_eigen_returns_definiteness(w, is_definite): + """Tests that _check_sdp_from_eigen returns the definiteness of the + matrix (when it is PSD), based on the given eigenvalues""" + assert _check_sdp_from_eigen(w) == is_definite + + def test__check_n_components(): """Checks that n_components returns what is expected (including the errors)""" @@ -1094,3 +1126,23 @@ def test_check_input_pairs_learners_invalid_y(estimator, build_dataset, with pytest.raises(ValueError) as raised_error: model.fit(input_data, wrong_labels) assert str(raised_error.value) == expected_msg + + +@pytest.mark.parametrize('has_classes, n_features, n_samples, n_components, ' + 'n_classes, result', + [(False, 3, 20, 3, 0, 'identity'), + (False, 3, 2, 3, 0, 'identity'), + (False, 5, 3, 4, 0, 'identity'), + (False, 4, 5, 3, 0, 'pca'), + (True, 5, 6, 3, 4, 'lda'), + (True, 6, 3, 3, 3, 'identity'), + (True, 5, 6, 4, 2, 'pca'), + (True, 2, 6, 2, 10, 'lda'), + (True, 4, 6, 2, 3, 'lda') + ]) +def test__auto_select_init(has_classes, n_features, n_samples, n_components, + n_classes, + result): + """Checks that the auto selection of the init works as expected""" + assert (_auto_select_init(has_classes, n_features, + n_samples, n_components, n_classes) == result)