-
Notifications
You must be signed in to change notification settings - Fork 79
/
Copy pathlinear_regression.rs
126 lines (110 loc) · 4.09 KB
/
linear_regression.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#![allow(non_snake_case)]
use ndarray::{stack, Array, Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
use ndarray_linalg::{random, Solve};
use ndarray_rand::RandomExt;
use ndarray_stats::DeviationExt;
use rand::distributions::StandardNormal;
/// The simple linear regression model is
/// y = bX + e where e ~ N(0, sigma^2 * I)
/// In probabilistic terms this corresponds to
/// y - bX ~ N(0, sigma^2 * I)
/// y | X, b ~ N(bX, sigma^2 * I)
/// The loss for the model is simply the squared error between the model
/// predictions and the true values:
/// Loss = ||y - bX||^2
/// The maximum likelihood estimation for the model parameters `beta` can be computed
/// in closed form via the normal equation:
/// b = (X^T X)^{-1} X^T y
/// where (X^T X)^{-1} X^T is known as the pseudoinverse or Moore-Penrose inverse.
///
/// Adapted from: https://github.com/xinscrs/numpy-ml
struct LinearRegression {
pub beta: Option<Array1<f64>>,
fit_intercept: bool,
}
impl LinearRegression {
fn new(fit_intercept: bool) -> LinearRegression {
LinearRegression {
beta: None,
fit_intercept,
}
}
fn fit<A, B>(&mut self, X: ArrayBase<A, Ix2>, y: ArrayBase<B, Ix1>)
where
A: Data<Elem = f64>,
B: Data<Elem = f64>,
{
let (n_samples, _) = X.dim();
// Check that our inputs have compatible shapes
assert_eq!(y.dim(), n_samples);
// If we are fitting the intercept, we need an additional column
self.beta = if self.fit_intercept {
let dummy_column: Array<f64, _> = Array::ones((n_samples, 1));
let X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap();
Some(LinearRegression::solve_normal_equation(X, y))
} else {
Some(LinearRegression::solve_normal_equation(X, y))
};
}
fn solve_normal_equation<A, B>(X: ArrayBase<A, Ix2>, y: ArrayBase<B, Ix1>) -> Array1<f64>
where
A: Data<Elem = f64>,
B: Data<Elem = f64>,
{
let rhs = X.t().dot(&y);
let linear_operator = X.t().dot(&X);
linear_operator.solve_into(rhs).unwrap()
}
fn predict<A>(&self, X: &ArrayBase<A, Ix2>) -> Array1<f64>
where
A: Data<Elem = f64>,
{
let (n_samples, _) = X.dim();
// If we are fitting the intercept, we need an additional column
if self.fit_intercept {
let dummy_column: Array<f64, _> = Array::ones((n_samples, 1));
let X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap();
self._predict(&X)
} else {
self._predict(X)
}
}
fn _predict<A>(&self, X: &ArrayBase<A, Ix2>) -> Array1<f64>
where
A: Data<Elem = f64>,
{
match &self.beta {
None => panic!("The linear regression estimator has to be fitted first!"),
Some(beta) => X.dot(beta),
}
}
}
fn get_data(n_samples: usize, n_features: usize) -> (Array2<f64>, Array1<f64>) {
let shape = (n_samples, n_features);
let noise: Array1<f64> = Array::random(n_samples, StandardNormal);
let beta: Array1<f64> = random(n_features) * 10.;
println!("Beta used to generate target variable: {:.3}", beta);
let X: Array2<f64> = random(shape);
let y: Array1<f64> = X.dot(&beta) + noise;
(X, y)
}
pub fn main() {
let n_train_samples = 5000;
let n_test_samples = 1000;
let n_features = 3;
let (X, y) = get_data(n_train_samples + n_test_samples, n_features);
let (X_train, X_test) = X.view().split_at(Axis(0), n_train_samples);
let (y_train, y_test) = y.view().split_at(Axis(0), n_train_samples);
let mut linear_regressor = LinearRegression::new(false);
linear_regressor.fit(X_train, y_train);
let test_predictions = linear_regressor.predict(&X_test);
let mean_squared_error = test_predictions.mean_sq_err(&y_test.to_owned()).unwrap();
println!(
"Beta estimated from the training data: {:.3}",
linear_regressor.beta.unwrap()
);
println!(
"The fitted regressor has a root mean squared error of {:.3}",
mean_squared_error
);
}