Skip to content

Commit 71fa54c

Browse files
authored
Essential traits (rust-ml#221)
* Add traits * Add traits and autotraits to bayes crate * Add traits and autotrait test to clustering crate * Add traits and autotraits test to elasticnet crate * Add traits and autotraits test to ftrl crate * Add traits and autotraits test to hierarchical crate * Add traits and autotraits test to ica crate * Add traits and autotraits test to kernel crate * Add traits and autotraits test to linear crate * Add traits and autotraits test to logistic crate * Add traits and autotraits test to nn crate * Add traits and autotraits test to pls crate * Add traits and autotraits test to preprocessing crate * Add traits and autotraits test to reduction crate * Add traits and autotraits test to svm crate * Add traits and autotraits test to trees crate * Add traits and autotraits test to tsne crate * Remove PartialOrd in SvmValidParams * Remove extra traits in errors. Leave Hash for simple enums. Fix some traits * Remove PartialOrd. Minor changes * Add bounds to Iter, GaussianNb, MultinomialNb, OpticsAnalysis
1 parent 8b54b04 commit 71fa54c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+507
-114
lines changed

algorithms/linfa-bayes/src/gaussian_nb.rs

+16-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use linfa::{Float, Label};
44
use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2};
55
use ndarray_stats::QuantileExt;
66
use std::collections::HashMap;
7+
use std::hash::Hash;
78

89
use crate::base_nb::{filter, NaiveBayes, NaiveBayesValidParams};
910
use crate::error::{NaiveBayesError, Result};
@@ -225,12 +226,12 @@ where
225226
/// let model = checked_params.fit_with(Some(model), &ds)?;
226227
/// # Result::Ok(())
227228
/// ```
228-
#[derive(Debug, Clone)]
229-
pub struct GaussianNb<F, L> {
229+
#[derive(Debug, Clone, PartialEq)]
230+
pub struct GaussianNb<F: PartialEq, L: Eq + Hash> {
230231
class_info: HashMap<L, GaussianClassInfo<F>>,
231232
}
232233

233-
#[derive(Debug, Default, Clone)]
234+
#[derive(Debug, Default, Clone, PartialEq)]
234235
struct GaussianClassInfo<F> {
235236
class_count: usize,
236237
prior: F,
@@ -284,10 +285,22 @@ mod tests {
284285
DatasetView,
285286
};
286287

288+
use crate::gaussian_nb::GaussianClassInfo;
289+
use crate::{GaussianNbParams, GaussianNbValidParams, NaiveBayesError};
287290
use approx::assert_abs_diff_eq;
288291
use ndarray::{array, Axis};
289292
use std::collections::HashMap;
290293

294+
#[test]
295+
fn autotraits() {
296+
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
297+
has_autotraits::<GaussianNb<f64, usize>>();
298+
has_autotraits::<GaussianClassInfo<f64>>();
299+
has_autotraits::<GaussianNbParams<f64, usize>>();
300+
has_autotraits::<GaussianNbValidParams<f64, usize>>();
301+
has_autotraits::<NaiveBayesError>();
302+
}
303+
291304
#[test]
292305
fn test_gaussian_nb() -> Result<()> {
293306
let x = array![

algorithms/linfa-bayes/src/hyperparams.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::marker::PhantomData;
55
/// A verified hyper-parameter set ready for the estimation of a [Gaussian Naive Bayes model](crate::gaussian_nb::GaussianNb).
66
///
77
/// See [`GaussianNb`](crate::gaussian_nb::GaussianNb) for information on the model and [`GaussianNbParams`](crate::hyperparams::GaussianNbParams) for information on hyperparameters.
8-
#[derive(Debug)]
8+
#[derive(Debug, Clone, PartialEq)]
99
pub struct GaussianNbValidParams<F, L> {
1010
// Required for calculation stability
1111
var_smoothing: F,
@@ -43,6 +43,7 @@ impl<F: Float, L> GaussianNbValidParams<F, L> {
4343
/// Returns [`InvalidSmoothing`](NaiveBayesError::InvalidSmoothing) if the smoothing
4444
/// parameter is negative.
4545
///
46+
#[derive(Debug, Clone, PartialEq)]
4647
pub struct GaussianNbParams<F, L>(GaussianNbValidParams<F, L>);
4748

4849
impl<F: Float, L> Default for GaussianNbParams<F, L> {
@@ -91,7 +92,7 @@ impl<F: Float, L> ParamGuard for GaussianNbParams<F, L> {
9192
/// A verified hyper-parameter set ready for the estimation of a [Multinomial Naive Bayes model](crate::multinomial_nb::MultinomialNb).
9293
///
9394
/// See [`MultinomialNb`](crate::multinomial_nb::MultinomialNb) for information on the model and [`MultinomialNbParams`](crate::hyperparams::MultinomialNbParams) for information on hyperparameters.
94-
#[derive(Debug)]
95+
#[derive(Debug, Clone, PartialEq)]
9596
pub struct MultinomialNbValidParams<F, L> {
9697
// Required for calculation stability
9798
alpha: F,
@@ -129,6 +130,7 @@ impl<F: Float, L> MultinomialNbValidParams<F, L> {
129130
/// Returns [`InvalidSmoothing`](NaiveBayesError::InvalidSmoothing) if the smoothing
130131
/// parameter is negative.
131132
///
133+
#[derive(Debug, Clone, PartialEq)]
132134
pub struct MultinomialNbParams<F, L>(MultinomialNbValidParams<F, L>);
133135

134136
impl<F: Float, L> Default for MultinomialNbParams<F, L> {

algorithms/linfa-bayes/src/multinomial_nb.rs

+15-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use linfa::traits::{Fit, FitWith, PredictInplace};
33
use linfa::{Float, Label};
44
use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2};
55
use std::collections::HashMap;
6+
use std::hash::Hash;
67

78
use crate::base_nb::{filter, NaiveBayes, NaiveBayesValidParams};
89
use crate::error::{NaiveBayesError, Result};
@@ -195,12 +196,12 @@ where
195196
/// let model = checked_params.fit_with(Some(model), &ds)?;
196197
/// # Result::Ok(())
197198
/// ```
198-
#[derive(Debug, Clone)]
199-
pub struct MultinomialNb<F, L> {
199+
#[derive(Debug, Clone, PartialEq)]
200+
pub struct MultinomialNb<F: PartialEq, L: Eq + Hash> {
200201
class_info: HashMap<L, MultinomialClassInfo<F>>,
201202
}
202203

203-
#[derive(Debug, Default, Clone)]
204+
#[derive(Debug, Default, Clone, PartialEq)]
204205
struct MultinomialClassInfo<F> {
205206
class_count: usize,
206207
prior: F,
@@ -242,10 +243,21 @@ mod tests {
242243
DatasetView,
243244
};
244245

246+
use crate::multinomial_nb::MultinomialClassInfo;
247+
use crate::{MultinomialNbParams, MultinomialNbValidParams};
245248
use approx::assert_abs_diff_eq;
246249
use ndarray::{array, Axis};
247250
use std::collections::HashMap;
248251

252+
#[test]
253+
fn autotraits() {
254+
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
255+
has_autotraits::<MultinomialNb<f64, usize>>();
256+
has_autotraits::<MultinomialClassInfo<f64>>();
257+
has_autotraits::<MultinomialNbValidParams<f64, usize>>();
258+
has_autotraits::<MultinomialNbParams<f64, usize>>();
259+
}
260+
249261
#[test]
250262
fn test_multinomial_nb() -> Result<()> {
251263
let x = array![[1., 0.], [2., 0.], [3., 0.], [0., 1.], [0., 2.], [0., 3.]];

algorithms/linfa-clustering/src/appx_dbscan/algorithm.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use super::cells_grid::CellsGrid;
1212
derive(Serialize, Deserialize),
1313
serde(crate = "serde_crate")
1414
)]
15-
#[derive(Clone, Debug, PartialEq)]
15+
#[derive(Clone, Debug, PartialEq, Eq)]
1616
/// DBSCAN (Density-based Spatial Clustering of Applications with Noise)
1717
/// clusters together neighbouring points, while points in sparse regions are labelled
1818
/// as noise. Since points may be part of a cluster or noise the transform method returns

algorithms/linfa-clustering/src/appx_dbscan/cells_grid/cell.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use linfa_nn::distance::{Distance, L2Dist};
55
use ndarray::{Array1, ArrayView1, ArrayView2, ArrayViewMut1};
66
use partitions::PartitionVec;
77

8-
#[derive(Clone)]
8+
#[derive(Clone, Debug, PartialEq, Eq)]
99
/// A point in a D dimensional euclidean space that memorizes its
1010
/// status: 'core' or 'non core'
1111
pub struct StatusPoint {
@@ -31,7 +31,7 @@ impl StatusPoint {
3131
}
3232
}
3333

34-
#[derive(Clone)]
34+
#[derive(Clone, Debug, PartialEq)]
3535
/// Informations regarding the cell used in various stages of the approximate DBSCAN
3636
/// algorithm if it is a core cell
3737
pub struct CoreCellInfo<F: Float> {
@@ -41,7 +41,7 @@ pub struct CoreCellInfo<F: Float> {
4141
i_cluster: usize,
4242
}
4343

44-
#[derive(Clone)]
44+
#[derive(Clone, Debug, PartialEq)]
4545
/// A cell from a grid that partitions the D dimensional euclidean space.
4646
pub struct Cell<F: Float> {
4747
/// The index of the intervals of the D dimensional axes where this cell lies

algorithms/linfa-clustering/src/appx_dbscan/cells_grid/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pub type CellVector<F> = PartitionVec<Cell<F>>;
1515
/// A structure that memorizes all non empty cells by their index's hash
1616
pub type CellTable = HashMap<Array1<i64>, usize>;
1717

18+
#[derive(Debug, Clone, PartialEq)]
1819
pub struct CellsGrid<F: Float> {
1920
table: CellTable,
2021
cells: CellVector<F>,

algorithms/linfa-clustering/src/appx_dbscan/cells_grid/tests.rs

+10
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
use crate::AppxDbscan;
22

33
use super::*;
4+
use crate::appx_dbscan::cells_grid::cell::CoreCellInfo;
45
use linfa::prelude::ParamGuard;
56
use ndarray::{arr2, Array2};
67

8+
#[test]
9+
fn autotraits() {
10+
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
11+
has_autotraits::<AppxDbscan>();
12+
has_autotraits::<StatusPoint>();
13+
has_autotraits::<CoreCellInfo<f64>>();
14+
has_autotraits::<Cell<f64>>();
15+
}
16+
717
#[test]
818
fn find_cells_test() {
919
let params = AppxDbscan::params(2)

algorithms/linfa-clustering/src/appx_dbscan/counting_tree/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ use linfa_nn::distance::{Distance, L2Dist};
44
use ndarray::{Array1, Array2, ArrayView1, Axis};
55
use std::collections::HashMap;
66

7-
#[derive(PartialEq, Debug)]
7+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
88
pub enum IntersectionType {
99
FullyCovered,
1010
Disjoint,
1111
Intersecting,
1212
}
1313

14-
#[derive(Clone)]
14+
#[derive(Clone, Debug, PartialEq)]
1515
/// Tree structure that divides the space in nested cells to perform approximate range counting
1616
/// Each member of this structure is a node in the tree
1717
pub struct TreeStructure<F: Float> {

algorithms/linfa-clustering/src/appx_dbscan/counting_tree/tests.rs

+7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ use approx::assert_abs_diff_eq;
66
use linfa::ParamGuard;
77
use ndarray::{arr1, ArrayView};
88

9+
#[test]
10+
fn autotraits() {
11+
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
12+
has_autotraits::<IntersectionType>();
13+
has_autotraits::<TreeStructure<f64>>();
14+
}
15+
916
#[test]
1017
fn counting_test() {
1118
let params = AppxDbscan::params(2)

algorithms/linfa-clustering/src/appx_dbscan/hyperparams.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pub struct AppxDbscanValidParams<F: Float, N> {
2121
pub(crate) nn_algo: N,
2222
}
2323

24-
#[derive(Debug)]
24+
#[derive(Debug, Clone, PartialEq)]
2525
/// Helper struct for building a set of [Approximated DBSCAN
2626
/// hyperparameters](struct.AppxDbscanParams.html)
2727
pub struct AppxDbscanParams<F: Float, N>(AppxDbscanValidParams<F, N>);

algorithms/linfa-clustering/src/appx_dbscan/tests.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
1-
use crate::{AppxDbscan, AppxDbscanParamsError, Dbscan};
1+
use crate::{AppxDbscan, AppxDbscanParams, AppxDbscanParamsError, AppxDbscanValidParams, Dbscan};
22
use linfa::traits::Transformer;
33
use linfa::ParamGuard;
44
use linfa_datasets::generate;
5+
use linfa_nn::distance::L2Dist;
56
use ndarray::{arr1, arr2, concatenate, s, Array1, Array2};
67
use ndarray_rand::rand::SeedableRng;
78
use ndarray_rand::rand_distr::Uniform;
89
use rand_xoshiro::Xoshiro256Plus;
910
use std::collections::HashMap;
1011

12+
#[test]
13+
fn autotraits() {
14+
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
15+
has_autotraits::<AppxDbscan>();
16+
has_autotraits::<Dbscan>();
17+
has_autotraits::<AppxDbscanValidParams<f64, L2Dist>>();
18+
has_autotraits::<AppxDbscanParams<f64, L2Dist>>();
19+
}
20+
1121
#[test]
1222
fn appx_dbscan_parity() {
1323
let mut rng = Xoshiro256Plus::seed_from_u64(40);

algorithms/linfa-clustering/src/dbscan/algorithm.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::collections::VecDeque;
99
use linfa::Float;
1010
use linfa::{traits::Transformer, DatasetBase};
1111

12-
#[derive(Clone, Debug, PartialEq)]
12+
#[derive(Clone, Debug, PartialEq, Eq)]
1313
/// DBSCAN (Density-based Spatial Clustering of Applications with Noise)
1414
/// clusters together points which are close together with enough neighbors
1515
/// labelled points which are sparsely neighbored as noise. As points may be

algorithms/linfa-clustering/src/dbscan/hyperparams.rs

+11-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use thiserror::Error;
99
derive(Serialize, Deserialize),
1010
serde(crate = "serde_crate")
1111
)]
12-
#[derive(Debug)]
12+
#[derive(Debug, Clone, PartialEq)]
1313
/// The set of hyperparameters that can be specified for the execution of
1414
/// the [DBSCAN algorithm](struct.Dbscan.html).
1515
pub struct DbscanValidParams<F: Float, D: Distance<F>, N: NearestNeighbour> {
@@ -19,7 +19,7 @@ pub struct DbscanValidParams<F: Float, D: Distance<F>, N: NearestNeighbour> {
1919
pub(crate) nn_algo: N,
2020
}
2121

22-
#[derive(Debug)]
22+
#[derive(Debug, Clone, PartialEq)]
2323
/// Helper struct for building a set of [DBSCAN hyperparameters](struct.DbscanParams.html)
2424
pub struct DbscanParams<F: Float, D: Distance<F>, N: NearestNeighbour>(DbscanValidParams<F, D, N>);
2525

@@ -106,10 +106,18 @@ impl<F: Float, D: Distance<F>, N: NearestNeighbour> DbscanValidParams<F, D, N> {
106106

107107
#[cfg(test)]
108108
mod tests {
109-
use linfa_nn::{distance::L2Dist, CommonNearestNeighbour};
109+
use linfa_nn::{distance::L2Dist, CommonNearestNeighbour, KdTree};
110110

111111
use super::*;
112112

113+
#[test]
114+
fn autotraits() {
115+
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
116+
has_autotraits::<DbscanParamsError>();
117+
has_autotraits::<DbscanParams<f64, L2Dist, KdTree>>();
118+
has_autotraits::<DbscanValidParams<f64, L2Dist, KdTree>>();
119+
}
120+
113121
#[test]
114122
fn tolerance_cannot_be_zero() {
115123
let res = DbscanParams::new(2, L2Dist, CommonNearestNeighbour::KdTree)

algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs

+11
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,17 @@ mod tests {
495495
use ndarray_rand::rand::SeedableRng;
496496
use ndarray_rand::rand_distr::{Distribution, StandardNormal};
497497

498+
#[test]
499+
fn autotraits() {
500+
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
501+
has_autotraits::<GaussianMixtureModel<f64>>();
502+
has_autotraits::<GmmError>();
503+
has_autotraits::<GmmParams<f64, Xoshiro256Plus>>();
504+
has_autotraits::<GmmValidParams<f64, Xoshiro256Plus>>();
505+
has_autotraits::<GmmInitMethod>();
506+
has_autotraits::<GmmCovarType>();
507+
}
508+
498509
pub struct MultivariateNormal {
499510
pub mean: Array1<f64>,
500511
pub covariance: Array2<f64>,

algorithms/linfa-clustering/src/gaussian_mixture/hyperparams.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use linfa::{Float, ParamGuard};
1111
derive(Serialize, Deserialize),
1212
serde(crate = "serde_crate")
1313
)]
14-
#[derive(Clone, Copy, Debug, PartialEq)]
14+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1515
/// A specifier for the type of the relation between components' covariances.
1616
pub enum GmmCovarType {
1717
/// each component has its own general covariance matrix
@@ -23,7 +23,7 @@ pub enum GmmCovarType {
2323
derive(Serialize, Deserialize),
2424
serde(crate = "serde_crate")
2525
)]
26-
#[derive(Clone, Copy, Debug)]
26+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
2727
/// A specifier for the method used for the initialization of the fitting algorithm of GMM
2828
pub enum GmmInitMethod {
2929
/// GMM fitting algorithm is initalized with the esult of the [KMeans](struct.KMeans.html) clustering.
@@ -37,7 +37,7 @@ pub enum GmmInitMethod {
3737
derive(Serialize, Deserialize),
3838
serde(crate = "serde_crate")
3939
)]
40-
#[derive(Clone, Debug)]
40+
#[derive(Clone, Debug, PartialEq)]
4141
/// The set of hyperparameters that can be specified for the execution of
4242
/// the [GMM algorithm](struct.GaussianMixtureModel.html).
4343
pub struct GmmValidParams<F: Float, R: Rng> {
@@ -90,7 +90,7 @@ impl<F: Float, R: Rng + Clone> GmmValidParams<F, R> {
9090
derive(Serialize, Deserialize),
9191
serde(crate = "serde_crate")
9292
)]
93-
#[derive(Clone, Debug)]
93+
#[derive(Clone, Debug, PartialEq)]
9494
/// The set of hyperparameters that can be specified for the execution of
9595
/// the [GMM algorithm](struct.GaussianMixtureModel.html).
9696
pub struct GmmParams<F: Float, R: Rng>(GmmValidParams<F, R>);

algorithms/linfa-clustering/src/k_means/algorithm.rs

+10
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ pub(crate) fn closest_centroid<F: Float, D: Distance<F>>(
587587
mod tests {
588588
use super::super::KMeansInit;
589589
use super::*;
590+
use crate::KMeansParamsError;
590591
use approx::assert_abs_diff_eq;
591592
use linfa_nn::distance::L1Dist;
592593
use ndarray::{array, concatenate, Array, Array1, Array2, Axis};
@@ -595,6 +596,15 @@ mod tests {
595596
use ndarray_rand::rand_distr::Uniform;
596597
use ndarray_rand::RandomExt;
597598

599+
#[test]
600+
fn autotraits() {
601+
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
602+
has_autotraits::<KMeans<f64, L2Dist>>();
603+
has_autotraits::<KMeansParamsError>();
604+
has_autotraits::<KMeansError>();
605+
has_autotraits::<IncrKMeansError<String>>();
606+
}
607+
598608
fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
599609
let mut y = Array2::zeros(x.dim());
600610
Zip::from(&mut y).and(x).for_each(|yi, &xi| {

0 commit comments

Comments
 (0)