From ffe7c16e4cf1488f1040704c20272c4f804542f6 Mon Sep 17 00:00:00 2001 From: Kris Thielemans Date: Fri, 19 Jul 2024 09:20:22 +0100 Subject: [PATCH] add apply_binary_func_element_wise for Arrays --- src/include/stir/ArrayFunction.h | 102 ++++++++++++++++++----- src/include/stir/ArrayFunction.inl | 129 ++++++++++++++++++++++++++--- src/test/test_Array.cxx | 69 ++++++++++++++- 3 files changed, 262 insertions(+), 38 deletions(-) diff --git a/src/include/stir/ArrayFunction.h b/src/include/stir/ArrayFunction.h index 2d7b20f172..d97d19f40f 100644 --- a/src/include/stir/ArrayFunction.h +++ b/src/include/stir/ArrayFunction.h @@ -1,6 +1,7 @@ /* Copyright (C) 2000 PARAPET partners Copyright (C) 2000- 2007, Hammersmith Imanet Ltd + Copyright (C) 2024, University College London This file is part of STIR. SPDX-License-Identifier: Apache-2.0 AND License-ref-PARAPET-license @@ -22,37 +23,30 @@ functions which work on all stir::Array objects, and which change every element of the array: All these functions return a reference to the (modified) array +
  • Analoguous functions that take out_array and in_array +
  • +
  • + Functions that apply a binary function element-wise on arrays: + +
  • - - \warning Compilers without partial specialisation of templates are - catered for by explicit instantiations. If you need it for any other - types, you'd have to add them by hand. */ -/* History: - - KT 21/05/2001 - added in_place_apply_array_function_on_1st_index, - in_place_apply_array_function_on_each_index - - KT 06/12/2001 - added apply_array_function_on_1st_index, - apply_array_function_on_each_index -*/ - #ifndef __stir_ArrayFunction_H__ #define __stir_ArrayFunction_H__ @@ -253,15 +247,77 @@ inline void apply_array_functions_on_each_index(Array<1, elemT>& out_array, ActualFunctionObjectPtrIter start, ActualFunctionObjectPtrIter stop); -template //! 1d specialisation for general function objects /*! \ingroup Array */ +template inline void apply_array_functions_on_each_index(Array<1, elemT>& out_array, const Array<1, elemT>& in_array, FunctionObjectPtrIter start, FunctionObjectPtrIter stop); +//! \name apply binary function element-wise +/*! \ingroup Array + arrays need to have the same number of elements, but can have different shapes. +*/ +//@{ +//! apply binary function element-wise and store in other array +template +inline void apply_binary_func_element_wise(Array& out, + const Array& in1, + const Array& in2, + BinaryFunctionT f); + +//! conditionally apply binary function element-wise and store in other array +/*! + arrays need to have the same number of elements, but can have different shapes. + + Element-wise loop, only computing/storing results if predicate(*in1_full_iter, *in2_full_iter)==true. + Other elements in \c out are unassigned. +*/ +template +inline void apply_binary_func_element_wise(Array& out, + const Array& in1, + const Array& in2, + PredicateBinaryFunctionT predicate, + BinaryFunctionT f); + +//! conditionally apply binary function element-wise and store in other array +/*! + arrays need to have the same number of elements, but can have different shapes. + + Element-wise loop, only computing/storing results if bool(*where_full_iter)==true. + Other elements in \c out are unassigned. +*/ +template +inline void apply_binary_func_element_wise(Array& out, + const Array& in1, + const Array& in2, + const Array& where, + BinaryFunctionT f); + +template +inline void +in_place_apply_binary_func_element_wise(Array& inout, const Array& in2, BinaryFunctionT f); + +template +inline void in_place_apply_binary_func_element_wise(Array& inout, + const Array& in2, + PredicateBinaryFunctionT predicate, + BinaryFunctionT f); + +template +inline void in_place_apply_binary_func_element_wise(Array& inout, + const Array& in2, + const Array& where, + BinaryFunctionT f); +//@} + template inline void transform_array_to_periodic_indices(Array& out_array, const Array& in_array); template diff --git a/src/include/stir/ArrayFunction.inl b/src/include/stir/ArrayFunction.inl index 612b7d8c03..10ae2c7910 100644 --- a/src/include/stir/ArrayFunction.inl +++ b/src/include/stir/ArrayFunction.inl @@ -1,6 +1,7 @@ /* Copyright (C) 2000 PARAPET partners Copyright (C) 2000- 2007, Hammersmith Imanet Ltd + Copyright (C) 2024, University College London This file is part of STIR. SPDX-License-Identifier: Apache-2.0 AND License-ref-PARAPET-license @@ -16,10 +17,6 @@ \author Kris Thielemans (some functions based on some earlier work by Darren Hague) \author PARAPET project - - \warning Compilers without partial specialisation of templates are - catered for by explicit instantiations. If you need it for any other - types, you'd have to add them by hand. */ #include "stir/BasicCoordinate.h" #include "stir/array_index_functions.h" @@ -27,13 +24,7 @@ #include #include -#ifdef BOOST_NO_STDC_NAMESPACE -namespace std -{ -using ::log; -using ::exp; -} // namespace std -#endif +#include START_NAMESPACE_STIR @@ -110,6 +101,122 @@ in_place_apply_function(T& v, FUNCTION f) return v; } +template +inline void +apply_binary_func_element_wise(Array& out, + const Array& in1, + const Array& in2, + BinaryFunctionT f) +{ + std::transform(in1.begin_all(), in1.end_all(), in2.begin_all(), out.begin_all(), f); +} + +template +inline void +apply_binary_func_element_wise(Array& out, + const Array& in1, + const Array& in2, + PredicateBinaryFunctionT predicate, + BinaryFunctionT f) +{ + auto in1_iter = in1.begin_all(); + const auto in1_end = in1.end_all(); + auto in2_iter = in2.begin_all(); + auto out_iter = out.begin_all(); + while (in1_iter != in1_end) + { + if (predicate(*in1_iter, *in2_iter)) + *out_iter = f(*in1_iter, *in2_iter); + ++in1_iter; + ++in2_iter; + ++out_iter; + } +} + +template +inline void +apply_binary_func_element_wise(Array& out, + const Array& in1, + const Array& in2, + const Array& where, + BinaryFunctionT f) +{ + auto in1_iter = in1.begin_all(); + const auto in1_end = in1.end_all(); + auto in2_iter = in2.begin_all(); + auto out_iter = out.begin_all(); + auto where_iter = where.begin_all(); + while (in1_iter != in1_end) + { + if (*where_iter) + *out_iter = f(*in1_iter, *in2_iter); + ++in1_iter; + ++in2_iter; + ++where_iter; + ++out_iter; + } +} + +template +inline void +in_place_apply_binary_func_element_wise(Array& inout, const Array& in2, BinaryFunctionT f) +{ + auto inout_iter = inout.begin_all(); + const auto inout_end = inout.end_all(); + auto in2_iter = in2.begin_all(); + while (inout_iter != inout_end) + { + *inout_iter = f(*inout_iter, *in2_iter); + ++inout_iter; + ++in2_iter; + } +} + +template +inline void +in_place_apply_binary_func_element_wise(Array& inout, + const Array& in2, + PredicateBinaryFunctionT predicate, + BinaryFunctionT f) +{ + auto inout_iter = inout.begin_all(); + const auto inout_end = inout.end_all(); + auto in2_iter = in2.begin_all(); + while (inout_iter != inout_end) + { + if (predicate(*inout_iter, *in2_iter)) + *inout_iter = f(*inout_iter, *in2_iter); + ++inout_iter; + ++in2_iter; + } +} + +template +inline void +in_place_apply_binary_func_element_wise(Array& inout, + const Array& in2, + const Array& where, + BinaryFunctionT f) +{ + auto inout_iter = inout.begin_all(); + const auto inout_end = inout.end_all(); + auto in2_iter = in2.begin_all(); + auto where_iter = where.begin_all(); + while (inout_iter != inout_end) + { + if (*where_iter) + *inout_iter = f(*inout_iter, *in2_iter); + ++inout_iter; + ++in2_iter; + ++where_iter; + } +} + template inline void in_place_apply_array_function_on_1st_index(Array& array, FunctionObjectPtr f) diff --git a/src/test/test_Array.cxx b/src/test/test_Array.cxx index ad5db37152..b846a450fc 100644 --- a/src/test/test_Array.cxx +++ b/src/test/test_Array.cxx @@ -54,15 +54,15 @@ #include "stir/HighResWallClockTimer.h" -#include +#include #include #include #include using std::ofstream; using std::ifstream; -using std::plus; using std::cerr; using std::endl; +using std::FILE; START_NAMESPACE_STIR @@ -472,10 +472,17 @@ ArrayTests::run_tests() Array<2, float> t2 = t2fp + testfp; check_if_equal(t2[3][2], 5.5F, "test operator +(Array2D)"); + { + // tests using apply_binary_func_element_wise, reproducing + and - + Array<2, float> t2func(t2fp.get_index_range()); + apply_binary_func_element_wise(t2func, t2fp, testfp, std::plus<>()); + check_if_equal(t2, t2func, "test apply_binary_func_element_wise (Array2D)"); + in_place_apply_binary_func_element_wise(t2func, testfp, std::minus<>()); + check_if_equal(t2fp, t2func, "test in_place_apply_binary_func_element_wise (Array2D)"); + } t2fp += testfp; check_if_equal(t2fp[3][2], 5.5F, "test operator +=(Array2D)"); check_if_equal(t2, t2fp, "test comparing Array2D+= and +"); - { BasicCoordinate<2, int> c; c[1] = 3; @@ -686,7 +693,7 @@ ArrayTests::run_tests() check_if_zero(test3.sum() - 2 * tmp2.sum() - tmp.sum(), "test operator-(float)"); } - in_place_apply_function(test3ter, std::bind(plus(), std::placeholders::_1, 4.F)); + in_place_apply_function(test3ter, std::bind(std::plus(), std::placeholders::_1, 4.F)); test3quat += 4.F; check_if_equal(test3quat, test3ter, "test in_place_apply_function and operator+=(NUMBER)"); @@ -756,6 +763,60 @@ ArrayTests::run_tests() check_if_equal(test3, data_to_fill, "test on 3D copy_to, irregular range"); } } + + { + // tests using apply_binary_func_element_wise, reproducing + and - + IndexRange<3> range(Coordinate3D(0, 0, 1), Coordinate3D(2, 2, 3)); + Array<3, float> test(range), test1(range), test2(range); + std::iota(test2.begin_all(), test2.end_all(), -4.65F); + std::transform(test2.begin_all(), test2.end_all(), test1.begin_all(), [](auto a) { return square(a); }); + test.fill(-1000.F); + apply_binary_func_element_wise( + test, test1, test2, [](auto a, auto b) { return a > b; }, [](auto a, auto b) { return b; }); + { + auto test1_iter = test1.begin_all_const(); + auto test2_iter = test2.begin_all_const(); + for (auto test_iter = test.begin_all(); test_iter != test.end_all(); ++test_iter, ++test1_iter, ++test2_iter) + { + check_if_equal(*test_iter, + *test1_iter > *test2_iter ? *test2_iter : -1000.F, + "test apply_binary_func_element_wise with predicate"); + } + } + + Array<3, bool> where(range); + apply_binary_func_element_wise(where, test1, test2, [](auto a, auto b) { return a > b; }); + { + Array<3, float> test_with_where(range); + test_with_where.fill(-1000.F); + apply_binary_func_element_wise(test_with_where, test1, test2, where, [](auto a, auto b) { return b; }); + check_if_equal(test, test_with_where, "test apply_binary_func_element_wise with 'where;"); + } + + test = test1; + in_place_apply_binary_func_element_wise( + test, test2, [](auto a, auto b) { return a > b; }, [](auto a, auto b) { return b; }); + { + auto test1_iter = test1.begin_all_const(); + auto test2_iter = test2.begin_all_const(); + for (auto test_iter = test.begin_all(); test_iter != test.end_all(); ++test_iter, ++test1_iter, ++test2_iter) + { + check_if_equal( + *test_iter, std::min(*test1_iter, *test2_iter), "test in_place_apply_binary_func_element_wise with predicate"); + } + } + test = test1; + in_place_apply_binary_func_element_wise(test, test2, where, [](auto a, auto b) { return b; }); + { + auto test1_iter = test1.begin_all_const(); + auto test2_iter = test2.begin_all_const(); + for (auto test_iter = test.begin_all(); test_iter != test.end_all(); ++test_iter, ++test1_iter, ++test2_iter) + { + check_if_equal( + *test_iter, std::min(*test1_iter, *test2_iter), "test in_place_apply_binary_func_element_wise with 'where'"); + } + } + } } {