diff --git a/Cargo.toml b/Cargo.toml index 9257b5dd..b3b5c9d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,3 +49,6 @@ optional = true [dev-dependencies] paste = "0.1" +ndarray-stats = {git = "https://github.com/rust-ndarray/ndarray-stats", branch = "master"} +ndarray-rand = "0.9" +rand = "0.6" diff --git a/examples/linear_regression/linear_regression.rs b/examples/linear_regression/linear_regression.rs new file mode 100644 index 00000000..f11956be --- /dev/null +++ b/examples/linear_regression/linear_regression.rs @@ -0,0 +1,99 @@ +#![allow(non_snake_case)] +use ndarray::{stack, Array, Array1, ArrayBase, Axis, Data, Ix1, Ix2}; +use ndarray_linalg::Solve; + +/// The simple linear regression model is +/// y = bX + e where e ~ N(0, sigma^2 * I) +/// In probabilistic terms this corresponds to +/// y - bX ~ N(0, sigma^2 * I) +/// y | X, b ~ N(bX, sigma^2 * I) +/// The loss for the model is simply the squared error between the model +/// predictions and the true values: +/// Loss = ||y - bX||^2 +/// The maximum likelihood estimation for the model parameters `beta` can be computed +/// in closed form via the normal equation: +/// b = (X^T X)^{-1} X^T y +/// where (X^T X)^{-1} X^T is known as the pseudoinverse or Moore-Penrose inverse. +/// +/// Adapted from: https://github.com/ddbourgin/numpy-ml +pub struct LinearRegression { + pub beta: Option>, + fit_intercept: bool, +} + +impl LinearRegression { + pub fn new(fit_intercept: bool) -> LinearRegression { + LinearRegression { + beta: None, + fit_intercept, + } + } + + /// Given: + /// - an input matrix `X`, with shape `(n_samples, n_features)`; + /// - a target variable `y`, with shape `(n_samples,)`; + /// `fit` tunes the `beta` parameter of the linear regression model + /// to match the training data distribution. + /// + /// `self` is modified in place, nothing is returned. + pub fn fit(&mut self, X: ArrayBase, y: ArrayBase) + where + A: Data, + B: Data, + { + let (n_samples, _) = X.dim(); + + // Check that our inputs have compatible shapes + assert_eq!(y.dim(), n_samples); + + // If we are fitting the intercept, we need an additional column + self.beta = if self.fit_intercept { + let dummy_column: Array = Array::ones((n_samples, 1)); + let X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap(); + Some(LinearRegression::solve_normal_equation(X, y)) + } else { + Some(LinearRegression::solve_normal_equation(X, y)) + }; + } + + /// Given an input matrix `X`, with shape `(n_samples, n_features)`, + /// `predict` returns the target variable according to linear model + /// learned from the training data distribution. + /// + /// **Panics** if `self` has not be `fit`ted before calling `predict. + pub fn predict(&self, X: &ArrayBase) -> Array1 + where + A: Data, + { + let (n_samples, _) = X.dim(); + + // If we are fitting the intercept, we need an additional column + if self.fit_intercept { + let dummy_column: Array = Array::ones((n_samples, 1)); + let X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap(); + self._predict(&X) + } else { + self._predict(X) + } + } + + fn solve_normal_equation(X: ArrayBase, y: ArrayBase) -> Array1 + where + A: Data, + B: Data, + { + let rhs = X.t().dot(&y); + let linear_operator = X.t().dot(&X); + linear_operator.solve_into(rhs).unwrap() + } + + fn _predict(&self, X: &ArrayBase) -> Array1 + where + A: Data, + { + match &self.beta { + None => panic!("The linear regression estimator has to be fitted first!"), + Some(beta) => X.dot(beta), + } + } +} diff --git a/examples/linear_regression/main.rs b/examples/linear_regression/main.rs new file mode 100644 index 00000000..4e76edb5 --- /dev/null +++ b/examples/linear_regression/main.rs @@ -0,0 +1,49 @@ +#![allow(non_snake_case)] +use ndarray::{Array1, Array2, Array, Axis}; +use ndarray_linalg::random; +use ndarray_stats::DeviationExt; +use ndarray_rand::RandomExt; +use rand::distributions::StandardNormal; + +// Import LinearRegression from other file ("module") in this example +mod linear_regression; +use linear_regression::LinearRegression; + +/// It returns a tuple: input data and the associated target variable. +/// +/// The target variable is a linear function of the input, perturbed by gaussian noise. +fn get_data(n_samples: usize, n_features: usize) -> (Array2, Array1) { + let shape = (n_samples, n_features); + let noise: Array1 = Array::random(n_samples, StandardNormal); + + let beta: Array1 = random(n_features) * 10.; + println!("Beta used to generate target variable: {:.3}", beta); + + let X: Array2 = random(shape); + let y: Array1 = X.dot(&beta) + noise; + (X, y) +} + +pub fn main() { + let n_train_samples = 5000; + let n_test_samples = 1000; + let n_features = 3; + + let (X, y) = get_data(n_train_samples + n_test_samples, n_features); + let (X_train, X_test) = X.view().split_at(Axis(0), n_train_samples); + let (y_train, y_test) = y.view().split_at(Axis(0), n_train_samples); + + let mut linear_regressor = LinearRegression::new(false); + linear_regressor.fit(X_train, y_train); + + let test_predictions = linear_regressor.predict(&X_test); + let mean_squared_error = test_predictions.mean_sq_err(&y_test.to_owned()).unwrap(); + println!( + "Beta estimated from the training data: {:.3}", + linear_regressor.beta.unwrap() + ); + println!( + "The fitted regressor has a mean squared error of {:.3}", + mean_squared_error + ); +}