Skip to content

Commit d5a1611

Browse files
committed
Port deviation functions from StatsBase.jl
1 parent 6f898f6 commit d5a1611

File tree

5 files changed

+493
-2
lines changed

5 files changed

+493
-2
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ quickcheck = { version = "0.8.1", default-features = false }
3030
ndarray-rand = "0.9"
3131
approx = "0.3"
3232
quickcheck_macros = "0.8"
33+
num-bigint = "0.2.2"
3334

3435
[[bench]]
3536
name = "sort"

src/deviation.rs

+240
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
use ndarray::{ArrayBase, Data, Dimension, Zip};
2+
use num_traits::{Signed, ToPrimitive};
3+
use std::convert::Into;
4+
use std::ops::AddAssign;
5+
6+
use crate::errors::{MultiInputError, ShapeMismatch};
7+
8+
/// Extension trait for `ArrayBase` providing functions
9+
/// to compute different deviation measures.
10+
pub trait DeviationExt<A, S, D>
11+
where
12+
S: Data<Elem = A>,
13+
D: Dimension,
14+
{
15+
fn count_eq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
16+
where
17+
A: PartialEq;
18+
19+
fn count_neq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
20+
where
21+
A: PartialEq;
22+
23+
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
24+
where
25+
A: AddAssign + Clone + Signed;
26+
27+
fn l2_dist(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
28+
where
29+
A: AddAssign + Clone + Signed + ToPrimitive;
30+
31+
fn l1_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
32+
where
33+
A: AddAssign + Clone + Signed;
34+
35+
fn linf_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
36+
where
37+
A: Clone + PartialOrd + Signed;
38+
39+
fn mean_abs_dev(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
40+
where
41+
A: AddAssign + Clone + Signed + ToPrimitive;
42+
43+
fn max_abs_dev(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
44+
where
45+
A: Clone + PartialOrd + Signed;
46+
47+
fn mean_sq_dev(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
48+
where
49+
A: AddAssign + Clone + Signed + ToPrimitive;
50+
51+
fn root_mean_sq_dev(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
52+
where
53+
A: AddAssign + Clone + Signed + ToPrimitive;
54+
55+
fn peak_signal_to_noise_ratio(
56+
&self,
57+
other: &ArrayBase<S, D>,
58+
maxv: A,
59+
) -> Result<f64, MultiInputError>
60+
where
61+
A: AddAssign + Clone + Signed + ToPrimitive;
62+
63+
private_decl! {}
64+
}
65+
66+
macro_rules! return_err_if_empty {
67+
($arr:expr) => {
68+
if $arr.len() == 0 {
69+
return Err(MultiInputError::EmptyInput);
70+
}
71+
};
72+
}
73+
macro_rules! return_err_unless_same_shape {
74+
($arr_a:expr, $arr_b:expr) => {
75+
if $arr_a.shape() != $arr_b.shape() {
76+
return Err(ShapeMismatch {
77+
first_shape: $arr_a.shape().to_vec(),
78+
second_shape: $arr_b.shape().to_vec(),
79+
}
80+
.into());
81+
}
82+
};
83+
}
84+
85+
impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
86+
where
87+
S: Data<Elem = A>,
88+
D: Dimension,
89+
{
90+
fn count_eq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
91+
where
92+
A: PartialEq,
93+
{
94+
return_err_if_empty!(self);
95+
return_err_unless_same_shape!(self, other);
96+
97+
let mut count = 0;
98+
99+
Zip::from(self).and(other).apply(|a, b| {
100+
if a == b {
101+
count += 1;
102+
}
103+
});
104+
105+
Ok(count)
106+
}
107+
108+
fn count_neq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
109+
where
110+
A: PartialEq,
111+
{
112+
self.count_eq(other).map(|n_eq| self.len() - n_eq)
113+
}
114+
115+
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
116+
where
117+
A: AddAssign + Clone + Signed,
118+
{
119+
return_err_if_empty!(self);
120+
return_err_unless_same_shape!(self, other);
121+
122+
let mut result = A::zero();
123+
124+
Zip::from(self).and(other).apply(|self_i, other_i| {
125+
let (a, b) = (self_i.clone(), other_i.clone());
126+
let abs_diff = (a - b).abs();
127+
result += abs_diff.clone() * abs_diff;
128+
});
129+
130+
Ok(result)
131+
}
132+
133+
fn l2_dist(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
134+
where
135+
A: AddAssign + Clone + Signed + ToPrimitive,
136+
{
137+
let sq_l2_dist = self
138+
.sq_l2_dist(other)?
139+
.to_f64()
140+
.expect("failed cast from type A to f64");
141+
142+
Ok(sq_l2_dist.sqrt())
143+
}
144+
145+
fn l1_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
146+
where
147+
A: AddAssign + Clone + Signed,
148+
{
149+
return_err_if_empty!(self);
150+
return_err_unless_same_shape!(self, other);
151+
152+
let mut result = A::zero();
153+
154+
Zip::from(self).and(other).apply(|self_i, other_i| {
155+
let (a, b) = (self_i.clone(), other_i.clone());
156+
result += (a - b).abs();
157+
});
158+
159+
Ok(result)
160+
}
161+
162+
fn linf_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
163+
where
164+
A: Clone + PartialOrd + Signed,
165+
{
166+
return_err_if_empty!(self);
167+
return_err_unless_same_shape!(self, other);
168+
169+
let mut max = A::zero();
170+
171+
Zip::from(self).and(other).apply(|self_i, other_i| {
172+
let (a, b) = (self_i.clone(), other_i.clone());
173+
let diff = (a - b).abs();
174+
if diff > max {
175+
max = diff;
176+
}
177+
});
178+
179+
Ok(max)
180+
}
181+
182+
fn mean_abs_dev(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
183+
where
184+
A: AddAssign + Clone + Signed + ToPrimitive,
185+
{
186+
let a = self
187+
.l1_dist(other)?
188+
.to_f64()
189+
.expect("failed cast from type A to f64");
190+
let b = self.len() as f64;
191+
192+
Ok(a / b)
193+
}
194+
195+
#[inline]
196+
fn max_abs_dev(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
197+
where
198+
A: Clone + PartialOrd + Signed,
199+
{
200+
self.linf_dist(other)
201+
}
202+
203+
fn mean_sq_dev(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
204+
where
205+
A: AddAssign + Clone + Signed + ToPrimitive,
206+
{
207+
let a = self
208+
.sq_l2_dist(other)?
209+
.to_f64()
210+
.expect("failed cast from type A to f64");
211+
let b = self.len() as f64;
212+
213+
Ok(a / b)
214+
}
215+
216+
fn root_mean_sq_dev(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
217+
where
218+
A: AddAssign + Clone + Signed + ToPrimitive,
219+
{
220+
let msd = self.mean_sq_dev(other)?;
221+
Ok(msd.sqrt())
222+
}
223+
224+
fn peak_signal_to_noise_ratio(
225+
&self,
226+
other: &ArrayBase<S, D>,
227+
maxv: A,
228+
) -> Result<f64, MultiInputError>
229+
where
230+
A: AddAssign + Clone + Signed + ToPrimitive,
231+
{
232+
let maxv_f = maxv.to_f64().expect("failed cast from type A to f64");
233+
let msd = self.mean_sq_dev(&other)?;
234+
let psnr = 10. * f64::log10(maxv_f * maxv_f / msd);
235+
236+
Ok(psnr)
237+
}
238+
239+
private_impl! {}
240+
}

src/errors.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ impl From<EmptyInput> for MinMaxError {
4646
/// An error used by methods and functions that take two arrays as argument and
4747
/// expect them to have exactly the same shape
4848
/// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`).
49-
#[derive(Clone, Debug)]
49+
#[derive(Clone, Debug, PartialEq)]
5050
pub struct ShapeMismatch {
5151
pub first_shape: Vec<usize>,
5252
pub second_shape: Vec<usize>,
@@ -65,7 +65,7 @@ impl fmt::Display for ShapeMismatch {
6565
impl Error for ShapeMismatch {}
6666

6767
/// An error for methods that take multiple non-empty array inputs.
68-
#[derive(Clone, Debug)]
68+
#[derive(Clone, Debug, PartialEq)]
6969
pub enum MultiInputError {
7070
/// One or more of the arrays were empty.
7171
EmptyInput,

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
//! [`StatsBase.jl`]: https://juliastats.github.io/StatsBase.jl/latest/
2929
3030
pub use crate::correlation::CorrelationExt;
31+
pub use crate::deviation::DeviationExt;
3132
pub use crate::entropy::EntropyExt;
3233
pub use crate::histogram::HistogramExt;
3334
pub use crate::maybe_nan::{MaybeNan, MaybeNanExt};
@@ -69,6 +70,7 @@ mod private {
6970
}
7071

7172
mod correlation;
73+
mod deviation;
7274
mod entropy;
7375
pub mod errors;
7476
pub mod histogram;

0 commit comments

Comments
 (0)