Skip to content

Commit 685bd4d

Browse files
committed
deviation: Implement sq_l2_dist
1 parent c055019 commit 685bd4d

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

src/deviation.rs

+32
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
use ndarray::{ArrayBase, Data, Dimension, Zip};
2+
use num_traits::sign::Signed;
3+
use std::ops::AddAssign;
24

35
/// Extension trait for `ArrayBase` providing functions
46
/// to compute different deviation measures.
@@ -14,6 +16,10 @@ where
1416
fn count_neq(&self, other: &ArrayBase<S, D>) -> usize
1517
where
1618
A: PartialEq;
19+
20+
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> A
21+
where
22+
A: AddAssign + Copy + Signed;
1723
}
1824

1925
impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
@@ -42,6 +48,19 @@ where
4248
{
4349
self.len() - self.count_eq(other)
4450
}
51+
52+
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> A
53+
where
54+
A: AddAssign + Copy + Signed,
55+
{
56+
let mut r = A::zero();
57+
58+
Zip::from(self).and(other).apply(|&a, &b| {
59+
r += (a - b).abs();
60+
});
61+
62+
r
63+
}
4564
}
4665

4766
#[cfg(test)]
@@ -74,4 +93,17 @@ mod tests {
7493
assert_eq!(b.count_neq(&c), 7);
7594
assert_eq!(d.count_neq(&e), 2);
7695
}
96+
97+
#[test]
98+
fn test_sq_l2_dist() {
99+
let a = array![1., 2., 3., 4., 5., 6., 7.];
100+
let b = array![1., 3., 3., 4., 6., 7., 8.];
101+
102+
assert_eq!(a.sq_l2_dist(&b), (&a - &b).mapv(f64::abs).sum());
103+
104+
let a = array![[1, 2], [3, 4], [5, 6]];
105+
let b = array![[1, 3], [3, 4], [6, 7]];
106+
107+
assert_eq!(a.sq_l2_dist(&b), (&a - &b).mapv(i32::abs).sum());
108+
}
77109
}

0 commit comments

Comments
 (0)