Skip to content

Commit

Permalink
feat(ttm): add taskloops
Browse files Browse the repository at this point in the history
  • Loading branch information
bassoy committed Nov 24, 2024
1 parent 32053ff commit 97b0900
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 23 deletions.
4 changes: 3 additions & 1 deletion include/tlib/detail/tags.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ namespace tlib::ttm::parallel_policy
struct sequential_t {}; // sequential loops and sequential gemm
struct parallel_blas_t {}; // multithreaded gemm
struct parallel_loop_t {}; // omp_for with single threaded gemm
struct parallel_taskloop_t {}; // omp_task for each loop with single threaded gemm
struct parallel_taskloop_t {}; // omp_taskloops for each loop with single threaded gemm
struct parallel_task_t {}; // omp_task for each recursion with single threaded gemm
struct parallel_loop_blas_t {}; // omp_for with multi-threaded gemm
struct batched_gemm_t {}; // multithreaded batched gemm with collapsed loops
struct combined_t {};
Expand All @@ -34,6 +35,7 @@ inline constexpr sequential_t sequential;
inline constexpr parallel_blas_t parallel_blas;
inline constexpr parallel_loop_t parallel_loop;
inline constexpr parallel_taskloop_t parallel_taskloop;
inline constexpr parallel_task_t parallel_task;
inline constexpr parallel_loop_blas_t parallel_loop_blas;
inline constexpr batched_gemm_t batched_gemm;
inline constexpr combined_t combined;
Expand Down
201 changes: 181 additions & 20 deletions include/tlib/detail/ttm.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,46 @@ inline void taskloops_over_gemm_with_slices (
value_t *c, size_t const*const nc, size_t const*const wc)
{
if(r>1){
#pragma omp task untied
{
if (r == qh) { // q == pia[r]
taskloops_over_gemm_with_slices(std::forward<gemm_t>(gemm), r-1, qh, a,na,wa,pia, b, c,nc,wc);
}
else{ // r>1 && r != qh
auto pia_r = pia[r-1]-1;
for(unsigned i = 0; i < na[pia_r]; ++i, a+=wa[pia_r], c+=wc[pia_r]){
taskloops_over_gemm_with_slices(std::forward<gemm_t>(gemm), r-1, qh, a,na,wa,pia, b, c,nc,wc);
#pragma omp taskloop untied
for(unsigned i = 0; i < na[pia_r]; ++i){
auto aa=a+i*wa[pia_r];
auto cc=c+i*wc[pia_r];
taskloops_over_gemm_with_slices(std::forward<gemm_t>(gemm), r-1, qh, aa,na,wa,pia, b, cc,nc,wc);
}
}
}
else {
gemm( a, b, c );
}
}


template<class value_t, class size_t, class gemm_t>
inline void tasks_over_gemm_with_slices (
gemm_t && gemm,
unsigned const r, // starts with p
unsigned const qh, // 1 <= qh <= p with \hat{q} = pi^{-1}(q)
const value_t *a, size_t const*const na, size_t const*const wa, size_t const*const pia,
const value_t *b,
value_t *c, size_t const*const nc, size_t const*const wc)
{
#pragma omp task untied
if(r>1){
if (r == qh) { // q == pia[r]
tasks_over_gemm_with_slices(std::forward<gemm_t>(gemm), r-1, qh, a,na,wa,pia, b, c,nc,wc);
}
else{ // r>1 && r != qh
auto pia_r = pia[r-1]-1;
for(unsigned i = 0; i < na[pia_r]; ++i){
auto aa=a+i*wa[pia_r];
auto cc=c+i*wc[pia_r];
tasks_over_gemm_with_slices(std::forward<gemm_t>(gemm), r-1, qh, aa,na,wa,pia, b, cc,nc,wc);
}
}
}
else {
Expand Down Expand Up @@ -297,18 +326,48 @@ inline void taskloops_over_gemm_with_subtensors (
const value_t *b,
value_t *c, size_t const*const nc, size_t const*const wc )
{
//#pragma omp task untied
if(r>1){
#pragma omp task untied
{
if (r <= qh) {
taskloops_over_gemm_with_subtensors (std::forward<gemm_t>(gemm), r-1, qh, a,na,wa,pia, b, c,nc,wc);
}
else if (r > qh){
auto pia_r = pia[r-1]-1u;
for(size_t i = 0; i < na[pia_r]; ++i, a+=wa[pia_r], c+=wc[pia_r]){
taskloops_over_gemm_with_subtensors (std::forward<gemm_t>(gemm), r-1, qh, a,na,wa,pia, b, c,nc,wc);
#pragma omp taskloop untied
for(size_t i = 0; i < na[pia_r]; ++i){
auto aa=a+i*wa[pia_r];
auto cc=c+i*wc[pia_r];
taskloops_over_gemm_with_subtensors (std::forward<gemm_t>(gemm), r-1, qh, aa,na,wa,pia, b, cc,nc,wc);
}
}
}
else {
gemm(a,b,c);
}
}


template<class value_t, class size_t, class gemm_t>
inline void tasks_over_gemm_with_subtensors (
gemm_t && gemm,
unsigned const r, // starts with p
unsigned const qh, // qhat one-based
const value_t *a, size_t const*const na, size_t const*const wa, size_t const*const pia,
const value_t *b,
value_t *c, size_t const*const nc, size_t const*const wc )
{
#pragma omp task untied
if(r>1){
if (r <= qh) {
tasks_over_gemm_with_subtensors (std::forward<gemm_t>(gemm), r-1, qh, a,na,wa,pia, b, c,nc,wc);
}
else if (r > qh){
auto pia_r = pia[r-1]-1u;
for(size_t i = 0; i < na[pia_r]; ++i){
auto aa=a+i*wa[pia_r];
auto cc=c+i*wc[pia_r];
tasks_over_gemm_with_subtensors (std::forward<gemm_t>(gemm), r-1, qh, aa,na,wa,pia, b, cc,nc,wc);
}
}
}
else {
Expand Down Expand Up @@ -427,12 +486,11 @@ inline void ttm(parallel_policy::parallel_taskloop_t, slicing_policy::slice_t, f
const value_t *b, size_t const*const nb, size_t const*const pib,
value_t *c, size_t const*const nc, size_t const*const wc )
{
set_blas_threads_max();
assert(get_blas_threads() > 1u || get_blas_threads() <= cores);

set_omp_nested();
auto is_cm = pib[0] == 1;

if(!is_case<8>(p,q,pia)){
set_blas_threads_max();
assert(get_blas_threads() > 1u || get_blas_threads() <= cores);
if(is_cm)
mtm_cm(q, p, a, na, pia, b, nb, c, nc );
else
Expand All @@ -450,9 +508,61 @@ inline void ttm(parallel_policy::parallel_taskloop_t, slicing_policy::slice_t, f

auto gemm_col = std::bind(gemm_col_tr2::run<value_t>,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c
auto gemm_row = std::bind(gemm_row:: run<value_t>,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c

#pragma omp parallel num_threads(cores)
{
set_blas_threads_min();
assert(get_blas_threads() == 1u);
#pragma omp single
{
if(is_cm) taskloops_over_gemm_with_slices(gemm_col, p, qh, a,na,wa,pia, b, c,nc,wc);
else taskloops_over_gemm_with_slices(gemm_row, p, qh, a,na,wa,pia, b, c,nc,wc);
}
}
}
}


template<class value_t, class size_t>
inline void ttm(parallel_policy::parallel_task_t, slicing_policy::slice_t, fusion_policy::none_t,
unsigned const q, unsigned const p,
const value_t *a, size_t const*const na, size_t const*const wa, size_t const*const pia,
const value_t *b, size_t const*const nb, size_t const*const pib,
value_t *c, size_t const*const nc, size_t const*const wc )
{

if(is_cm) taskloops_over_gemm_with_slices(gemm_col, p, qh, a,na,wa,pia, b, c,nc,wc);
else taskloops_over_gemm_with_slices(gemm_row, p, qh, a,na,wa,pia, b, c,nc,wc);
set_omp_nested();
auto is_cm = pib[0] == 1;
if(!is_case<8>(p,q,pia)){
set_blas_threads_max();
assert(get_blas_threads() > 1u || get_blas_threads() <= cores);
if(is_cm)
mtm_cm(q, p, a, na, pia, b, nb, c, nc );
else
mtm_rm(q, p, a, na, pia, b, nb, c, nc );
}
else {
auto const qh = inverse_mode(pia, pia+p, q);

using namespace std::placeholders;

auto n1 = na[pia[0]-1];
auto m = nc[q-1];
auto nq = na[q-1];
auto wq = wa[q-1];
auto gemm_col = std::bind(gemm_col_tr2::run<value_t>,_1,_2,_3, n1,m,nq, wq, m,wq); // a,b,c
auto gemm_row = std::bind(gemm_row:: run<value_t>,_2,_1,_3, m,n1,nq, nq,wq,wq); // b,a,c

#pragma omp parallel num_threads(cores)
{
set_blas_threads_min();
assert(get_blas_threads() == 1u);
#pragma omp single
{
if(is_cm) tasks_over_gemm_with_slices(gemm_col, p, qh, a,na,wa,pia, b, c,nc,wc);
else tasks_over_gemm_with_slices(gemm_row, p, qh, a,na,wa,pia, b, c,nc,wc);
}
}
}
}

Expand Down Expand Up @@ -886,12 +996,12 @@ inline void ttm(
value_t *c, size_t const*const nc, size_t const*const wc
)
{
set_blas_threads_max();
assert(get_blas_threads() > 1 || get_blas_threads() <= cores);



set_omp_nested();
auto is_cm = pib[0] == 1;
if(!is_case<8>(p,q,pia)){
set_blas_threads_max();
assert(get_blas_threads() > 1 || get_blas_threads() <= cores);
if(is_cm)
mtm_cm(q, p, a, na, pia, b, nb, c, nc );
else
Expand All @@ -908,9 +1018,60 @@ inline void ttm(
auto gemm_col = std::bind(gemm_col_tr2::run<value_t>,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c
auto gemm_row = std::bind(gemm_row:: run<value_t>,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c

if(is_cm) taskloops_over_gemm_with_subtensors(gemm_col, p, qh, a,na,wa,pia, b, c,nc,wc);
else taskloops_over_gemm_with_subtensors(gemm_row, p, qh, a,na,wa,pia, b, c,nc,wc);
#pragma omp parallel num_threads(cores)
{
set_blas_threads_min();
assert(get_blas_threads() == 1);
#pragma omp single
{
if(is_cm) taskloops_over_gemm_with_subtensors(gemm_col, p, qh, a,na,wa,pia, b, c,nc,wc);
else taskloops_over_gemm_with_subtensors(gemm_row, p, qh, a,na,wa,pia, b, c,nc,wc);
}
}
}
}


template<class value_t, class size_t>
inline void ttm(
parallel_policy::parallel_task_t, slicing_policy::subtensor_t, fusion_policy::none_t,
unsigned const q, unsigned const p,
const value_t *a, size_t const*const na, size_t const*const wa, size_t const*const pia,
const value_t *b, size_t const*const nb, size_t const*const pib,
value_t *c, size_t const*const nc, size_t const*const wc
)
{

set_omp_nested();
auto is_cm = pib[0] == 1;
if(!is_case<8>(p,q,pia)){
set_blas_threads_max();
assert(get_blas_threads() > 1 || get_blas_threads() <= cores);
if(is_cm)
mtm_cm(q, p, a, na, pia, b, nb, c, nc );
else
mtm_rm(q, p, a, na, pia, b, nb, c, nc );
}
else {
auto const qh = inverse_mode(pia, pia+p, q);
auto const nnq = product(na, pia, 1, qh);
auto const m = nc[q-1];
auto const nq = na[q-1];

using namespace std::placeholders;
auto gemm_col = std::bind(gemm_col_tr2::run<value_t>,_1,_2,_3, nnq,m,nq, nnq, m,nnq); // a,b,c
auto gemm_row = std::bind(gemm_row:: run<value_t>,_2,_1,_3, m,nnq,nq, nq,nnq,nnq); // b,a,c

#pragma omp parallel num_threads(cores)
{
set_blas_threads_min();
assert(get_blas_threads() == 1);
#pragma omp single
{
if(is_cm) tasks_over_gemm_with_subtensors(gemm_col, p, qh, a,na,wa,pia, b, c,nc,wc);
else tasks_over_gemm_with_subtensors(gemm_row, p, qh, a,na,wa,pia, b, c,nc,wc);
}
}
}
}

Expand Down
31 changes: 29 additions & 2 deletions test/src/gtest_tlib_ttm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ TEST(TensorTimesMatrix, ParallelGemmSliceNoFusion)
check_tensor_times_matrix<vt,st,ep,sp,fp,4u>(2u,3);
}

TEST(TensorTimesMatrix, ParallelTaskSliceNoFusion)
TEST(TensorTimesMatrix, ParallelTaskLoopSliceNoFusion)
{
using vt = double;
using st = std::size_t;
Expand All @@ -300,6 +300,19 @@ TEST(TensorTimesMatrix, ParallelTaskSliceNoFusion)
check_tensor_times_matrix<vt,st,ep,sp,fp,4u>(2u,3);
}

TEST(TensorTimesMatrix, ParallelTaskSliceNoFusion)
{
using vt = double;
using st = std::size_t;
using ep = parallel_policy::parallel_task_t;
using sp = slicing_policy::slice_t;
using fp = fusion_policy::none_t;

check_tensor_times_matrix<vt,st,ep,sp,fp,2u>(2u,3);
check_tensor_times_matrix<vt,st,ep,sp,fp,3u>(2u,3);
check_tensor_times_matrix<vt,st,ep,sp,fp,4u>(2u,3);
}



TEST(TensorTimesMatrix, ParallelGemmSubtensorNoFusion)
Expand Down Expand Up @@ -379,7 +392,7 @@ TEST(TensorTimesMatrix, SequentialSubtensorNoFusion)
}


TEST(TensorTimesMatrix, ParallelTaskSubtensorNone)
TEST(TensorTimesMatrix, ParallelTaskLoopSubtensorNone)
{
using vt = double;
using st = std::size_t;
Expand All @@ -393,6 +406,20 @@ TEST(TensorTimesMatrix, ParallelTaskSubtensorNone)
// check_tensor_times_matrix<vt,st,ep,sp,fp,5u>(2u,3);
}

TEST(TensorTimesMatrix, ParallelTaskSubtensorNone)
{
using vt = double;
using st = std::size_t;
using ep = parallel_policy::parallel_task_t;
using sp = slicing_policy::subtensor_t;
using fp = fusion_policy::none_t;

check_tensor_times_matrix<vt,st,ep,sp,fp,2u>(2u,3);
check_tensor_times_matrix<vt,st,ep,sp,fp,3u>(2u,3);
check_tensor_times_matrix<vt,st,ep,sp,fp,4u>(2u,3);
// check_tensor_times_matrix<vt,st,ep,sp,fp,5u>(2u,3);
}

TEST(TensorTimesMatrix, ParallelLoopSubtensorOuterFusion)
{
using vt = double;
Expand Down

0 comments on commit 97b0900

Please sign in to comment.