Skip to content

Commit

Permalink
refactor(ttm/ttmpy): rename ttm tags.
Browse files Browse the repository at this point in the history
  • Loading branch information
bassoy committed Oct 27, 2024
1 parent 2a144f5 commit 90f3675
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 82 deletions.
22 changes: 0 additions & 22 deletions example/interface1.cpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,3 @@
/*
# include MLK for fast execution
MKL_ROOT_DIR="/opt/intel/oneapi"
MKL_BLAS_DIR="${MKL_ROOT_DIR}/mkl/latest"
MKL_COMP_DIR="${MKL_ROOT_DIR}/compiler/2023.2.0/linux/compiler"
MKL_BLAS_INC="-I${MKL_BLAS_DIR}/include"
MKL_BLAS_LIB="-Wl,--start-group ${MKL_BLAS_DIR}/lib/libmkl_intel_ilp64.a ${MKL_BLAS_DIR}/lib/libmkl_intel_thread.a ${MKL_BLAS_DIR}/lib/libmkl_core.a ${MKL_COMP_DIR}/lib/intel64_lin/libiomp5.a -Wl,--end-group -lpthread -lm -ldl -m64"
MKL_BLAS_FLAGS="-DMKL_ILP64 -m64"
TLIB_DIR=..
TLIB_INC="-I${TLIB_DIR}/include"
INCS="${TLIB_INC} ${MKL_BLAS_INC}"
LIBS="${MKL_BLAS_LIB}"
FLAGS="${MKL_BLAS_FLAGS} -DUSE_MKLBLAS"
g++ ${INCS} -std=c++17 -Ofast -fopenmp interface1.cpp ${LIBS} ${FLAGS} -o interface1 && ./interface1
*/

#include <tlib/ttm.h>

#include <vector>
Expand Down
20 changes: 4 additions & 16 deletions example/interface2.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
/*
TLIB_DIR=..
TLIB_INC="-I${TLIB_DIR}/include"
INCS="${TLIB_INC} ${MKL_BLAS_INC}"
LIBS="${MKL_BLAS_LIB}"
FLAGS="${MKL_BLAS_FLAGS} -DUSE_MKLBLAS"
g++ ${INCS} -std=c++17 -Ofast -fopenmp interface1.cpp ${LIBS} ${FLAGS} -o interface1 && ./interface1
*/

#include <tlib/ttm.h>

#include <vector>
Expand Down Expand Up @@ -67,10 +55,10 @@ int main()


// correct shape, layout and strides of the output tensors C1,C2 are automatically computed and returned by the functions.
auto C1 = tlib::ttm(q, A,B, tlib::parallel_policy::threaded_gemm_t{} , tlib::slicing_policy::slice_t{}, tlib::fusion_policy::none_t{} );
auto C2 = tlib::ttm(q, A,B, tlib::parallel_policy::omp_forloop_t{} , tlib::slicing_policy::slice_t{}, tlib::fusion_policy::all_t{} );
auto C3 = tlib::ttm(q, A,B, tlib::parallel_policy::omp_forloop_t{} , tlib::slicing_policy::subtensor_t{}, tlib::fusion_policy::outer_t{} );
auto C4 = tlib::ttm(q, A,B, tlib::parallel_policy::batched_gemm_t{} , tlib::slicing_policy::subtensor_t{}, tlib::fusion_policy::outer_t{} );
auto C1 = tlib::ttm(q, A,B, tlib::parallel_policy::threaded_gemm , tlib::slicing_policy::slice, tlib::fusion_policy::none );
auto C2 = tlib::ttm(q, A,B, tlib::parallel_policy::omp_forloop , tlib::slicing_policy::slice, tlib::fusion_policy::all );
auto C3 = tlib::ttm(q, A,B, tlib::parallel_policy::omp_forloop , tlib::slicing_policy::subtensor, tlib::fusion_policy::all );
auto C4 = tlib::ttm(q, A,B, tlib::parallel_policy::batched_gemm , tlib::slicing_policy::subtensor, tlib::fusion_policy::all );


std::cout << "C1 = " << C1 << std::endl;
Expand Down
9 changes: 2 additions & 7 deletions example/interface3.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
/*
# include either -DUSE_OPENBLAS or -DUSE_INTELBLAS for fast execution
g++ -I../include/ -std=c++17 -Ofast -fopenmp interface3.cpp -o interface3 && ./interface3
*/

#include <tlib/ttm.h>

#include <vector>
Expand Down Expand Up @@ -49,14 +44,14 @@ int main()
std::cout << "B = [ "; std::copy(B.begin(), B.end(), iterator_t(std::cout, " ")); std::cout << " ];" << std::endl;

tlib::ttm(
tlib::parallel_policy::threaded_gemm_t{} , tlib::slicing_policy::slice_t{}, tlib::fusion_policy::none_t{},
tlib::parallel_policy::threaded_gemm , tlib::slicing_policy::slice, tlib::fusion_policy::none,
q, p,
A.data(), na.data(), wa.data(), pia.data(),
B.data(), nb.data(), pib.data(),
C1.data(), nc.data(), wc.data());

tlib::ttm(
tlib::parallel_policy::omp_forloop_t{} , tlib::slicing_policy::subtensor_t{}, tlib::fusion_policy::outer_t{},
tlib::parallel_policy::omp_forloop, tlib::slicing_policy::subtensor, tlib::fusion_policy::all,
q, p,
A.data(), na.data(), wa.data(), pia.data(),
B.data(), nb.data(), pib.data(),
Expand Down
4 changes: 2 additions & 2 deletions example/measure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ int main(int argc, char* argv[])

if(method == 2 || method == 7){
std::cout << "Algorithm: <par-loop | slice-qd, all>" << std::endl;
measure(q, A, B, C, tlib::parallel_policy::omp_forloop, tlib::slicing_policy::subtensor, tlib::fusion_policy::outer);
measure(q, A, B, C, tlib::parallel_policy::omp_forloop, tlib::slicing_policy::subtensor, tlib::fusion_policy::all );
std::cout << "---------" << std::endl << std::endl;
}

Expand All @@ -203,7 +203,7 @@ int main(int argc, char* argv[])

if(method == 6 || method == 7){
std::cout << "Algorithm: <par-gemm | slice-qd, all>" << std::endl;
measure(q, A, B, C, tlib::parallel_policy::threaded_gemm, tlib::slicing_policy::subtensor, tlib::fusion_policy::outer );
measure(q, A, B, C, tlib::parallel_policy::threaded_gemm, tlib::slicing_policy::subtensor, tlib::fusion_policy::all );
std::cout << "---------" << std::endl << std::endl;
}

Expand Down
26 changes: 12 additions & 14 deletions include/tlib/detail/tags.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ struct omp_taskloop_t {}; // omp_taskloops with single threaded gemm
struct omp_forloop_t {}; // omp_for with single threaded gemm
struct omp_forloop_and_threaded_gemm_t {}; // omp_for with multi-threaded gemm
struct batched_gemm_t {}; // multithreaded batched gemm with collapsed loops
struct depends_t {};
struct combined_t {};

inline constexpr sequential_t sequential;
inline constexpr threaded_gemm_t threaded_gemm;
inline constexpr omp_taskloop_t omp_taskloop;
inline constexpr omp_forloop_t omp_forloop;
inline constexpr batched_gemm_t batched_gemm;
inline constexpr depends_t depends;
inline constexpr combined_t combined;

}

Expand All @@ -43,27 +43,25 @@ namespace tlib::slicing_policy
{
struct slice_t {};
struct subtensor_t {};
struct depends_t {};
struct combined_t {};


inline constexpr depends_t depends;
inline constexpr slice_t slice;
inline constexpr subtensor_t subtensor;
inline constexpr combined_t combined;
inline constexpr slice_t slice;
inline constexpr subtensor_t subtensor;

}


namespace tlib::fusion_policy
{
struct depends_t {};
struct none_t {};
struct outer_t {};
struct all_t {};
struct none_t {};
struct outer_t {};
struct all_t {};

inline constexpr depends_t depends;
inline constexpr none_t none;
inline constexpr outer_t outer;
inline constexpr all_t all;
inline constexpr none_t none;
inline constexpr outer_t outer;
inline constexpr all_t all;
}


14 changes: 7 additions & 7 deletions include/tlib/detail/ttm.h
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ inline void ttm(

template<class value_t, class size_t>
inline void ttm(
parallel_policy::threaded_gemm_t, slicing_policy::subtensor_t, fusion_policy::outer_t,
parallel_policy::threaded_gemm_t, slicing_policy::subtensor_t, fusion_policy::all_t,
unsigned const q, unsigned const p,
const value_t *a, size_t const*const na, size_t const*const wa, size_t const*const pia,
const value_t *b, size_t const*const nb, size_t const*const pib,
Expand Down Expand Up @@ -842,7 +842,7 @@ inline void ttm(

template<class value_t, class size_t>
inline void ttm(
parallel_policy::omp_forloop_t, slicing_policy::subtensor_t, fusion_policy::outer_t,
parallel_policy::omp_forloop_t, slicing_policy::subtensor_t, fusion_policy::all_t,
unsigned const q, unsigned const p,
const value_t *a, size_t const*const na, size_t const*const wa, size_t const*const pia,
const value_t *b, size_t const*const nb, size_t const*const pib,
Expand Down Expand Up @@ -906,7 +906,7 @@ inline void ttm(

template<class value_t, class size_t>
inline void ttm(
parallel_policy::omp_forloop_and_threaded_gemm_t, slicing_policy::subtensor_t, fusion_policy::outer_t,
parallel_policy::omp_forloop_and_threaded_gemm_t, slicing_policy::subtensor_t, fusion_policy::all_t,
unsigned const q, unsigned const p,
const value_t *a, size_t const*const na, size_t const*const wa, size_t const*const pia,
const value_t *b, size_t const*const nb, size_t const*const pib,
Expand Down Expand Up @@ -965,7 +965,7 @@ inline void ttm(

template<class value_t, class size_t>
inline void ttm(
parallel_policy::omp_forloop_and_threaded_gemm_t, slicing_policy::subtensor_t, fusion_policy::outer_t,
parallel_policy::omp_forloop_and_threaded_gemm_t, slicing_policy::subtensor_t, fusion_policy::all_t,
unsigned const q, unsigned const p, double ratio,
const value_t *a, size_t const*const na, size_t const*const wa, size_t const*const pia,
const value_t *b, size_t const*const nb, size_t const*const pib,
Expand Down Expand Up @@ -1028,7 +1028,7 @@ inline void ttm(

template<class value_t, class size_t>
inline void ttm(
parallel_policy::batched_gemm_t, slicing_policy::subtensor_t, fusion_policy::outer_t,
parallel_policy::batched_gemm_t, slicing_policy::subtensor_t, fusion_policy::all_t,
unsigned const q, unsigned const p,
const value_t *a, size_t const*const na, size_t const*const wa, size_t const*const pia,
const value_t *b, size_t const*const nb, size_t const*const pib,
Expand Down Expand Up @@ -1115,7 +1115,7 @@ inline void ttm(

template<class value_t, class size_t>
inline void ttm(
parallel_policy::depends_t, slicing_policy::depends_t, fusion_policy::depends_t,
parallel_policy::combined_t, slicing_policy::combined_t, fusion_policy::all_t,
unsigned const q, unsigned const p,
const value_t *a, size_t const*const na, size_t const*const wa, size_t const*const pia,
const value_t *b, size_t const*const nb, size_t const*const pib,
Expand Down Expand Up @@ -1157,7 +1157,7 @@ inline void ttm(
auto const outer = product(na, pia, qh+1,p+1);

if( outer >= cores){
ttm(parallel_policy::omp_forloop, slicing_policy::subtensor, fusion_policy::outer,
ttm(parallel_policy::omp_forloop, slicing_policy::subtensor, fusion_policy::all,
q, p,
a, na, wa, pia,
b, nb, pib,
Expand Down
2 changes: 1 addition & 1 deletion include/tlib/ttm.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,5 @@ template<class value_t>
inline auto operator*(tlib::tensor_view<value_t> const& a, tlib::tensor<value_t> const& b)
{
return ttm(a.contraction_mode(), a.get_tensor(), b,
tlib::parallel_policy::depends, tlib::slicing_policy::depends, tlib::fusion_policy::depends) ;
tlib::parallel_policy::combined, tlib::slicing_policy::combined, tlib::fusion_policy::all) ;
}
6 changes: 3 additions & 3 deletions test/src/gtest_tlib_ttm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ TEST(TensorTimesMatrix, OmpForLoopSubtensorOuterFusion)
using size_type = std::size_t;
using execution_policy = tlib::parallel_policy::omp_forloop_t;
using slicing_policy = tlib::slicing_policy::subtensor_t;
using fusion_policy = tlib::fusion_policy::outer_t;
using fusion_policy = tlib::fusion_policy::all_t;

check_tensor_times_matrix<value_type,size_type,execution_policy,slicing_policy,fusion_policy,2u>(2u,3);
check_tensor_times_matrix<value_type,size_type,execution_policy,slicing_policy,fusion_policy,3u>(2u,3);
Expand All @@ -380,7 +380,7 @@ TEST(TensorTimesMatrix, OmpForLoopThreadedGemmSubtensorOuterFusion)
using size_type = std::size_t;
using execution_policy = tlib::parallel_policy::omp_forloop_and_threaded_gemm_t;
using slicing_policy = tlib::slicing_policy::subtensor_t;
using fusion_policy = tlib::fusion_policy::outer_t;
using fusion_policy = tlib::fusion_policy::all_t;

check_tensor_times_matrix<value_type,size_type,execution_policy,slicing_policy,fusion_policy,2u>(2u,3);
check_tensor_times_matrix<value_type,size_type,execution_policy,slicing_policy,fusion_policy,3u>(2u,3);
Expand All @@ -397,7 +397,7 @@ TEST(TensorTimesMatrix, BatchedGemmSubtensorOuterFusion)
using size_type = std::size_t;
using execution_policy = tlib::parallel_policy::batched_gemm_t;
using slicing_policy = tlib::slicing_policy::subtensor_t;
using fusion_policy = tlib::fusion_policy::outer_t;
using fusion_policy = tlib::fusion_policy::all_t;


check_tensor_times_matrix<value_type,size_type,execution_policy,slicing_policy,fusion_policy,2u>(2u,3);
Expand Down
15 changes: 5 additions & 10 deletions ttmpy/src/wrapped_ttm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,11 @@ ttm(std::size_t const contraction_mode,
auto* cptr = static_cast<T*>(cinfo.ptr); // extract data an shape of input array
// auto nnc = std::size_t(cinfo.size);


#if defined(USE_OPENBLAS) || defined(USE_MKL)
tlib::ttm<T>(tlib::parallel_policy::omp_forloop_t{}, tlib::slicing_policy::subtensor_t{}, tlib::fusion_policy::outer_t{},
q, p,
aptr, na.data(), wa.data(), pia.data(),
bptr, nb.data(), pib.data(),
cptr, nc.data(), wc.data());
#else

#endif
tlib::ttm<T>(tlib::parallel_policy::combined, tlib::slicing_policy::combined, tlib::fusion_policy::all,
q, p,
aptr, na.data(), wa.data(), pia.data(),
bptr, nb.data(), pib.data(),
cptr, nc.data(), wc.data());

return c;
}
Expand Down

0 comments on commit 90f3675

Please sign in to comment.