|
1 | 1 | use ndarray::{ArrayBase, Data, Dimension, Zip};
|
| 2 | +use num_traits::sign::Signed; |
| 3 | +use std::ops::AddAssign; |
2 | 4 |
|
3 | 5 | /// Extension trait for `ArrayBase` providing functions
|
4 | 6 | /// to compute different deviation measures.
|
|
14 | 16 | fn count_neq(&self, other: &ArrayBase<S, D>) -> usize
|
15 | 17 | where
|
16 | 18 | A: PartialEq;
|
| 19 | + |
| 20 | + fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> A |
| 21 | + where |
| 22 | + A: AddAssign + Copy + Signed; |
17 | 23 | }
|
18 | 24 |
|
19 | 25 | impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
|
|
42 | 48 | {
|
43 | 49 | self.len() - self.count_eq(other)
|
44 | 50 | }
|
| 51 | + |
| 52 | + fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> A |
| 53 | + where |
| 54 | + A: AddAssign + Copy + Signed, |
| 55 | + { |
| 56 | + let mut r = A::zero(); |
| 57 | + |
| 58 | + Zip::from(self).and(other).apply(|&a, &b| { |
| 59 | + r += (a - b).abs(); |
| 60 | + }); |
| 61 | + |
| 62 | + r |
| 63 | + } |
45 | 64 | }
|
46 | 65 |
|
47 | 66 | #[cfg(test)]
|
@@ -74,4 +93,17 @@ mod tests {
|
74 | 93 | assert_eq!(b.count_neq(&c), 7);
|
75 | 94 | assert_eq!(d.count_neq(&e), 2);
|
76 | 95 | }
|
| 96 | + |
| 97 | + #[test] |
| 98 | + fn test_sq_l2_dist() { |
| 99 | + let a = array![1., 2., 3., 4., 5., 6., 7.]; |
| 100 | + let b = array![1., 3., 3., 4., 6., 7., 8.]; |
| 101 | + |
| 102 | + assert_eq!(a.sq_l2_dist(&b), (&a - &b).mapv(f64::abs).sum()); |
| 103 | + |
| 104 | + let a = array![[1, 2], [3, 4], [5, 6]]; |
| 105 | + let b = array![[1, 3], [3, 4], [6, 7]]; |
| 106 | + |
| 107 | + assert_eq!(a.sq_l2_dist(&b), (&a - &b).mapv(i32::abs).sum()); |
| 108 | + } |
77 | 109 | }
|
0 commit comments