Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear regression example #166

Closed
wants to merge 14 commits into from
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@ optional = true

[dev-dependencies]
paste = "0.1"
ndarray-stats = {git = "https://github.com/rust-ndarray/ndarray-stats", branch = "master"}
ndarray-rand = "0.9"
rand = "0.6"
99 changes: 99 additions & 0 deletions examples/linear_regression/linear_regression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#![allow(non_snake_case)]
use ndarray::{stack, Array, Array1, ArrayBase, Axis, Data, Ix1, Ix2};
use ndarray_linalg::Solve;

/// The simple linear regression model is
/// y = bX + e where e ~ N(0, sigma^2 * I)
/// In probabilistic terms this corresponds to
/// y - bX ~ N(0, sigma^2 * I)
/// y | X, b ~ N(bX, sigma^2 * I)
/// The loss for the model is simply the squared error between the model
/// predictions and the true values:
/// Loss = ||y - bX||^2
/// The maximum likelihood estimation for the model parameters `beta` can be computed
/// in closed form via the normal equation:
/// b = (X^T X)^{-1} X^T y
/// where (X^T X)^{-1} X^T is known as the pseudoinverse or Moore-Penrose inverse.
///
/// Adapted from: https://github.com/ddbourgin/numpy-ml
pub struct LinearRegression {
pub beta: Option<Array1<f64>>,
fit_intercept: bool,
}

impl LinearRegression {
pub fn new(fit_intercept: bool) -> LinearRegression {
LinearRegression {
beta: None,
fit_intercept,
}
}

/// Given:
/// - an input matrix `X`, with shape `(n_samples, n_features)`;
/// - a target variable `y`, with shape `(n_samples,)`;
/// `fit` tunes the `beta` parameter of the linear regression model
/// to match the training data distribution.
///
/// `self` is modified in place, nothing is returned.
pub fn fit<A, B>(&mut self, X: ArrayBase<A, Ix2>, y: ArrayBase<B, Ix1>)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x and y should be refence x: &ArrayBase<A, Ix2>

where
A: Data<Elem = f64>,
B: Data<Elem = f64>,
{
let (n_samples, _) = X.dim();

// Check that our inputs have compatible shapes
assert_eq!(y.dim(), n_samples);

// If we are fitting the intercept, we need an additional column
self.beta = if self.fit_intercept {
let dummy_column: Array<f64, _> = Array::ones((n_samples, 1));
let X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap();
Some(LinearRegression::solve_normal_equation(X, y))
} else {
Some(LinearRegression::solve_normal_equation(X, y))
};
}

/// Given an input matrix `X`, with shape `(n_samples, n_features)`,
/// `predict` returns the target variable according to linear model
/// learned from the training data distribution.
///
/// **Panics** if `self` has not be `fit`ted before calling `predict.
pub fn predict<A>(&self, X: &ArrayBase<A, Ix2>) -> Array1<f64>
where
A: Data<Elem = f64>,
{
let (n_samples, _) = X.dim();

// If we are fitting the intercept, we need an additional column
if self.fit_intercept {
let dummy_column: Array<f64, _> = Array::ones((n_samples, 1));
let X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap();
self._predict(&X)
} else {
self._predict(X)
}
}

fn solve_normal_equation<A, B>(X: ArrayBase<A, Ix2>, y: ArrayBase<B, Ix1>) -> Array1<f64>
where
A: Data<Elem = f64>,
B: Data<Elem = f64>,
{
let rhs = X.t().dot(&y);
let linear_operator = X.t().dot(&X);
linear_operator.solve_into(rhs).unwrap()
}

fn _predict<A>(&self, X: &ArrayBase<A, Ix2>) -> Array1<f64>
where
A: Data<Elem = f64>,
{
match &self.beta {
None => panic!("The linear regression estimator has to be fitted first!"),
Some(beta) => X.dot(beta),
}
}
}
49 changes: 49 additions & 0 deletions examples/linear_regression/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#![allow(non_snake_case)]
use ndarray::{Array1, Array2, Array, Axis};
use ndarray_linalg::random;
use ndarray_stats::DeviationExt;
use ndarray_rand::RandomExt;
use rand::distributions::StandardNormal;

// Import LinearRegression from other file ("module") in this example
mod linear_regression;
use linear_regression::LinearRegression;

/// It returns a tuple: input data and the associated target variable.
///
/// The target variable is a linear function of the input, perturbed by gaussian noise.
fn get_data(n_samples: usize, n_features: usize) -> (Array2<f64>, Array1<f64>) {
let shape = (n_samples, n_features);
let noise: Array1<f64> = Array::random(n_samples, StandardNormal);

let beta: Array1<f64> = random(n_features) * 10.;
println!("Beta used to generate target variable: {:.3}", beta);

let X: Array2<f64> = random(shape);
let y: Array1<f64> = X.dot(&beta) + noise;
(X, y)
}

pub fn main() {
let n_train_samples = 5000;
let n_test_samples = 1000;
let n_features = 3;

let (X, y) = get_data(n_train_samples + n_test_samples, n_features);
let (X_train, X_test) = X.view().split_at(Axis(0), n_train_samples);
let (y_train, y_test) = y.view().split_at(Axis(0), n_train_samples);

let mut linear_regressor = LinearRegression::new(false);
linear_regressor.fit(X_train, y_train);

let test_predictions = linear_regressor.predict(&X_test);
let mean_squared_error = test_predictions.mean_sq_err(&y_test.to_owned()).unwrap();
println!(
"Beta estimated from the training data: {:.3}",
linear_regressor.beta.unwrap()
);
println!(
"The fitted regressor has a mean squared error of {:.3}",
mean_squared_error
);
}