|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#pragma once |
| 10 | + |
| 11 | +#include <ATen/ATen.h> |
| 12 | +#include <iostream> |
| 13 | +#include <type_traits> |
| 14 | +#include "fbgemm_gpu/utils/tensor_accessor_builder.h" |
| 15 | + |
| 16 | +namespace fbgemm_gpu::utils { |
| 17 | + |
| 18 | +#define U64(x) static_cast<uint64_t>(x) |
| 19 | + |
| 20 | +//////////////////////////////////////////////////////////////////////////////// |
| 21 | +// Helpers to detect TensorAccessorBuilder type (regardless of template params) |
| 22 | +//////////////////////////////////////////////////////////////////////////////// |
| 23 | + |
| 24 | +template <typename> |
| 25 | +struct is_tensor_accessor_builder : std::false_type {}; |
| 26 | + |
| 27 | +template < |
| 28 | + typename T, |
| 29 | + size_t N, |
| 30 | + size_t INB, |
| 31 | + bool P, |
| 32 | + template <typename> |
| 33 | + class PT> |
| 34 | +struct is_tensor_accessor_builder<TensorAccessorBuilder<T, N, INB, P, PT>> |
| 35 | + : std::true_type {}; |
| 36 | + |
| 37 | +template <typename T> |
| 38 | +inline constexpr bool is_tensor_accessor_builder_v = |
| 39 | + is_tensor_accessor_builder<T>::value; |
| 40 | + |
| 41 | +//////////////////////////////////////////////////////////////////////////////// |
| 42 | +// Transform Kernel Argument |
| 43 | +// |
| 44 | +// Transform certain arguments before passing them to the kernel invocation |
| 45 | +//////////////////////////////////////////////////////////////////////////////// |
| 46 | + |
| 47 | +template <typename T> |
| 48 | +decltype(auto) transform_kernel_arg(const std::string_view& context, T&& arg) { |
| 49 | + if constexpr (is_tensor_accessor_builder_v<std::decay_t<T>>) { |
| 50 | + // If the arg is a TensorAccessorBuilder, build it out to a tensor accessor. |
| 51 | + // This is the mechanism that allows us to log kernel function names on |
| 52 | + // failed checks and assertions when comopiled with FBGEMM_GPU_MEMCHECK |
| 53 | + // turned ON. |
| 54 | + return arg.build( |
| 55 | +#ifdef FBGEMM_GPU_MEMCHECK |
| 56 | + context.data() |
| 57 | +#endif |
| 58 | + ); |
| 59 | + } else { |
| 60 | + // Otherwise, forward the argument as is |
| 61 | + return std::forward<T>(arg); |
| 62 | + } |
| 63 | +} |
| 64 | + |
| 65 | +//////////////////////////////////////////////////////////////////////////////// |
| 66 | +// Launch the kernel with all the ceremonial routines |
| 67 | +//////////////////////////////////////////////////////////////////////////////// |
| 68 | + |
| 69 | +template <typename KernelFunc, typename... Args> |
| 70 | +inline void launch_kernel( |
| 71 | + const std::string_view& context, |
| 72 | + const KernelFunc& kernel, |
| 73 | + const dim3 grid, |
| 74 | + const dim3 block, |
| 75 | + const size_t Ns, |
| 76 | + cudaStream_t stream, |
| 77 | + Args&&... args) { |
| 78 | +#ifdef USE_ROCM |
| 79 | + // ROCm has a limit of 2^32 elements per kernel launch, but doens't |
| 80 | + // automatically work around problem like CUDA does, see: |
| 81 | + // https://github.com/ROCm/hip/issues/2253 |
| 82 | + uint64_t grid_size = U64(grid.x) * U64(grid.y) * U64(grid.z) * U64(block.x) * |
| 83 | + U64(block.y) * U64(block.z); |
| 84 | + TORCH_CHECK( |
| 85 | + grid_size < U64(0xFFFFFFFF), |
| 86 | + "[ ", |
| 87 | + context, |
| 88 | + " ]: ", |
| 89 | + "Kernel launch grid size ", |
| 90 | + grid_size, |
| 91 | + " is greater than the ROCm limit of 2^32"); |
| 92 | +#endif |
| 93 | + |
| 94 | + kernel<<<grid, block, Ns, stream>>>( |
| 95 | + // Transform arguments to the kernel before forwarding them. |
| 96 | + transform_kernel_arg( |
| 97 | + // Pass the context for debugging |
| 98 | + context, |
| 99 | + std::forward<Args>(args))...); |
| 100 | + |
| 101 | + // Check for CUDA errors |
| 102 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 103 | + |
| 104 | + return; |
| 105 | +} |
| 106 | + |
| 107 | +#undef U64 |
| 108 | + |
| 109 | +} // namespace fbgemm_gpu::utils |
| 110 | + |
| 111 | +// The constexpr reference to the kernel is added to enable for better |
| 112 | +// compilation error messages upon template mismatch |
| 113 | +#define FBGEMM_LAUNCH_KERNEL(KERNEL, GRID, BLOCK, ...) \ |
| 114 | + constexpr decltype(KERNEL)& kernel = KERNEL; \ |
| 115 | + fbgemm_gpu::utils::launch_kernel( \ |
| 116 | + #KERNEL, \ |
| 117 | + kernel, \ |
| 118 | + GRID, \ |
| 119 | + BLOCK, \ |
| 120 | + 0, \ |
| 121 | + at::cuda::getCurrentCUDAStream(), \ |
| 122 | + __VA_ARGS__); |
0 commit comments