Skip to content

Commit 21357e2

Browse files
YuhanLiinPABannier
andauthored
Adding Multi-Task ElasticNet support (rust-ml#238)
* added block coordinate descent function * added duality_gap_mtl computation * ENH cd pass to be consistent with bcd * added prox operator for MTL Enet * added helper functions for tests * working ent mtl penalties * bcd lower objective test pass * added MultiTaskEnet struct * added MTENET documentation * added API MTENET * added variance, z-score, conf interval for multitask ENET * added multi-task estimators * added tests for MTL * added tests for Enet and MTL * WIP: made variance params generic over the number of tasks * added z_score and confidence_95th for MTL * WIP make compute_variance generic over the dimension * Replace for loops in block_coordinate_descent with general_mat_mul calls * Bring back generic compute_intercept * Replace manual norm calculations with norm trait calls * Add docs and derives to multi task types * Add example for multitask_elasticnet * Rename shape() calls to nrows and ncols Co-authored-by: Pierre-Antoine Bannier <[email protected]>
1 parent 44b244c commit 21357e2

File tree

7 files changed

+754
-67
lines changed

7 files changed

+754
-67
lines changed

algorithms/linfa-elasticnet/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,6 @@ thiserror = "1.0"
4040
linfa = { version = "0.6.0", path = "../.." }
4141

4242
[dev-dependencies]
43-
linfa-datasets = { version = "0.6.0", path = "../../datasets", features = ["diabetes"] }
43+
linfa-datasets = { version = "0.6.0", path = "../../datasets", features = ["diabetes", "linnerud"] }
4444
ndarray-rand = "0.14"
4545
rand_xoshiro = "0.6"

algorithms/linfa-elasticnet/examples/elasticnet.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ fn main() -> Result<()> {
55
// load Diabetes dataset
66
let (train, valid) = linfa_datasets::diabetes().split_with_ratio(0.90);
77

8-
// train pure LASSO model with 0.1 penalty
8+
// train pure LASSO model with 0.3 penalty
99
let model = ElasticNet::params()
1010
.penalty(0.3)
1111
.l1_ratio(1.0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use linfa::prelude::*;
2+
use linfa_elasticnet::{MultiTaskElasticNet, Result};
3+
4+
fn main() -> Result<()> {
5+
// load Diabetes dataset
6+
let (train, valid) = linfa_datasets::linnerud().split_with_ratio(0.80);
7+
8+
// train pure LASSO model with 0.1 penalty
9+
let model = MultiTaskElasticNet::params()
10+
.penalty(0.1)
11+
.l1_ratio(1.0)
12+
.fit(&train)?;
13+
14+
println!("intercept: {}", model.intercept());
15+
println!("params: {}", model.hyperplane());
16+
17+
println!("z score: {:?}", model.z_score());
18+
19+
// validate
20+
let y_est = model.predict(&valid);
21+
println!("predicted variance: {}", y_est.r2(&valid)?);
22+
23+
Ok(())
24+
}

0 commit comments

Comments
 (0)