Skip to content

Commit

Permalink
vector type
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Jan 30, 2025
1 parent 3565bfd commit 63809ab
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 155 deletions.
58 changes: 58 additions & 0 deletions include/mscclpp/gpu_data_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,62 @@ using __bfloat162 = __nv_bfloat162;

#endif

#include <mscclpp/device.hpp>

namespace mscclpp {

/// Word array.
template <int Bytes>
struct alignas(Bytes) Words {
static_assert(Bytes > 0, "Bytes must be greater than 0");
static_assert(Bytes % 4 == 0, "Bytes must be multiple of 4");
uint32_t w[Bytes / 4];

MSCCLPP_HOST_DEVICE_INLINE Words() {}

MSCCLPP_HOST_DEVICE_INLINE uint32_t& operator[](int i) { return w[i]; }

MSCCLPP_HOST_DEVICE_INLINE const uint32_t& operator[](int i) const { return w[i]; }
};

/// Vector type.
template <typename T, int N>
union alignas(sizeof(T) * N) VectorType {
static_assert(N > 0, "N must be greater than 0");

T data[N];
Words<sizeof(T) * N> words;

using ElementType = T;
constexpr static int Size = N;

MSCCLPP_HOST_DEVICE_INLINE VectorType() {}

MSCCLPP_HOST_DEVICE_INLINE T& operator[](int i) { return data[i]; }

MSCCLPP_HOST_DEVICE_INLINE const T& operator[](int i) const { return data[i]; }
};

using i32x1 = VectorType<int32_t, 1>;
using u32x1 = VectorType<uint32_t, 1>;
using f64x1 = VectorType<double, 1>;
using f32x1 = VectorType<float, 1>;

using i32x2 = VectorType<int32_t, 2>;
using u32x2 = VectorType<uint32_t, 2>;
using f32x2 = VectorType<float, 2>;
using f16x2 = VectorType<__half, 2>;
using bf16x2 = VectorType<__bfloat16, 2>;

using i32x4 = VectorType<int32_t, 4>;
using u32x4 = VectorType<uint32_t, 4>;
using f32x4 = VectorType<float, 4>;
using f16x4 = VectorType<__half, 4>;
using bf16x4 = VectorType<__bfloat16, 4>;

using f16x8 = VectorType<__half, 8>;
using bf16x8 = VectorType<__bfloat16, 8>;

} // namespace mscclpp

#endif // MSCCLPP_GPU_DATA_TYPES_HPP_
117 changes: 66 additions & 51 deletions include/mscclpp/nvls_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,80 +27,95 @@ struct DeviceMulticastPointerDeviceHandle {
size_t bufferSize;

#if defined(MSCCLPP_DEVICE_CUDA)
template <typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemLoadReduce(TValue& val, T* ptr) {
if constexpr (std::is_same_v<TValue, int32_t> && std::is_same_v<T, int32_t>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.s32 %0, [%1];" : "=r"(val) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<TValue, uint32_t> && std::is_same_v<T, uint32_t>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.u32 %0, [%1];" : "=r"(val) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, float>) {
template <typename VectorType>
MSCCLPP_DEVICE_INLINE static VectorType multimemLoadReduce(VectorType* ptr) {
VectorType val;
if constexpr (std::is_same_v<VectorType, i32x1>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.s32 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<VectorType, u32x1>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.u32 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<VectorType, f32x1>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.f32 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<VectorType, f32x2>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.f32 {%0,%1}, [%2];"
: "=r"(val.words[0]), "=r"(val.words[1])
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same_v<VectorType, f32x4>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f32 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3])
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, float>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.f32 {%0,%1}, [%2];"
: "=r"(val.x), "=r"(val.y)
} else if constexpr (std::is_same_v<VectorType, f16x2>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.f16x2 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<VectorType, f16x4>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.f16x2 {%0,%1}, [%2];"
: "=r"(val.words[0]), "=r"(val.words[1])
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, float>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.f32 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, __half2>) {
} else if constexpr (std::is_same_v<VectorType, f16x8>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3])
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, __half2>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.f16x2 {%0,%1}, [%2];"
: "=r"(val.x), "=r"(val.y)
} else if constexpr (std::is_same_v<VectorType, bf16x2>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.bf16x2 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<VectorType, bf16x4>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.bf16x2 {%0,%1}, [%2];"
: "=r"(val.words[0]), "=r"(val.words[1])
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __half2>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.f16x2 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, __bfloat162>) {
} else if constexpr (std::is_same_v<VectorType, bf16x8>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3])
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __bfloat162>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.bf16x2 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
static_assert(dependentFalse<VectorType>, "Not supported type");
}
return val;
};

template <typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemStore(const TValue& val, T* ptr) {
if constexpr (std::is_same_v<TValue, int32_t> && std::is_same_v<T, int32_t>) {
asm volatile("multimem.st.relaxed.sys.global.s32 [%0], %1;" ::"l"(ptr), "r"(val) : "memory");
} else if constexpr (std::is_same_v<TValue, uint32_t> && std::is_same_v<T, uint32_t>) {
asm volatile("multimem.st.relaxed.sys.global.u32 [%0], %1;" ::"l"(ptr), "r"(val) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, float>) {
asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
template <typename VectorType, typename T>
MSCCLPP_DEVICE_INLINE static void multimemStore(const VectorType& val, T* ptr) {
if constexpr (std::is_same_v<VectorType, i32x1>) {
asm volatile("multimem.st.relaxed.sys.global.s32 [%0], %1;" ::"l"(ptr), "r"(val.words[0]) : "memory");
} else if constexpr (std::is_same_v<VectorType, u32x1>) {
asm volatile("multimem.st.relaxed.sys.global.u32 [%0], %1;" ::"l"(ptr), "r"(val.words[0]) : "memory");
} else if constexpr (std::is_same_v<VectorType, f64x1>) {
asm volatile("multimem.st.relaxed.sys.global.f64 [%0], %1;" ::"l"(ptr), "d"(val.words[0]) : "memory");
} else if constexpr (std::is_same_v<VectorType, f32x1>) {
asm volatile("multimem.st.relaxed.sys.global.f32 [%0], %1;" ::"l"(ptr), "r"(val.words[0]) : "memory");
} else if constexpr (std::is_same_v<VectorType, f32x2>) {
asm volatile("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1,%2};" ::"l"(ptr), "r"(val.words[0]),
"r"(val.words[1])
: "memory");
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, float>) {
asm volatile("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y)
} else if constexpr (std::is_same_v<VectorType, f32x4>) {
asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.words[0]),
"r"(val.words[1]), "r"(val.words[2]), "r"(val.words[3])
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, float>) {
asm volatile("multimem.st.relaxed.sys.global.f32 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, __half2>) {
asm volatile("multimem.st.relaxed.sys.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
} else if constexpr (std::is_same_v<VectorType, f16x2>) {
asm volatile("multimem.st.relaxed.sys.global.f16x2 [%0], %1;" ::"l"(ptr), "r"(val.words[0]) : "memory");
} else if constexpr (std::is_same_v<VectorType, f16x4>) {
asm volatile("multimem.st.relaxed.sys.global.v2.f16x2 [%0], {%1,%2};" ::"l"(ptr), "r"(val.words[0]),
"r"(val.words[1])
: "memory");
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, __half2>) {
asm volatile("multimem.st.relaxed.sys.global.v2.f16x2 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y)
} else if constexpr (std::is_same_v<VectorType, f16x8>) {
asm volatile("multimem.st.relaxed.sys.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.words[0]),
"r"(val.words[1]), "r"(val.words[2]), "r"(val.words[3])
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __half2>) {
asm volatile("multimem.st.relaxed.sys.global.f16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, __bfloat162>) {
asm volatile("multimem.st.relaxed.sys.global.v4.bf16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
} else if constexpr (std::is_same_v<VectorType, bf16x2>) {
asm volatile("multimem.st.relaxed.sys.global.bf16x2 [%0], %1;" ::"l"(ptr), "r"(val.words[0]) : "memory");
} else if constexpr (std::is_same_v<VectorType, bf16x4>) {
asm volatile("multimem.st.relaxed.sys.global.v2.bf16x2 [%0], {%1,%2};" ::"l"(ptr), "r"(val.words[0]),
"r"(val.words[1])
: "memory");
} else if constexpr (std::is_same_v<VectorType, bf16x8>) {
asm volatile("multimem.st.relaxed.sys.global.v4.bf16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.words[0]),
"r"(val.words[1]), "r"(val.words[2]), "r"(val.words[3])
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __bfloat162>) {
asm volatile("multimem.st.relaxed.sys.global.bf16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
static_assert(dependentFalse<VectorType>, "Not supported type");
}
};

Expand Down
68 changes: 22 additions & 46 deletions python/mscclpp_benchmark/allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -804,58 +804,34 @@ __forceinline__ __device__ void barrier(mscclpp::MemoryDevice2DeviceSemaphoreDev
deviceSyncer.sync(num_blocks);
}

// Assumes \p kVecSize is 1, 2, 4, or 8 (default 8)
template <typename DataType = float, int kVecSize = 8>
// Assumes kVecSize is 1, 2, 4, or 8
template <typename DataType, int kVecSize>
MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, int my_rank,
int num_ranks, size_t num_elements) {
DataType* mc_ptr = (DataType*)nvlsPtrs.mcPtr;
int tid = threadIdx.x;
int bid = blockIdx.x;
int num_threads_per_block = blockDim.x;
int num_blocks = gridDim.x;
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, size_t my_rank,
size_t num_ranks, size_t num_elements) {
using VectorType = mscclpp::VectorType<DataType, kVecSize>;
VectorType* mc_ptr = reinterpret_cast<VectorType*>(nvlsPtrs.mcPtr);
size_t tid = threadIdx.x;
size_t bid = blockIdx.x;
size_t num_threads_per_block = blockDim.x;
size_t num_blocks = gridDim.x;

// start with a barrier to ensure all devices have written their values
// to their own memory (that is part of the multicast memory)
// before reading them in this kernel
barrier(semaphores, tid, bid, num_blocks, num_ranks);

// every device loads, reduces, and stores a partition of the multicast memory
int rank_start = ((int64_t)num_elements * (int64_t)my_rank) / (int64_t)num_ranks;
int rank_end = ((int64_t)num_elements * (int64_t)(my_rank + 1)) / (int64_t)num_ranks;

int thread_offset = (bid * num_threads_per_block + tid) * kVecSize;
int thread_step = (num_threads_per_block * num_blocks) * kVecSize; // number of threads * vector size

for (int idx = rank_start + thread_offset; idx < rank_end; idx += thread_step) {
if constexpr (std::is_same_v<DataType, float> && (kVecSize == 4)) {
uint4 val; // fits 4 float elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, float> && (kVecSize == 2)) {
uint2 val; // fits 2 float elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, float> && (kVecSize == 1)) {
uint1 val; // fits 1 float element
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, __half> && (kVecSize == 8)) {
uint4 val; // fits 8 cutlass::half_t elements; i.e., 4 half2 elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (half2*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (half2*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, __half> && (kVecSize == 4)) {
uint2 val; // fits 4 cutlass::half_t elements; i.e., 2 half2 elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (half2*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (half2*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, __half> && (kVecSize == 2)) {
uint1 val; // fits 2 cutlass::half_t elements; i.e., 1 half2 element
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (half2*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (half2*)(mc_ptr + idx));
} else {
// not supported: cannot use static_assert because of the way TYPE is handled in this file
assert(false); // Unsupported data type and vector size combination
}
size_t num_vectors = num_elements / VectorType::Size;
size_t rank_start = (num_vectors * my_rank) / num_ranks;
size_t rank_end = (num_vectors * (my_rank + 1)) / num_ranks;

size_t thread_offset = bid * num_threads_per_block + tid;
size_t thread_step = num_threads_per_block * num_blocks; // number of threads * vector size

for (size_t idx = rank_start + thread_offset; idx < rank_end; idx += thread_step) {
VectorType val = mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(mc_ptr + idx);
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
}

// end with a barrier to ensure all devices can now read their values
Expand All @@ -866,8 +842,8 @@ MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::MemoryDevice2DeviceSemapho

extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce6(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, int my_rank, int num_ranks, size_t num_elements,
size_t vector_size) {
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, size_t my_rank, size_t num_ranks,
size_t num_elements, size_t vector_size) {
if (vector_size == 8) {
allreduce6_helper<TYPE, 8>(semaphores, nvlsPtrs, my_rank, num_ranks, num_elements);
} else if (vector_size == 4) {
Expand Down
13 changes: 6 additions & 7 deletions python/test/nvls_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks, int nbytes) {
int nelem = nbytes / sizeof(float);
float* dev_ptr = (float*)nvlsPtrs.devicePtr;
float* mc_ptr = (float*)nvlsPtrs.mcPtr;
mscclpp::f32x4* mc_ptr = (mscclpp::f32x4*)nvlsPtrs.mcPtr;
int tid = threadIdx.x;
int bid = blockIdx.x;

Expand All @@ -33,15 +33,14 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
}
deviceSyncer.sync(gridDim.x);

int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks;
int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks;
int my_st = ((int64_t)nelem / 4 * (int64_t)my_rank) / (int64_t)nranks;
int my_en = ((int64_t)nelem / 4 * (int64_t)(my_rank + 1)) / (int64_t)nranks;

int my_offset = (tid + bid * blockDim.x) * 4;
int my_step = blockDim.x * gridDim.x * 4;
int my_offset = (tid + bid * blockDim.x);
int my_step = blockDim.x * gridDim.x;

for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
uint4 val;
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, mc_ptr + idx);
mscclpp::f32x4 val = mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(mc_ptr + idx);
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
}

Expand Down
Loading

0 comments on commit 63809ab

Please sign in to comment.