Skip to content

Commit 3f12d39

Browse files
committed
Make mapv_into_any() work for ArcArray, resolves #1280
1 parent 0740695 commit 3f12d39

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

src/impl_methods.rs

+17-5
Original file line numberDiff line numberDiff line change
@@ -2586,15 +2586,27 @@ where
25862586
/// map is performed as in [`mapv`].
25872587
///
25882588
/// Elements are visited in arbitrary order.
2589-
///
2589+
///
2590+
/// Note that the compiler will need some hint about the return type, which
2591+
/// is generic over [`DataOwned`], and can thus be an [`Array`] or
2592+
/// [`ArcArray`]. Example:
2593+
///
2594+
/// ```rust
2595+
/// # use ndarray::{array, Array};
2596+
/// let a = array![[1., 2., 3.]];
2597+
/// let a_plus_one: Array<_, _> = a.mapv_into_any(|a| a + 1.);
2598+
/// ```
2599+
///
25902600
/// [`mapv_into`]: ArrayBase::mapv_into
25912601
/// [`mapv`]: ArrayBase::mapv
2592-
pub fn mapv_into_any<B, F>(self, mut f: F) -> Array<B, D>
2602+
pub fn mapv_into_any<B, F, T>(self, mut f: F) -> ArrayBase<T, D>
25932603
where
25942604
S: DataMut,
25952605
F: FnMut(A) -> B,
25962606
A: Clone + 'static,
25972607
B: 'static,
2608+
T: DataOwned<Elem = B>,
2609+
ArrayBase<T, D>: From<Array<B, D>>,
25982610
{
25992611
if core::any::TypeId::of::<A>() == core::any::TypeId::of::<B>() {
26002612
// A and B are the same type.
@@ -2606,14 +2618,14 @@ where
26062618
};
26072619
// Delegate to mapv_into() using the wrapped closure.
26082620
// Convert output to a uniquely owned array of type Array<A, D>.
2609-
let output = self.mapv_into(f).into_owned();
2621+
let output = self.mapv_into(f).into();
26102622
// Change the return type from Array<A, D> to Array<B, D>.
26112623
// Again, safe because A and B are the same type.
2612-
unsafe { unlimited_transmute::<Array<A, D>, Array<B, D>>(output) }
2624+
unsafe { unlimited_transmute::<ArrayBase<S, D>, ArrayBase<T, D>>(output) }
26132625
} else {
26142626
// A and B are not the same type.
26152627
// Fallback to mapv().
2616-
self.mapv(f)
2628+
self.mapv(f).into()
26172629
}
26182630
}
26192631

tests/array.rs

+20-2
Original file line numberDiff line numberDiff line change
@@ -995,14 +995,32 @@ fn map1() {
995995
fn mapv_into_any_same_type() {
996996
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
997997
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
998-
assert_eq!(a.mapv_into_any(|a| a + 1.), a_plus_one);
998+
let b: Array<_, _> = a.mapv_into_any(|a| a + 1.);
999+
assert_eq!(b, a_plus_one);
9991000
}
10001001

10011002
#[test]
10021003
fn mapv_into_any_diff_types() {
10031004
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
10041005
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
1005-
assert_eq!(a.mapv_into_any(|a| a.round() as i32 % 2 == 0), a_even);
1006+
let b: Array<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
1007+
assert_eq!(b, a_even);
1008+
}
1009+
1010+
#[test]
1011+
fn mapv_into_any_arcarray_same_type() {
1012+
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
1013+
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
1014+
let b: ArcArray<_, _> = a.mapv_into(|a| a + 1.);
1015+
assert_eq!(b, a_plus_one);
1016+
}
1017+
1018+
#[test]
1019+
fn mapv_into_any_arcarray_diff_types() {
1020+
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
1021+
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
1022+
let b: ArcArray<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
1023+
assert_eq!(b, a_even);
10061024
}
10071025

10081026
#[test]

0 commit comments

Comments
 (0)