Skip to content

Commit 7c161b6

Browse files
committed
Port deviation functions from StatsBase.jl
1 parent 6722934 commit 7c161b6

File tree

2 files changed

+388
-0
lines changed

2 files changed

+388
-0
lines changed

src/deviation.rs

+386
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
use ndarray::{ArrayBase, Data, Dimension, Zip};
2+
use num_traits::{Float, Signed, ToPrimitive};
3+
use std::convert::Into;
4+
use std::ops::AddAssign;
5+
6+
/// Extension trait for `ArrayBase` providing functions
7+
/// to compute different deviation measures.
8+
pub trait DeviationExt<A, S, D>
9+
where
10+
S: Data<Elem = A>,
11+
D: Dimension,
12+
{
13+
fn count_eq(&self, other: &ArrayBase<S, D>) -> usize
14+
where
15+
A: PartialEq;
16+
17+
fn count_neq(&self, other: &ArrayBase<S, D>) -> usize
18+
where
19+
A: PartialEq;
20+
21+
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> A
22+
where
23+
A: AddAssign + Clone + Signed;
24+
25+
fn l2_dist(&self, other: &ArrayBase<S, D>) -> f64
26+
where
27+
A: AddAssign + Clone + Signed + ToPrimitive;
28+
29+
fn l1_dist(&self, other: &ArrayBase<S, D>) -> A
30+
where
31+
A: AddAssign + Clone + Signed;
32+
33+
fn linf_dist(&self, other: &ArrayBase<S, D>) -> A
34+
where
35+
A: Clone + PartialOrd + Signed;
36+
37+
fn gkl_div(&self, other: &ArrayBase<S, D>) -> A
38+
where
39+
A: AddAssign + Clone + Float;
40+
41+
fn mean_abs_dev(&self, other: &ArrayBase<S, D>) -> f64
42+
where
43+
A: AddAssign + Clone + Signed + Into<f64>;
44+
45+
fn max_abs_dev(&self, other: &ArrayBase<S, D>) -> A
46+
where
47+
A: Clone + PartialOrd + Signed;
48+
49+
fn mean_sq_dev(&self, other: &ArrayBase<S, D>) -> f64
50+
where
51+
A: AddAssign + Clone + Signed + Into<f64>;
52+
53+
fn root_mean_sq_dev(&self, other: &ArrayBase<S, D>) -> f64
54+
where
55+
A: AddAssign + Clone + Signed + Into<f64>;
56+
57+
fn peak_signal_to_noise_ratio(&self, other: &ArrayBase<S, D>, maxv: A) -> f64
58+
where
59+
A: AddAssign + Clone + Signed + Into<f64>;
60+
}
61+
62+
impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
63+
where
64+
S: Data<Elem = A>,
65+
D: Dimension,
66+
{
67+
fn count_eq(&self, other: &ArrayBase<S, D>) -> usize
68+
where
69+
A: PartialEq,
70+
{
71+
let mut c = 0;
72+
73+
Zip::from(self).and(other).apply(|a, b| {
74+
if a == b {
75+
c += 1;
76+
}
77+
});
78+
79+
c
80+
}
81+
82+
fn count_neq(&self, other: &ArrayBase<S, D>) -> usize
83+
where
84+
A: PartialEq,
85+
{
86+
self.len() - self.count_eq(other)
87+
}
88+
89+
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> A
90+
where
91+
A: AddAssign + Clone + Signed,
92+
{
93+
let mut r = A::zero();
94+
95+
Zip::from(self).and(other).apply(|self_i, other_i| {
96+
let (a, b) = (self_i.clone(), other_i.clone());
97+
let abs_diff = (a - b).abs();
98+
r += abs_diff.clone() * abs_diff;
99+
});
100+
101+
r
102+
}
103+
104+
fn l2_dist(&self, other: &ArrayBase<S, D>) -> f64
105+
where
106+
A: AddAssign + Clone + Signed + ToPrimitive,
107+
{
108+
self.sq_l2_dist(other).to_f64().unwrap().sqrt()
109+
}
110+
111+
fn l1_dist(&self, other: &ArrayBase<S, D>) -> A
112+
where
113+
A: AddAssign + Clone + Signed,
114+
{
115+
let mut r = A::zero();
116+
117+
Zip::from(self).and(other).apply(|self_i, other_i| {
118+
let (a, b) = (self_i.clone(), other_i.clone());
119+
r += (a - b).abs();
120+
});
121+
122+
r
123+
}
124+
125+
fn linf_dist(&self, other: &ArrayBase<S, D>) -> A
126+
where
127+
A: Clone + PartialOrd + Signed,
128+
{
129+
let mut max = A::zero();
130+
131+
Zip::from(self).and(other).apply(|self_i, other_i| {
132+
let (a, b) = (self_i.clone(), other_i.clone());
133+
let diff = (a - b).abs();
134+
if diff > max {
135+
max = diff;
136+
}
137+
});
138+
139+
max
140+
}
141+
142+
fn gkl_div(&self, other: &ArrayBase<S, D>) -> A
143+
where
144+
A: AddAssign + Clone + Float,
145+
{
146+
let mut r = A::zero();
147+
148+
Zip::from(self).and(other).apply(|self_i, other_i| {
149+
let (a, b) = (self_i.clone(), other_i.clone());
150+
r += a * (a / b).ln() - a + b;
151+
});
152+
153+
r
154+
}
155+
156+
fn mean_abs_dev(&self, other: &ArrayBase<S, D>) -> f64
157+
where
158+
A: AddAssign + Clone + Signed + Into<f64>,
159+
{
160+
let a: f64 = self.l1_dist(other).into();
161+
let b = self.len().to_f64().unwrap();
162+
a / b
163+
}
164+
165+
#[inline]
166+
fn max_abs_dev(&self, other: &ArrayBase<S, D>) -> A
167+
where
168+
A: Clone + PartialOrd + Signed,
169+
{
170+
self.linf_dist(other)
171+
}
172+
173+
fn mean_sq_dev(&self, other: &ArrayBase<S, D>) -> f64
174+
where
175+
A: AddAssign + Clone + Signed + Into<f64>,
176+
{
177+
let a: f64 = self.sq_l2_dist(other).into();
178+
let b = self.len().to_f64().unwrap();
179+
a / b
180+
}
181+
182+
fn root_mean_sq_dev(&self, other: &ArrayBase<S, D>) -> f64
183+
where
184+
A: AddAssign + Clone + Signed + Into<f64>,
185+
{
186+
self.mean_sq_dev(other).sqrt()
187+
}
188+
189+
fn peak_signal_to_noise_ratio(&self, other: &ArrayBase<S, D>, maxv: A) -> f64
190+
where
191+
A: AddAssign + Clone + Signed + Into<f64>,
192+
{
193+
let maxv_f: f64 = maxv.into();
194+
10. * f64::log10(maxv_f * maxv_f / self.mean_sq_dev(&other))
195+
}
196+
}
197+
198+
#[cfg(test)]
199+
mod tests {
200+
use super::*;
201+
use approx::assert_abs_diff_eq;
202+
use ndarray::*;
203+
use ndarray_rand::RandomExt;
204+
use rand::distributions::Uniform;
205+
use std::f64;
206+
207+
#[test]
208+
fn test_count_eq() {
209+
let a = array![0., 0.];
210+
let b = array![1., 0.];
211+
let c = array![0., 1.];
212+
let d = array![1., 1.];
213+
214+
assert_eq!(a.count_eq(&a), 2);
215+
assert_eq!(a.count_eq(&b), 1);
216+
assert_eq!(a.count_eq(&c), 1);
217+
assert_eq!(a.count_eq(&d), 0);
218+
}
219+
220+
#[test]
221+
fn test_count_neq() {
222+
let a = array![0., 0.];
223+
let b = array![1., 0.];
224+
let c = array![0., 1.];
225+
let d = array![1., 1.];
226+
227+
assert_eq!(a.count_neq(&a), 0);
228+
assert_eq!(a.count_neq(&b), 1);
229+
assert_eq!(a.count_neq(&c), 1);
230+
assert_eq!(a.count_neq(&d), 2);
231+
}
232+
233+
#[test]
234+
fn test_sq_l2_dist() {
235+
let a = array![0., 1., 4., 2.];
236+
let b = array![1., 1., 2., 4.];
237+
238+
assert_eq!(a.sq_l2_dist(&b), 9.);
239+
}
240+
241+
#[test]
242+
fn test_l2_dist() {
243+
let a = array![0., 1., 4., 2.];
244+
let b = array![1., 1., 2., 4.];
245+
246+
assert_eq!(a.l2_dist(&b), 3.);
247+
}
248+
249+
#[test]
250+
fn test_l1_dist() {
251+
let a = array![0., 1., 4., 2.];
252+
let b = array![1., 1., 2., 4.];
253+
254+
assert_eq!(a.l1_dist(&b), 5.);
255+
}
256+
257+
#[test]
258+
fn test_linf_dist() {
259+
let a = array![0., 0.];
260+
let b = array![1., 0.];
261+
let c = array![1., 2.];
262+
263+
assert_eq!(a.linf_dist(&a), 0.);
264+
265+
assert_eq!(a.linf_dist(&b), 1.);
266+
assert_eq!(b.linf_dist(&a), 1.);
267+
268+
assert_eq!(a.linf_dist(&c), 2.);
269+
assert_eq!(c.linf_dist(&a), 2.);
270+
}
271+
272+
#[test]
273+
fn test_gkl_div() {
274+
let a = Array::random((5,), Uniform::new(0., 1.));
275+
let b = Array::random((5,), Uniform::new(1., 2.));
276+
let c = Array::random((5,), Uniform::new(-1., 0.));
277+
278+
assert_eq!(a.gkl_div(&a), 0.);
279+
assert!(a.gkl_div(&b) > 0.);
280+
assert!(b.gkl_div(&a) > 0.);
281+
assert_ne!(a.gkl_div(&b), b.gkl_div(&a));
282+
283+
// TODO: what is the sign check logic doing in StatsBase.jl impl?
284+
assert!(f64::is_nan(a.gkl_div(&c)));
285+
}
286+
287+
#[test]
288+
fn test_mean_abs_dev() {
289+
let a = array![1., 1.];
290+
let b = array![3., 5.];
291+
292+
assert_eq!(a.mean_abs_dev(&a), 0.);
293+
assert_eq!(a.mean_abs_dev(&b), 3.);
294+
assert_eq!(b.mean_abs_dev(&a), 3.);
295+
}
296+
297+
#[test]
298+
fn test_max_abs_dev() {
299+
// This is effectively an alias for linf_dist, so not retesting deeply
300+
let a = array![0., 0.];
301+
let b = array![2., 4.];
302+
303+
assert_eq!(a.max_abs_dev(&a), 0.);
304+
assert_eq!(a.max_abs_dev(&b), 4.);
305+
assert_eq!(b.max_abs_dev(&a), 4.);
306+
}
307+
308+
#[test]
309+
fn test_mean_sq_dev() {
310+
let a = array![1., 1.];
311+
let b = array![3., 5.];
312+
313+
assert_eq!(a.mean_sq_dev(&a), 0.);
314+
assert_eq!(a.mean_sq_dev(&b), 10.);
315+
assert_eq!(b.mean_sq_dev(&a), 10.);
316+
}
317+
318+
#[test]
319+
fn test_root_mean_sq_dev() {
320+
let a = array![1., 1.];
321+
let b = array![3., 5.];
322+
323+
assert_eq!(a.root_mean_sq_dev(&a), 0.);
324+
assert_abs_diff_eq!(a.root_mean_sq_dev(&b), 10.0.sqrt());
325+
assert_abs_diff_eq!(b.root_mean_sq_dev(&a), 10.0.sqrt());
326+
}
327+
328+
#[test]
329+
fn test_peak_signal_to_noise_ratio() {
330+
let a = array![1., 1.];
331+
assert!(a.peak_signal_to_noise_ratio(&a, 1.).is_infinite());
332+
333+
let a = array![1., 2., 3., 4., 5., 6., 7.];
334+
let b = array![1., 3., 3., 4., 6., 7., 8.];
335+
let maxv = 8.;
336+
let expected = 20. * Float::log10(maxv) - 10. * Float::log10(a.mean_sq_dev(&b));
337+
let actual = a.peak_signal_to_noise_ratio(&b, maxv);
338+
339+
assert_abs_diff_eq!(actual, expected);
340+
}
341+
342+
#[test]
343+
fn test_deviations_with_n_by_m_ints() {
344+
let a = array![[0, 1], [4, 2]];
345+
let b = array![[1, 1], [2, 4]];
346+
347+
assert_eq!(a.count_eq(&a), 4);
348+
assert_eq!(a.count_neq(&a), 0);
349+
assert_eq!(a.sq_l2_dist(&b), 9);
350+
assert_eq!(a.l2_dist(&b), 3.);
351+
assert_eq!(a.l1_dist(&b), 5);
352+
assert_eq!(a.linf_dist(&b), 2);
353+
354+
assert_abs_diff_eq!(a.mean_abs_dev(&b), 1.25);
355+
assert_eq!(a.max_abs_dev(&b), 2);
356+
assert_abs_diff_eq!(a.mean_sq_dev(&b), 2.25);
357+
assert_abs_diff_eq!(a.root_mean_sq_dev(&b), 1.5);
358+
assert_abs_diff_eq!(
359+
a.peak_signal_to_noise_ratio(&b, 4),
360+
8.519374645445623,
361+
epsilon = f64::EPSILON
362+
);
363+
364+
// TODO: gkl_div
365+
}
366+
367+
#[test]
368+
fn test_deviations_with_empty_inputs() {
369+
let a: Array1<f64> = array![];
370+
371+
assert_eq!(a.count_eq(&a), 0);
372+
assert_eq!(a.count_neq(&a), 0);
373+
assert_eq!(a.sq_l2_dist(&a), 0.);
374+
assert_eq!(a.l2_dist(&a), 0.);
375+
assert_eq!(a.l1_dist(&a), 0.);
376+
assert_eq!(a.linf_dist(&a), 0.);
377+
378+
assert!(a.mean_abs_dev(&a).is_nan());
379+
assert_eq!(a.max_abs_dev(&a), 0.);
380+
assert!(a.mean_sq_dev(&a).is_nan());
381+
assert!(a.root_mean_sq_dev(&a).is_nan());
382+
assert!(a.peak_signal_to_noise_ratio(&a, 0.).is_nan());
383+
384+
// TODO: gkl_div
385+
}
386+
}

0 commit comments

Comments
 (0)