Skip to content

Implement arithmetic ops on more combinations of types #744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 76 additions & 35 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
@@ -53,9 +53,7 @@ macro_rules! impl_binary_op(
/// Perform elementwise
#[doc=$doc]
/// between `self` and `rhs`,
/// and return the result (based on `self`).
///
/// `self` must be an `Array` or `ArcArray`.
/// and return the result.
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
@@ -64,13 +62,13 @@ impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension,
E: Dimension,
{
type Output = ArrayBase<S, D>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> ArrayBase<S, D>
type Output = Array<A, D>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> Array<A, D>
{
self.$mth(&rhs)
}
@@ -79,7 +77,7 @@ where
/// Perform elementwise
#[doc=$doc]
/// between `self` and reference `rhs`,
/// and return the result (based on `self`).
/// and return the result.
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
@@ -88,18 +86,19 @@ impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: DataOwned<Elem=A> + DataMut,
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension,
E: Dimension,
{
type Output = ArrayBase<S, D>;
fn $mth(mut self, rhs: &ArrayBase<S2, E>) -> ArrayBase<S, D>
type Output = Array<A, D>;
fn $mth(self, rhs: &ArrayBase<S2, E>) -> Array<A, D>
{
self.zip_mut_with(rhs, |x, y| {
let mut lhs = self.into_owned();
lhs.zip_mut_with(rhs, |x, y| {
*x = x.clone() $operator y.clone();
});
self
lhs
}
}

@@ -129,22 +128,45 @@ where

/// Perform elementwise
#[doc=$doc]
/// between `self` and the scalar `x`,
/// and return the result (based on `self`).
/// between `self` and `rhs`,
/// and return the result as a new `Array`.
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// `self` must be an `Array` or `ArcArray`.
/// **Panics** if broadcasting isn’t possible.
impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: Data<Elem=A>,
S2: Data<Elem=B>,
D: Dimension,
E: Dimension,
{
type Output = Array<A, D>;
fn $mth(self, rhs: ArrayBase<S2, E>) -> Array<A, D> {
// FIXME: Can we co-broadcast arrays here? And how?
self.to_owned().$mth(rhs)
}
}

/// Perform elementwise
#[doc=$doc]
/// between `self` and the scalar `x`,
/// and return the result.
impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
where A: Clone + $trt<B, Output=A>,
S: DataOwned<Elem=A> + DataMut,
S: Data<Elem=A>,
D: Dimension,
B: ScalarOperand,
{
type Output = ArrayBase<S, D>;
fn $mth(mut self, x: B) -> ArrayBase<S, D> {
self.unordered_foreach_mut(move |elt| {
type Output = Array<A, D>;
fn $mth(self, x: B) -> Array<A, D> {
let mut lhs = self.into_owned();
lhs.unordered_foreach_mut(move |elt| {
*elt = elt.clone() $operator x.clone();
});
self
lhs
}
}

@@ -183,17 +205,17 @@ macro_rules! impl_scalar_lhs_op {
// these have no doc -- they are not visible in rustdoc
// Perform elementwise
// between the scalar `self` and array `rhs`,
// and return the result (based on `self`).
// and return the result.
impl<S, D> $trt<ArrayBase<S, D>> for $scalar
where S: DataOwned<Elem=$scalar> + DataMut,
where S: Data<Elem=$scalar>,
D: Dimension,
{
type Output = ArrayBase<S, D>;
fn $mth(self, rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
type Output = Array<$scalar, D>;
fn $mth(self, rhs: ArrayBase<S, D>) -> Array<$scalar, D> {
if_commutative!($commutative {
rhs.$mth(self)
} or {{
let mut rhs = rhs;
let mut rhs = rhs.into_owned();
rhs.unordered_foreach_mut(move |elt| {
*elt = self $operator *elt;
});
@@ -293,16 +315,17 @@ mod arithmetic_ops {
impl<A, S, D> Neg for ArrayBase<S, D>
where
A: Clone + Neg<Output = A>,
S: DataOwned<Elem = A> + DataMut,
S: Data<Elem = A>,
D: Dimension,
{
type Output = Self;
type Output = Array<A, D>;
/// Perform an elementwise negation of `self` and return the result.
fn neg(mut self) -> Self {
self.unordered_foreach_mut(|elt| {
fn neg(self) -> Array<A, D> {
let mut array = self.into_owned();
array.unordered_foreach_mut(|elt| {
*elt = -elt.clone();
});
self
array
}
}

@@ -323,16 +346,17 @@ mod arithmetic_ops {
impl<A, S, D> Not for ArrayBase<S, D>
where
A: Clone + Not<Output = A>,
S: DataOwned<Elem = A> + DataMut,
S: Data<Elem = A>,
D: Dimension,
{
type Output = Self;
type Output = Array<A, D>;
/// Perform an elementwise unary not of `self` and return the result.
fn not(mut self) -> Self {
self.unordered_foreach_mut(|elt| {
fn not(self) -> Array<A, D> {
let mut array = self.into_owned();
array.unordered_foreach_mut(|elt| {
*elt = !elt.clone();
});
self
array
}
}

@@ -359,6 +383,23 @@ mod assign_ops {
($trt:ident, $method:ident, $doc:expr) => {
use std::ops::$trt;

#[doc=$doc]
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// **Panics** if broadcasting isn’t possible.
impl<A, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<A>,
S: DataMut<Elem = A>,
S2: Data<Elem = A>,
D: Dimension,
E: Dimension,
{
fn $method(&mut self, rhs: ArrayBase<S2, E>) {
self.$method(&rhs)
}
}

#[doc=$doc]
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
31 changes: 13 additions & 18 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -607,18 +607,14 @@ pub type Ixs = isize;
///
/// ### Binary Operators with Two Arrays
///
/// Let `A` be an array or view of any kind. Let `B` be an array
/// with owned storage (either `Array` or `ArcArray`).
/// Let `C` be an array with mutable data (either `Array`, `ArcArray`
/// or `ArrayViewMut`).
/// The following combinations of operands
/// are supported for an arbitrary binary operator denoted by `@` (it can be
/// `+`, `-`, `*`, `/` and so on).
///
/// - `&A @ &A` which produces a new `Array`
/// - `B @ A` which consumes `B`, updates it with the result, and returns it
/// - `B @ &A` which consumes `B`, updates it with the result, and returns it
/// - `C @= &A` which performs an arithmetic operation in place
/// Let `A` be an array or view of any kind. Let `M` be an array with mutable
/// data (either `Array`, `ArcArray` or `ArrayViewMut`). The following
/// combinations of operands are supported for an arbitrary binary operator
/// denoted by `@` (it can be `+`, `-`, `*`, `/` and so on).
///
/// - `&A @ &A` or `&A @ A` which produce a new `Array`
/// - `A @ &A` or `A @ A` which may reuse the allocation of the LHS if it's an owned array
/// - `M @= &A` or `M @= A` which performs an arithmetic operation in place on `M`
///
/// Note that the element type needs to implement the operator trait and the
/// `Clone` trait.
@@ -647,17 +643,16 @@ pub type Ixs = isize;
/// `ScalarOperand` docs has the detailed condtions).
///
/// - `&A @ K` or `K @ &A` which produces a new `Array`
/// - `B @ K` or `K @ B` which consumes `B`, updates it with the result and returns it
/// - `C @= K` which performs an arithmetic operation in place
/// - `A @ K` or `K @ A` which may reuse the allocation of the array if it's an owned array
/// - `M @= K` which performs an arithmetic operation in place
///
/// ### Unary Operators
///
/// Let `A` be an array or view of any kind. Let `B` be an array with owned
/// storage (either `Array` or `ArcArray`). The following operands are supported
/// for an arbitrary unary operator denoted by `@` (it can be `-` or `!`).
/// The following operands are supported for an arbitrary unary operator
/// denoted by `@` (it can be `-` or `!`).
///
/// - `@&A` which produces a new `Array`
/// - `@B` which consumes `B`, updates it with the result, and returns it
/// - `@A` which may reuse the allocation of the array if it's an owned array
///
/// ## Broadcasting
///
10 changes: 5 additions & 5 deletions tests/array.rs
Original file line number Diff line number Diff line change
@@ -394,11 +394,11 @@ fn test_add() {
}

let B = A.clone();
A = A + &B;
assert_eq!(A[[0, 0]], 0);
assert_eq!(A[[0, 1]], 2);
assert_eq!(A[[1, 0]], 4);
assert_eq!(A[[1, 1]], 6);
let C = A + &B;
assert_eq!(C[[0, 0]], 0);
assert_eq!(C[[0, 1]], 2);
assert_eq!(C[[1, 0]], 4);
assert_eq!(C[[1, 1]], 6);
}

#[test]