1
1
import warnings
2
- from typing import Callable , Union
2
+ from typing import Callable , Optional , Union
3
3
4
4
import numpy as np
5
- from numpy .typing import ArrayLike
6
5
from scipy .special import logsumexp as LSE
7
6
from sklearn .base import BaseEstimator
8
7
from sklearn .utils .validation import check_is_fitted , check_random_state
@@ -33,19 +32,18 @@ class SparseKDE(BaseEstimator):
33
32
weights: numpy.ndarray, default=None
34
33
Weights of the descriptors.
35
34
If None, all weights are set to `1/n_descriptors`.
36
- metric : Callable[[ArrayLike, ArrayLike, bool, dict], ArrayLike],
37
- default=:func:`skmatter.metrics.pairwise_euclidean_distances()`
35
+ metric : Callable, default=None
38
36
The metric to use. Your metric should be able to take at least three arguments
39
37
in secquence: `X`, `Y`, and `squared=True`. Here, `X` and `Y` are two array-like
40
38
of shape (n_samples, n_components). The return of the metric is an array-like of
41
- shape (n_samples, n_samples). If you want to use periodic boundary
42
- conditions, be sure to provide the cell size in the metric_params and
43
- provide a metric that can take the cell argument.
39
+ shape (n_samples, n_samples). If you want to use periodic boundary conditions,
40
+ be sure to provide the cell size in the metric_params and provide a metric that
41
+ can take the cell argument. If :obj:`None`, the
42
+ :func:`skmatter.metrics.periodic_pairwise_euclidean_distances()` is used.
44
43
metric_params : dict, default=None
45
- Additional parameters to be passed to the use of
46
- metric. i.e. the cell dimension for
47
- :func:`skmatter.metrics.pairwise_euclidean_distances()`
48
- `{'cell_length': [side_length_1, ..., side_length_n]}`
44
+ Additional parameters to be passed to the use of metric. i.e. the cell
45
+ dimension for :func:`skmatter.metrics.periodic_pairwise_euclidean_distances()`
46
+ ``{'cell_length': [side_length_1, ..., side_length_n]}``
49
47
fspread : float, default=-1.0
50
48
The fractional "space" occupied by the voronoi cell of each grid. Use this when
51
49
each cell is of a similar size.
@@ -106,11 +104,9 @@ class SparseKDE(BaseEstimator):
106
104
def __init__ (
107
105
self ,
108
106
descriptors : np .ndarray ,
109
- weights : Union [np .ndarray , None ] = None ,
110
- metric : Callable [
111
- [ArrayLike , ArrayLike , bool , dict ], ArrayLike
112
- ] = periodic_pairwise_euclidean_distances ,
113
- metric_params : Union [dict , None ] = None ,
107
+ weights : Optional [np .ndarray ] = None ,
108
+ metric : Optional [Callable ] = None ,
109
+ metric_params : Optional [dict ] = None ,
114
110
fspread : float = - 1.0 ,
115
111
fpoints : float = 0.15 ,
116
112
kernel : str = "gaussian" ,
@@ -119,6 +115,10 @@ def __init__(
119
115
self .metric_params = (
120
116
metric_params if metric_params is not None else {"cell_length" : None }
121
117
)
118
+
119
+ if metric is None :
120
+ metric = periodic_pairwise_euclidean_distances
121
+
122
122
self .metric = lambda X , Y : metric (X , Y , squared = True , ** self .metric_params )
123
123
self .cell = metric_params ["cell_length" ] if metric_params is not None else None
124
124
self ._check_dimension (descriptors )
0 commit comments