Skip to content

Commit 2b40dad

Browse files
q10facebook-github-bot
authored andcommitted
Better kernel launch utilities
Summary: - Add utilities for doing multiple checks prior to launching kernels Differential Revision: D72095960
1 parent 47635cf commit 2b40dad

File tree

5 files changed

+722
-106
lines changed

5 files changed

+722
-106
lines changed
+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <ATen/ATen.h>
12+
#include <iostream>
13+
#include <type_traits>
14+
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
15+
16+
namespace fbgemm_gpu::utils {
17+
18+
#define U64(x) static_cast<uint64_t>(x)
19+
20+
////////////////////////////////////////////////////////////////////////////////
21+
// Helpers to detect TensorAccessorBuilder type (regardless of template params)
22+
////////////////////////////////////////////////////////////////////////////////
23+
24+
template <typename>
25+
struct is_tensor_accessor_builder : std::false_type {};
26+
27+
template <
28+
typename T,
29+
size_t N,
30+
size_t INB,
31+
bool P,
32+
template <typename>
33+
class PT>
34+
struct is_tensor_accessor_builder<TensorAccessorBuilder<T, N, INB, P, PT>>
35+
: std::true_type {};
36+
37+
template <typename T>
38+
inline constexpr bool is_tensor_accessor_builder_v =
39+
is_tensor_accessor_builder<T>::value;
40+
41+
////////////////////////////////////////////////////////////////////////////////
42+
// Transform Kernel Argument
43+
//
44+
// Transform certain arguments before passing them to the kernel invocation
45+
////////////////////////////////////////////////////////////////////////////////
46+
47+
template <typename T>
48+
decltype(auto) transform_kernel_arg(const std::string_view& context, T&& arg) {
49+
if constexpr (is_tensor_accessor_builder_v<std::decay_t<T>>) {
50+
// If the arg is a TensorAccessorBuilder, build it out to a tensor accessor.
51+
// This is the mechanism that allows us to log kernel function names on
52+
// failed checks and assertions when comopiled with FBGEMM_GPU_MEMCHECK
53+
// turned ON.
54+
return arg.build(
55+
#ifdef FBGEMM_GPU_MEMCHECK
56+
context.data()
57+
#endif
58+
);
59+
} else {
60+
// Otherwise, forward the argument as is
61+
return std::forward<T>(arg);
62+
}
63+
}
64+
65+
////////////////////////////////////////////////////////////////////////////////
66+
// Launch the kernel with all the ceremonial routines
67+
////////////////////////////////////////////////////////////////////////////////
68+
69+
template <typename KernelFunc, typename... Args>
70+
inline void launch_kernel(
71+
const std::string_view& context,
72+
const KernelFunc& kernel,
73+
const dim3 grid,
74+
const dim3 block,
75+
const size_t Ns,
76+
cudaStream_t stream,
77+
Args&&... args) {
78+
#ifdef USE_ROCM
79+
// ROCm has a limit of 2^32 elements per kernel launch, but doens't
80+
// automatically work around problem like CUDA does, see:
81+
// https://github.com/ROCm/hip/issues/2253
82+
uint64_t grid_size = U64(grid.x) * U64(grid.y) * U64(grid.z) * U64(block.x) *
83+
U64(block.y) * U64(block.z);
84+
TORCH_CHECK(
85+
grid_size < U64(0xFFFFFFFF),
86+
"[ ",
87+
context,
88+
" ]: ",
89+
"Kernel launch grid size ",
90+
grid_size,
91+
" is greater than the ROCm limit of 2^32");
92+
#endif
93+
94+
kernel<<<grid, block, Ns, stream>>>(
95+
// Transform arguments to the kernel before forwarding them.
96+
transform_kernel_arg(
97+
// Pass the context for debugging
98+
context,
99+
std::forward<Args>(args))...);
100+
101+
// Check for CUDA errors
102+
C10_CUDA_KERNEL_LAUNCH_CHECK();
103+
104+
return;
105+
}
106+
107+
#undef U64
108+
109+
} // namespace fbgemm_gpu::utils
110+
111+
// The constexpr reference to the kernel is added to enable for better
112+
// compilation error messages upon template mismatch
113+
#define FBGEMM_LAUNCH_KERNEL(KERNEL, GRID, BLOCK, ...) \
114+
constexpr decltype(KERNEL)& kernel = KERNEL; \
115+
fbgemm_gpu::utils::launch_kernel( \
116+
#KERNEL, \
117+
kernel, \
118+
GRID, \
119+
BLOCK, \
120+
0, \
121+
at::cuda::getCurrentCUDAStream(), \
122+
__VA_ARGS__);

0 commit comments

Comments
 (0)