Skip to content

Commit c557a23

Browse files
committed
idx_abs_max: add customization point kokkos#96
Note the different behavior with kokkos-kernels, tracked in issue kokkos#114.
1 parent a015a09 commit c557a23

File tree

6 files changed

+116
-4
lines changed

6 files changed

+116
-4
lines changed

examples/kokkos-based/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
linalg_add_example(add_kokkos)
33
linalg_add_example(dot_kokkos)
44
linalg_add_example(dotc_kokkos)
5+
linalg_add_example(idx_abs_max_kokkos)
56
linalg_add_example(simple_scale_kokkos)
67
linalg_add_example(matrix_vector_product_kokkos)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#include <experimental/linalg>
2+
#include <iostream>
3+
4+
int main(int argc, char* argv[])
5+
{
6+
std::cout << "idx_abs_max example: calling kokkos-kernels" << std::endl;
7+
8+
std::size_t N = 10;
9+
Kokkos::initialize(argc,argv);
10+
{
11+
using value_type = double;
12+
13+
Kokkos::View<value_type*> a_view("A",N);
14+
value_type* a_ptr = a_view.data();
15+
16+
// Requires CTAD working, GCC 11.1 works but some others are buggy
17+
// std::experimental::mdspan a(a_ptr,N);
18+
using extents_type = std::experimental::extents<std::experimental::dynamic_extent>;
19+
std::experimental::mdspan<value_type, extents_type> a(a_ptr,N);
20+
a(0) = 0.5;
21+
a(1) = 0.2;
22+
a(2) = 0.1;
23+
a(3) = 0.4;
24+
a(4) = -0.8;
25+
a(5) = -1.7;
26+
a(6) = -0.3;
27+
a(7) = 0.5;
28+
a(8) = -1.7;
29+
a(9) = -0.9;
30+
31+
namespace stdla = std::experimental::linalg;
32+
33+
// This goes to the base implementation
34+
const auto idx = stdla::idx_abs_max(std::execution::seq, a);
35+
printf("Seq result = %i\n", idx);
36+
37+
// This forwards to KokkosKernels (https://github.com/kokkos/kokkos-kernels
38+
const auto idx_kk = stdla::idx_abs_max(KokkosKernelsSTD::kokkos_exec<>(), a);
39+
printf("Kokkos result = %i\n", idx_kk);
40+
}
41+
Kokkos::finalize();
42+
}

include/experimental/__p1673_bits/blas1_vector_idx_abs_max.hpp

+47-3
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,33 @@ namespace experimental {
4848
inline namespace __p1673_version_0 {
4949
namespace linalg {
5050

51+
// begin anonymous namespace
52+
namespace {
53+
54+
template <class Exec, class v_t, class = void>
55+
struct is_custom_idx_abs_max_avail : std::false_type {};
56+
57+
template <class Exec, class v_t>
58+
struct is_custom_idx_abs_max_avail<
59+
Exec, v_t,
60+
std::enable_if_t<
61+
std::is_integral<
62+
decltype(idx_abs_max(std::declval<Exec>(),
63+
std::declval<v_t>()
64+
)
65+
)
66+
>::value
67+
>
68+
>
69+
{
70+
static constexpr bool value = !std::is_same<Exec,std::experimental::linalg::impl::inline_exec_t>::value;
71+
};
72+
5173
template<class ElementType,
5274
extents<>::size_type ext0,
5375
class Layout,
5476
class Accessor>
55-
extents<>::size_type idx_abs_max(
77+
extents<>::size_type idx_abs_max_default_impl(
5678
std::experimental::mdspan<ElementType, std::experimental::extents<ext0>, Layout, Accessor> v)
5779
{
5880
using std::abs;
@@ -73,16 +95,38 @@ extents<>::size_type idx_abs_max(
7395
return maxInd; // FIXME check for NaN "never less than" stuff
7496
}
7597

98+
} // end anonymous namespace
99+
76100
template<class ExecutionPolicy,
77101
class ElementType,
78102
extents<>::size_type ext0,
79103
class Layout,
80104
class Accessor>
81105
extents<>::size_type idx_abs_max(
82-
ExecutionPolicy&& /* exec */,
106+
ExecutionPolicy&& exec,
107+
std::experimental::mdspan<ElementType, std::experimental::extents<ext0>, Layout, Accessor> v)
108+
{
109+
constexpr bool use_custom = is_custom_idx_abs_max_avail<
110+
decltype(execpolicy_mapper(exec)), decltype(v)
111+
>::value;
112+
113+
if constexpr(use_custom){
114+
using return_type = extents<>::size_type;
115+
return return_type(idx_abs_max(execpolicy_mapper(exec), v));
116+
}
117+
else{
118+
return idx_abs_max_default_impl(v);
119+
}
120+
}
121+
122+
template<class ElementType,
123+
extents<>::size_type ext0,
124+
class Layout,
125+
class Accessor>
126+
extents<>::size_type idx_abs_max(
83127
std::experimental::mdspan<ElementType, std::experimental::extents<ext0>, Layout, Accessor> v)
84128
{
85-
return idx_abs_max(v);
129+
return idx_abs_max(std::experimental::linalg::impl::default_exec_t(), v);
86130
}
87131

88132
} // end namespace linalg

include/experimental/__p1673_bits/linalg_execpolicy_mapper.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct default_exec_t {};
1717
}
1818

1919

20-
#if defined(LINALG_ENABLE_KOKKOS) || defined(LINALG_ENABLE_KOKKOS_DEFAULT)
20+
#if defined(LINALG_ENABLE_KOKKOS) && defined(LINALG_ENABLE_KOKKOS_DEFAULT)
2121
#include <experimental/__p1673_bits/kokkos-kernels/exec_policy_wrapper_kk.hpp>
2222
#endif
2323

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
2+
#ifndef LINALG_TPLIMPLEMENTATIONS_INCLUDE_EXPERIMENTAL_P1673_BITS_KOKKOSKERNELS_IDX_ABS_MAX_HPP_
3+
#define LINALG_TPLIMPLEMENTATIONS_INCLUDE_EXPERIMENTAL_P1673_BITS_KOKKOSKERNELS_IDX_ABS_MAX_HPP_
4+
5+
#include <KokkosBlas1_iamax.hpp>
6+
7+
namespace KokkosKernelsSTD {
8+
9+
template<class ExecSpace,
10+
class ElementType,
11+
std::experimental::extents<>::size_type ext0,
12+
class Layout,
13+
class Accessor>
14+
auto idx_abs_max(kokkos_exec<ExecSpace>,
15+
std::experimental::mdspan<ElementType, std::experimental::extents<ext0>, Layout, Accessor> v)
16+
{
17+
// note that -1 here, this is related to:
18+
// https://github.com/kokkos/stdBLAS/issues/114
19+
20+
return KokkosBlas::iamax(Impl::mdspan_to_view(v))-1;
21+
}
22+
23+
}
24+
#endif

tpl-implementations/include/experimental/linalg_kokkoskernels

+1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
#include "__p1673_bits/kokkos-kernels/blas1_dot_kk.hpp"
66
#include "__p1673_bits/kokkos-kernels/blas1_add_kk.hpp"
77
#include "__p1673_bits/kokkos-kernels/blas1_scale_kk.hpp"
8+
#include "__p1673_bits/kokkos-kernels/blas1_idx_abs_max_kk.hpp"
89
#include "__p1673_bits/kokkos-kernels/blas2_matrix_vector_product_kk.hpp"

0 commit comments

Comments
 (0)