From 5566afaaeebc9553b67c137224b1d5c3556ff0ac Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 4 Feb 2025 10:54:44 +0000 Subject: [PATCH 01/13] add float8_e4m3fnuz ck a8w8 gemm --- csrc/ck_gemm_a8w8/gemm_a8w8.cu | 9 ++-- csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu | 6 ++- .../ck_gemm_a8w8/include/gemm_a8w8_common.cuh | 53 +++++++++++++++++-- 3 files changed, 57 insertions(+), 11 deletions(-) diff --git a/csrc/ck_gemm_a8w8/gemm_a8w8.cu b/csrc/ck_gemm_a8w8/gemm_a8w8.cu index 00fcf0c3..e56ddb79 100644 --- a/csrc/ck_gemm_a8w8/gemm_a8w8.cu +++ b/csrc/ck_gemm_a8w8/gemm_a8w8.cu @@ -144,6 +144,7 @@ RowwiseKernel rowwise_dispatch(int M, int N, int K) return rowwise_heuristic_dispatch(M, N, K); } + torch::Tensor gemm_a8w8( torch::Tensor &XQ, torch::Tensor &WQ, @@ -153,10 +154,10 @@ torch::Tensor gemm_a8w8( std::optional bias, int splitK) { - TORCH_CHECK(XQ.dtype() == at::ScalarType::Char && XQ.dtype() == WQ.dtype(), - "Weights and activations should both be int8!"); - TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), - "Scales should have the same dtype!"); + const auto is_int8 = XQ.dtype() == at::ScalarType::Char && XQ.dtype() == WQ.dtype(); + const auto is_fp8 = XQ.dtype() == at::ScalarType::Float8_e4m3fnuz && XQ.dtype() == WQ.dtype(); + TORCH_CHECK(is_int8 || is_fp8, "Weights and activations should both be int8 or fp8!"); + TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); if (bias != std::nullopt) TORCH_CHECK(bias.value().dtype() == Y.dtype(), "Out amd bias should have the same dtype!"); diff --git a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu index de5d6fd7..6944afcc 100644 --- a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu +++ b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu @@ -67,8 +67,10 @@ torch::Tensor gemm_a8w8_tune( int kernelId, int splitK) { - TORCH_CHECK(XQ.dtype() == at::ScalarType::Char && XQ.dtype() == WQ.dtype(), - "Weights and activations should both be int8!"); + const auto is_int8 = XQ.dtype() == at::ScalarType::Char && XQ.dtype() == WQ.dtype(); + const auto is_fp8 = XQ.dtype() == at::ScalarType::Float8_e4m3fnuz && XQ.dtype() == WQ.dtype(); + TORCH_CHECK(is_int8 || is_fp8, "Weights and activations should both be int8 or fp8!"); + TORCH_CHECK( x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); std::optional bias = std::nullopt; diff --git a/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh b/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh index 5cf99d82..1e94dfa8 100644 --- a/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh +++ b/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh @@ -45,11 +45,11 @@ using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; -using ADataType = I8; -using BDataType = I8; -using AccDataType = I32; -using CShuffleDataType = I32; -using ComputeDataType = I8; +using ADataType = FP8; +using BDataType = FP8; +using AccDataType = F32; +using CShuffleDataType = F32; +using ComputeDataType = FP8; using ALayout = Row; using BLayout = Col; @@ -148,6 +148,49 @@ struct MultiplyMultiplyAdd e = ck::type_convert(x0_f); } + template <> + __host__ __device__ constexpr void operator()(F16 &e, + const float& c, + const float& d0, + const float& d1, + const F16 &d2) const + { + const float x0_f = c * d0 * d1 + ck::type_convert(d2); + + e = ck::type_convert(x0_f); + } + template <> + __host__ __device__ constexpr void operator()(F16 &e, + const float& c, + const F16& d0, + const F16& d1, + const F16 &d2) const + { + const float x0_f = c * ck::type_convert(d0) * ck::type_convert(d1) + ck::type_convert(d2); + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()(B16 &e, + const float& c, + const B16& d0, + const B16& d1, + const B16 &d2) const + { + const float x0_f = c * ck::type_convert(d0) * ck::type_convert(d1) + ck::type_convert(d2); + e = ck::type_convert(x0_f); + } + template <> + __host__ __device__ constexpr void operator()(B16 &e, + const float& c, + const float& d0, + const float& d1, + const B16 &d2) const + { + const float x0_f = c * d0 * d1 + ck::type_convert(d2); + e = ck::type_convert(x0_f); + } + template <> __host__ __device__ constexpr void operator()( ck::bhalf_t &e, const int &c, const ck::bhalf_t &d0, const ck::bhalf_t &d1, const ck::bhalf_t &d2) const From 1071d9b89c5a5382dd3e48872ffe08ea04045424 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 11 Feb 2025 12:33:08 +0000 Subject: [PATCH 02/13] templated input dtype for gemm_a8w8 (support int8 and fp8) --- csrc/ck_gemm_a8w8/gemm_a8w8.cu | 216 ++++++-- csrc/ck_gemm_a8w8/gemm_a8w8_common.py | 197 +++---- csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu | 110 ++-- csrc/ck_gemm_a8w8/gemm_a8w8_tune.py | 37 +- csrc/ck_gemm_a8w8/gen_instances.py | 513 ++++++++++-------- .../ck_gemm_a8w8/include/gemm_a8w8_common.cuh | 61 ++- 6 files changed, 697 insertions(+), 437 deletions(-) diff --git a/csrc/ck_gemm_a8w8/gemm_a8w8.cu b/csrc/ck_gemm_a8w8/gemm_a8w8.cu index e56ddb79..e48b90a8 100644 --- a/csrc/ck_gemm_a8w8/gemm_a8w8.cu +++ b/csrc/ck_gemm_a8w8/gemm_a8w8.cu @@ -31,34 +31,82 @@ using RowwiseKernelMap = std::unordered_map< RowwiseKernel, IntTupleHash>; -template +template < + typename ADataType, + typename BDataType, + typename AccDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename DDataType, + typename EDataType +> RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) { // Apply shape heuristics to find a suitable kernel implementation. if (M < 64 && N < 2048 && K < 2048) { // Kernel that generally works well on small shapes. - return a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2; + return a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else if (M < 64 && K < 2048) { // Kernel that works well for small batch size and small K. - return a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2; + return a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else if (M < 64 && N < 2048) { // Kernel that works well for small batch size and small N. - return a8w8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2; + return a8w8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else if (M < 64 && N > 2048 && K > 2048) { // Kernel that works well for small M but larger N and K. - return a8w8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1; + return a8w8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else if (M < 64) { // Fallback to generic small batch kernel if we cant find a good match. - return a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2; + return a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; /* } else if (((M < 512 && K < 8192) || (N <= 2048 && K <= 8192) || (K <= 2048 && N <= 8192)) && K >= 1024) { // Kernel that is optimized for larger batch sizes but otherwise small // tensors. @@ -67,23 +115,54 @@ RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) else if (K < 1024) { // Special case for small K. - return a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; + return a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else if (M < 1024) { // Kernel for generic medium batch sizes. - return a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else if (M >= 1024 && N >= 1024 && K >= 1024) { // Kernel for very large gemm // return a8w8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3; - return a8w8_rowwise_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1; + return a8w8_rowwise_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else { // Fallback large kernel. - return a8w8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3; + return a8w8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3; } } @@ -95,23 +174,51 @@ static constexpr int nextPow2(unsigned int num) return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } -template +template < + typename ADataType, + typename BDataType, + typename AccDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename DDataType, + typename EDataType +> RowwiseKernel rowwise_dispatch(int M, int N, int K) { // For a given shape, either find the best kernel via lookup or heuristic. // For many small M shapes, we bucket them to the next largest kernel. // This is fine since kernels are padded anyway. - static const auto lookup = [] { + if constexpr (std::is_same_v) { - return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType,F16)}; + return RowwiseKernelMap{GENERATE_LOOKUP_TABLE( + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + F16 + ) + }; } else if constexpr (std::is_same_v) { - return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType,B16)}; - } else { + return RowwiseKernelMap{GENERATE_LOOKUP_TABLE( + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + B16 + ) + }; + } + else { static_assert(false, "rowwise_dispatch used with unsupported dtype!"); - } }(); - + } + }(); + // First check if this shape(M,N,K) is available in the direct lookup. auto it = lookup.find({M, N, K}); // If we found an optimal kernel, use it. @@ -141,10 +248,17 @@ RowwiseKernel rowwise_dispatch(int M, int N, int K) return it->second; } // Otherwise, use heuristics. - return rowwise_heuristic_dispatch(M, N, K); + return rowwise_heuristic_dispatch< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >(M, N, K); } - torch::Tensor gemm_a8w8( torch::Tensor &XQ, torch::Tensor &WQ, @@ -166,26 +280,50 @@ torch::Tensor gemm_a8w8( int N = WQ.size(0); int K = XQ.size(1); int KBatch = std::pow(2, splitK); - - if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) - { - rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - } - else if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) - { - rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - } - else if (Y.dtype() == at::ScalarType::Half) - { - rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - } - else if (Y.dtype() == at::ScalarType::BFloat16) - { - rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - } - else - { - TORCH_CHECK(false, "Unsupported scales/output dtype!"); + // TODO: simplify + if(is_int8){ + if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else if (Y.dtype() == at::ScalarType::Half) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else if (Y.dtype() == at::ScalarType::BFloat16) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else + { + TORCH_CHECK(false, "Unsupported scales/output dtype!"); + } + } else { + if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else if (Y.dtype() == at::ScalarType::Half) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else if (Y.dtype() == at::ScalarType::BFloat16) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else + { + TORCH_CHECK(false, "Unsupported scales/output dtype!"); + } } + return Y; } diff --git a/csrc/ck_gemm_a8w8/gemm_a8w8_common.py b/csrc/ck_gemm_a8w8/gemm_a8w8_common.py index 279e21fd..4c484caf 100644 --- a/csrc/ck_gemm_a8w8/gemm_a8w8_common.py +++ b/csrc/ck_gemm_a8w8/gemm_a8w8_common.py @@ -3,7 +3,7 @@ from dataclasses import dataclass @dataclass -class kernelInstance: +class KernelParameters: BLOCK_SIZE: int MPerBLOCK: int NPerBLOCK: int @@ -38,115 +38,116 @@ def name(self) -> str: ("x").join(map(lambda x: str(x), [ self.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE, self.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE])), self.LOOP_SCHED.lower(), - f"v{self.PIPELINE_VERSION}" + f"v{self.PIPELINE_VERSION}", + ]) -kernels_list = { +kernels_params_dict = { # id: kernel: BLOCK_SIZE| MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_MAP_M| WAVE_MAP_N| ABLOCK_TRANSFER| BBLOCK_TRANSFER| CBLOCK_TRANSFER| CBLOCK_SPV| CSHUFFLE_MX| CSHUFFLE_NX| LOOP_SCHED| PIPELINE_VERSION - 0: kernelInstance( 256, 256, 256, 64, 32, 32, 4, 4, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 4), - 1: kernelInstance( 256, 256, 256, 128, 32, 32, 4, 4, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 2: kernelInstance( 256, 256, 224, 128, 32, 32, 2, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 3: kernelInstance( 256, 256, 192, 128, 32, 32, 4, 3, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 4: kernelInstance( 256, 256, 160, 128, 32, 32, 2, 5, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 5: kernelInstance( 256, 256, 128, 128, 32, 32, 4, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 6: kernelInstance( 256, 256, 96, 128, 32, 32, 2, 3, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 7: kernelInstance( 256, 256, 64, 128, 32, 32, 4, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 8: kernelInstance( 256, 128, 256, 128, 32, 32, 2, 4, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 9: kernelInstance( 256, 128, 224, 128, 32, 32, 1, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 1, 1, "Intrawave", 3), - 10: kernelInstance( 256, 128, 192, 128, 32, 32, 2, 3, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 11: kernelInstance( 256, 128, 160, 128, 32, 32, 1, 5, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 1, 1, "Intrawave", 3), - 12: kernelInstance( 256, 128, 128, 256, 32, 32, 2, 2, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 13: kernelInstance( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 14: kernelInstance( 256, 128, 96, 256, 32, 32, 1, 3, [16, 16, 1], [16, 16, 1], [1, 64, 1, 4], [8, 8, 1], 1, 1, "Intrawave", 3), - 15: kernelInstance( 256, 128, 64, 256, 32, 32, 2, 1, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 16: kernelInstance( 256, 64, 256, 128, 32, 32, 1, 4, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 17: kernelInstance( 256, 64, 224, 128, 16, 16, 2, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 18: kernelInstance( 256, 64, 192, 256, 32, 32, 1, 3, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 19: kernelInstance( 256, 64, 192, 128, 32, 32, 1, 3, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 20: kernelInstance( 256, 64, 160, 256, 16, 16, 2, 5, [16, 16, 1], [16, 16, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 21: kernelInstance( 256, 64, 128, 256, 32, 32, 1, 2, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 22: kernelInstance( 256, 64, 96, 256, 16, 16, 2, 3, [16, 16, 1], [16, 16, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 23: kernelInstance( 256, 64, 64, 512, 32, 32, 1, 1, [32, 8, 1], [32, 8, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 24: kernelInstance( 256, 32, 256, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 25: kernelInstance( 256, 32, 224, 256, 16, 16, 1, 7, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 3), - 26: kernelInstance( 256, 32, 192, 256, 16, 16, 1, 6, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), - 27: kernelInstance( 256, 32, 160, 256, 16, 16, 1, 5, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 3), - 28: kernelInstance( 256, 32, 128, 256, 32, 32, 1, 1, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 29: kernelInstance( 256, 32, 96, 256, 16, 16, 1, 3, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 3), - 30: kernelInstance( 256, 32, 64, 512, 16, 16, 1, 2, [32, 8, 1], [32, 8, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), - 31: kernelInstance( 256, 16, 256, 128, 16, 16, 1, 4, [16, 16, 1], [8, 32, 1], [1, 16, 1, 16], [8, 8, 1], 1, 2, "Intrawave", 3), - 32: kernelInstance( 256, 16, 192, 256, 16, 16, 1, 3, [16, 16, 1], [16, 16, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Intrawave", 3), - 33: kernelInstance( 256, 16, 128, 256, 16, 16, 1, 2, [16, 16, 1], [16, 16, 1], [1, 16, 1, 16], [8, 8, 1], 1, 2, "Intrawave", 3), - 34: kernelInstance( 256, 16, 64, 512, 16, 16, 1, 1, [32, 8, 1], [32, 8, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Intrawave", 3), - 35: kernelInstance( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 4), - 36: kernelInstance( 256, 128, 128, 64, 32, 32, 2, 2, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 4), - 37: kernelInstance( 256, 256, 256, 128, 16, 16, 8, 8, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), - 38: kernelInstance( 256, 256, 256, 64, 16, 16, 8, 8, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), - 39: kernelInstance( 256, 224, 256, 128, 16, 16, 7, 8, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), - 40: kernelInstance( 256, 256, 224, 128, 16, 16, 8, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 41: kernelInstance( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 5), - 42: kernelInstance( 256, 128, 256, 64, 32, 32, 2, 4, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), - 43: kernelInstance( 256, 256, 128, 64, 32, 32, 4, 2, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), - 44: kernelInstance( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), - 45: kernelInstance( 256, 128, 64, 128, 32, 32, 2, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 46: kernelInstance( 256, 64, 128, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 47: kernelInstance( 256, 64, 64, 128, 32, 32, 1, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 0: KernelParameters( 256, 256, 256, 64, 32, 32, 4, 4, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 4), + 1: KernelParameters( 256, 256, 256, 128, 32, 32, 4, 4, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 2: KernelParameters( 256, 256, 224, 128, 32, 32, 2, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 3: KernelParameters( 256, 256, 192, 128, 32, 32, 4, 3, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 4: KernelParameters( 256, 256, 160, 128, 32, 32, 2, 5, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 5: KernelParameters( 256, 256, 128, 128, 32, 32, 4, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 6: KernelParameters( 256, 256, 96, 128, 32, 32, 2, 3, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 7: KernelParameters( 256, 256, 64, 128, 32, 32, 4, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 8: KernelParameters( 256, 128, 256, 128, 32, 32, 2, 4, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 9: KernelParameters( 256, 128, 224, 128, 32, 32, 1, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 1, 1, "Intrawave", 3), + 10: KernelParameters( 256, 128, 192, 128, 32, 32, 2, 3, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 11: KernelParameters( 256, 128, 160, 128, 32, 32, 1, 5, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 1, 1, "Intrawave", 3), + 12: KernelParameters( 256, 128, 128, 256, 32, 32, 2, 2, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 13: KernelParameters( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 14: KernelParameters( 256, 128, 96, 256, 32, 32, 1, 3, [16, 16, 1], [16, 16, 1], [1, 64, 1, 4], [8, 8, 1], 1, 1, "Intrawave", 3), + 15: KernelParameters( 256, 128, 64, 256, 32, 32, 2, 1, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 16: KernelParameters( 256, 64, 256, 128, 32, 32, 1, 4, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 17: KernelParameters( 256, 64, 224, 128, 16, 16, 2, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 18: KernelParameters( 256, 64, 192, 256, 32, 32, 1, 3, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 19: KernelParameters( 256, 64, 192, 128, 32, 32, 1, 3, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 20: KernelParameters( 256, 64, 160, 256, 16, 16, 2, 5, [16, 16, 1], [16, 16, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 21: KernelParameters( 256, 64, 128, 256, 32, 32, 1, 2, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 22: KernelParameters( 256, 64, 96, 256, 16, 16, 2, 3, [16, 16, 1], [16, 16, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 23: KernelParameters( 256, 64, 64, 512, 32, 32, 1, 1, [32, 8, 1], [32, 8, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 24: KernelParameters( 256, 32, 256, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 25: KernelParameters( 256, 32, 224, 256, 16, 16, 1, 7, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 3), + 26: KernelParameters( 256, 32, 192, 256, 16, 16, 1, 6, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), + 27: KernelParameters( 256, 32, 160, 256, 16, 16, 1, 5, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 3), + 28: KernelParameters( 256, 32, 128, 256, 32, 32, 1, 1, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 29: KernelParameters( 256, 32, 96, 256, 16, 16, 1, 3, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 3), + 30: KernelParameters( 256, 32, 64, 512, 16, 16, 1, 2, [32, 8, 1], [32, 8, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), + 31: KernelParameters( 256, 16, 256, 128, 16, 16, 1, 4, [16, 16, 1], [8, 32, 1], [1, 16, 1, 16], [8, 8, 1], 1, 2, "Intrawave", 3), + 32: KernelParameters( 256, 16, 192, 256, 16, 16, 1, 3, [16, 16, 1], [16, 16, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Intrawave", 3), + 33: KernelParameters( 256, 16, 128, 256, 16, 16, 1, 2, [16, 16, 1], [16, 16, 1], [1, 16, 1, 16], [8, 8, 1], 1, 2, "Intrawave", 3), + 34: KernelParameters( 256, 16, 64, 512, 16, 16, 1, 1, [32, 8, 1], [32, 8, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Intrawave", 3), + 35: KernelParameters( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 4), + 36: KernelParameters( 256, 128, 128, 64, 32, 32, 2, 2, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 4), + 37: KernelParameters( 256, 256, 256, 128, 16, 16, 8, 8, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), + 38: KernelParameters( 256, 256, 256, 64, 16, 16, 8, 8, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), + 39: KernelParameters( 256, 224, 256, 128, 16, 16, 7, 8, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), + 40: KernelParameters( 256, 256, 224, 128, 16, 16, 8, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 41: KernelParameters( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 5), + 42: KernelParameters( 256, 128, 256, 64, 32, 32, 2, 4, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), + 43: KernelParameters( 256, 256, 128, 64, 32, 32, 4, 2, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), + 44: KernelParameters( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), + 45: KernelParameters( 256, 128, 64, 128, 32, 32, 2, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 46: KernelParameters( 256, 64, 128, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 47: KernelParameters( 256, 64, 64, 128, 32, 32, 1, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), # mem(Intrawave): Latency friendly - 48: kernelInstance( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 1), - 49: kernelInstance( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 1), - 50: kernelInstance( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 1), + 48: KernelParameters( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 1), + 49: KernelParameters( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 1), + 50: KernelParameters( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 1), # mem(Intrawave): Memory friendly, Col - 51: kernelInstance( 256, 256, 32, 128, 32, 32, 2, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), - 52: kernelInstance( 256, 256, 16, 128, 16, 16, 4, 1, [8, 32, 1], [8, 16, 1], [1, 32, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), - 53: kernelInstance( 128, 128, 32, 128, 32, 32, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), - 54: kernelInstance( 128, 128, 16, 128, 16, 16, 4, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), - 55: kernelInstance( 128, 64, 32, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), - 56: kernelInstance( 128, 64, 16, 128, 16, 16, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), - 57: kernelInstance( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), - 58: kernelInstance( 64, 16, 16, 64, 16, 16, 1, 1, [4, 16, 1], [4, 16, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 2), - 59: kernelInstance( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 2), - 60: kernelInstance( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), - 61: kernelInstance( 128, 16, 64, 128, 16, 16, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), - 62: kernelInstance( 128, 32, 64, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 2), - 63: kernelInstance( 128, 16, 128, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 64: kernelInstance( 128, 32, 128, 128, 32, 32, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Interwave", 2), - 65: kernelInstance( 256, 16, 256, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Interwave", 2), - 66: kernelInstance( 256, 32, 256, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 16, 1, 16], [8, 8, 1], 1, 1, "Intrawave", 2), + 51: KernelParameters( 256, 256, 32, 128, 32, 32, 2, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), + 52: KernelParameters( 256, 256, 16, 128, 16, 16, 4, 1, [8, 32, 1], [8, 16, 1], [1, 32, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), + 53: KernelParameters( 128, 128, 32, 128, 32, 32, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), + 54: KernelParameters( 128, 128, 16, 128, 16, 16, 4, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), + 55: KernelParameters( 128, 64, 32, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), + 56: KernelParameters( 128, 64, 16, 128, 16, 16, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), + 57: KernelParameters( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), + 58: KernelParameters( 64, 16, 16, 64, 16, 16, 1, 1, [4, 16, 1], [4, 16, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 2), + 59: KernelParameters( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 2), + 60: KernelParameters( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), + 61: KernelParameters( 128, 16, 64, 128, 16, 16, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), + 62: KernelParameters( 128, 32, 64, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 2), + 63: KernelParameters( 128, 16, 128, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 64: KernelParameters( 128, 32, 128, 128, 32, 32, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Interwave", 2), + 65: KernelParameters( 256, 16, 256, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Interwave", 2), + 66: KernelParameters( 256, 32, 256, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 16, 1, 16], [8, 8, 1], 1, 1, "Intrawave", 2), # mem(Interwave): Latency friendly - 67: kernelInstance( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 1), - 68: kernelInstance( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 1), - 69: kernelInstance( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 1), + 67: KernelParameters( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 1), + 68: KernelParameters( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 1), + 69: KernelParameters( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 1), # mem(Interwave): Memory friendly, Col - 70: kernelInstance( 256, 256, 32, 128, 32, 32, 2, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 71: kernelInstance( 256, 256, 16, 128, 16, 16, 4, 1, [8, 32, 1], [8, 16, 1], [1, 32, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), - 72: kernelInstance( 128, 128, 32, 128, 32, 32, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 73: kernelInstance( 128, 128, 16, 128, 16, 16, 4, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), - 74: kernelInstance( 128, 64, 32, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 75: kernelInstance( 128, 64, 16, 128, 16, 16, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), - 76: kernelInstance( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), - 77: kernelInstance( 64, 16, 16, 64, 16, 16, 1, 1, [4, 16, 1], [4, 16, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 2), - 78: kernelInstance( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 2), - 79: kernelInstance( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 80: kernelInstance( 128, 16, 64, 128, 16, 16, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 81: kernelInstance( 128, 32, 64, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Interwave", 2), - 82: kernelInstance( 128, 16, 128, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 83: kernelInstance( 128, 32, 128, 128, 32, 32, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Interwave", 2), - 84: kernelInstance( 256, 16, 256, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Interwave", 2), - 85: kernelInstance( 256, 32, 256, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 16, 1, 16], [8, 8, 1], 1, 1, "Interwave", 2), + 70: KernelParameters( 256, 256, 32, 128, 32, 32, 2, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 71: KernelParameters( 256, 256, 16, 128, 16, 16, 4, 1, [8, 32, 1], [8, 16, 1], [1, 32, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), + 72: KernelParameters( 128, 128, 32, 128, 32, 32, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 73: KernelParameters( 128, 128, 16, 128, 16, 16, 4, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), + 74: KernelParameters( 128, 64, 32, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 75: KernelParameters( 128, 64, 16, 128, 16, 16, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), + 76: KernelParameters( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), + 77: KernelParameters( 64, 16, 16, 64, 16, 16, 1, 1, [4, 16, 1], [4, 16, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 2), + 78: KernelParameters( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 2), + 79: KernelParameters( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 80: KernelParameters( 128, 16, 64, 128, 16, 16, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 81: KernelParameters( 128, 32, 64, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Interwave", 2), + 82: KernelParameters( 128, 16, 128, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 83: KernelParameters( 128, 32, 128, 128, 32, 32, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Interwave", 2), + 84: KernelParameters( 256, 16, 256, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Interwave", 2), + 85: KernelParameters( 256, 32, 256, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 16, 1, 16], [8, 8, 1], 1, 1, "Interwave", 2), } default_kernels_dict = { # ( M, N, K): kernel: BLOCK_SIZE| MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_MAP_M| WAVE_MAP_N| ABLOCK_TRANSFER| BBLOCK_TRANSFER| CBLOCK_TRANSFER| CBLOCK_SPV| CSHUFFLE_MX| CSHUFFLE_NX| LOOP_SCHED|PIPELINE_VERSION - (-1): kernelInstance( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 2), - (-3): kernelInstance( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), - (-4): kernelInstance( 64, 16, 16, 256, 16, 16, 1, 1, [16, 4, 1], [16, 4, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 1), - (-5): kernelInstance( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), - (-6): kernelInstance( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), - (-7): kernelInstance( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - (-8): kernelInstance( 256, 256, 128, 64, 32, 32, 4, 2, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), - (-9): kernelInstance( 256, 224, 256, 128, 16, 16, 7, 8, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), - (-10): kernelInstance( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), + (-1): KernelParameters( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 2), + (-3): KernelParameters( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), + (-4): KernelParameters( 64, 16, 16, 256, 16, 16, 1, 1, [16, 4, 1], [16, 4, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 1), + (-5): KernelParameters( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), + (-6): KernelParameters( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), + (-7): KernelParameters( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + (-8): KernelParameters( 256, 256, 128, 64, 32, 32, 4, 2, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), + (-9): KernelParameters( 256, 224, 256, 128, 16, 16, 7, 8, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), + (-10): KernelParameters( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), } \ No newline at end of file diff --git a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu index 6944afcc..ca92ce76 100644 --- a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu +++ b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu @@ -26,9 +26,16 @@ static constexpr int nextPow2(unsigned int num) return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } -template -RowwiseKernel rowwise_dispatch(int id) -{ +template < + typename ADataType, + typename BDataType, + typename AccDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename DDataType, + typename EDataType +> +RowwiseKernel rowwise_dispatch(int id){ // For a given shape, either find the best kernel via lookup or heuristic. // For many small M shapes, we bucket them to the next largest kernel. // This is fine since kernels are padded anyway. @@ -37,15 +44,35 @@ RowwiseKernel rowwise_dispatch(int id) static const auto lookup = [] { if constexpr (std::is_same_v) { - return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType,F16)}; + return RowwiseKernelMap{ + GENERATE_LOOKUP_TABLE( + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + F16 + ) + }; } else if constexpr (std::is_same_v) { - return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType,B16)}; + return RowwiseKernelMap{ + GENERATE_LOOKUP_TABLE( + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + B16 + ) + }; } else { - static_assert(false, "rowwise_dispatch used with unsupported dtype!"); - } }(); + static_assert(false, "rowwise_dispatch used with unsupported dtype!"); + } + }(); - TORCH_CHECK(id < lookup.size(), - "Kernel id " + std::to_string(id) +" is out of range!"); + TORCH_CHECK(id < lookup.size(), "Kernel id " + std::to_string(id) +" is out of range!"); auto it = lookup.find(id); // If we found an optimal kernel, use it. if (it != lookup.end()) @@ -57,6 +84,33 @@ RowwiseKernel rowwise_dispatch(int id) } +torch::Tensor gemm_a8w8_int8( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y, + int kernelId, + int KBatch) +{ + std::optional bias = std::nullopt; + rowwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + return Y; +} + +torch::Tensor gemm_a8w8_fp8( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y, + int kernelId, + int KBatch) +{ + std::optional bias = std::nullopt; + rowwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + return Y; +} torch::Tensor gemm_a8w8_tune( torch::Tensor &XQ, @@ -65,14 +119,13 @@ torch::Tensor gemm_a8w8_tune( torch::Tensor &w_scale, torch::Tensor &Y, int kernelId, - int splitK) + int splitK + ) { const auto is_int8 = XQ.dtype() == at::ScalarType::Char && XQ.dtype() == WQ.dtype(); const auto is_fp8 = XQ.dtype() == at::ScalarType::Float8_e4m3fnuz && XQ.dtype() == WQ.dtype(); - TORCH_CHECK(is_int8 || is_fp8, "Weights and activations should both be int8 or fp8!"); - - TORCH_CHECK( x_scale.dtype() == w_scale.dtype(), - "Scales should have the same dtype!"); + TORCH_CHECK(is_int8 || is_fp8, "Weights and activations should both be either int8 or fp8!"); + TORCH_CHECK( x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); std::optional bias = std::nullopt; int M = XQ.size(0); @@ -80,26 +133,11 @@ torch::Tensor gemm_a8w8_tune( int K = XQ.size(1); int KBatch = std::pow(2, splitK); - // if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) - // { - // rowwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias); - // } - // else if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) - // { - // rowwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias); - // } - // else if (Y.dtype() == at::ScalarType::Half) - // { - // rowwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias); - // } - // else - if (Y.dtype() == at::ScalarType::BFloat16) - { - rowwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - } - else - { + if(Y.dtype() != at::ScalarType::BFloat16) TORCH_CHECK(false, "Unsupported scales/output dtype!"); - } - return Y; -} + + if(is_fp8) + return gemm_a8w8_fp8(XQ, WQ, x_scale, w_scale, Y, kernelId, KBatch); + + return gemm_a8w8_int8(XQ, WQ, x_scale, w_scale, Y, kernelId, KBatch); +} \ No newline at end of file diff --git a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py index 0d7993c4..6d805520 100644 --- a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py +++ b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py @@ -8,7 +8,7 @@ import torch.nn.functional as F import aiter from aiter.test_common import checkAllclose, perftest -from gemm_a8w8_common import kernelInstance, kernels_list +from gemm_a8w8_common import kernels_params_dict import argparse def checkClose(a, b, rtol=1e-3, atol=0.01): @@ -48,11 +48,17 @@ def kernel_instance_test(x, weight, x_scale, w_scale, out, kernel_id, splitK=0): aiter.gemm_a8w8_tune(x, weight, x_scale, w_scale, out, kernel_id, splitK) return out +def random_tensor(a: torch.tensor, b: torch.tensor, dtype: torch.dtype) -> torch.tensor: + if dtype == torch.int8: + return torch.randint(-20, 20, (a, b), dtype=dtype, device="cuda") + elif dtype == torch.float8_e4m3fnuz: + return torch.rand((a, b), device="cuda").to(dtype) + raise RuntimeError("Unsupported data type.") -def tune_gemm(m, n, k, useSplitK = False): +def tune_gemm(m, n, k, dtype: torch.dtype, useSplitK = False): dim = (m, n, k) - x = torch.randint(-20, 20, (m, k), dtype=torch.int8, device="cuda") - weight = torch.randint(-20, 20, (n, k), dtype=torch.int8, device="cuda") + x = random_tensor(m, k, dtype) + weight = random_tensor(n, k, dtype) x_scale = torch.rand([m, 1], dtype=torch.bfloat16, device="cuda") w_scale = torch.rand([1, n], dtype=torch.bfloat16, device="cuda") out = torch.empty(m, n, dtype=torch.bfloat16, device="cuda") @@ -61,11 +67,11 @@ def tune_gemm(m, n, k, useSplitK = False): print(f"*******************M:{m} X N:{n} X K:{k}**************************") print(f"Start tuning a8w8 gemm kernel for M:{m}, N:{n}, K{k}:") - kernels_num = len(kernels_list) + kernels_num = len(kernels_params_dict) best_kernelConfig = (-1, 0) best_time = -1 for i in range(kernels_num): - kernel = kernels_list[i] + kernel = kernels_params_dict[i] maxsplitK = aiter.compute_gemm_SplitK(m, n, k, kernel.MPerBLOCK, kernel.NPerBLOCK, kernel.KPerBLOCK) \ if useSplitK else 0 for splitK in range(maxsplitK+1): @@ -80,6 +86,7 @@ def tune_gemm(m, n, k, useSplitK = False): else: print(f"{str(dim):<20} kernelid:{i:<3d}\t No pass , {kernel.name}, {splitK=}") except RuntimeError as e: + print(str(e)) print(f"{str(dim):<20} kernelid:{i:<3d}\t No support , {kernel.name}, {splitK=}") best_kernelId, splitK = best_kernelConfig @@ -95,15 +102,16 @@ def tune_gemm(m, n, k, useSplitK = False): return best_kernelId, splitK, best_time -def tune_gemm_list(untunedf, tunedf, issorted = False, useSplitK = False): +def tune_gemm_list(untunedf, tunedf, dtype: torch.dtype, issorted = False, useSplitK = False): + print("untuned df is \n\n", untunedf) for i in range(len(untunedf)): M = untunedf.loc[i, "M"] N = untunedf.loc[i, "N"] K = untunedf.loc[i, "K"] if tunedf[(tunedf["M"]==M) & (tunedf["N"]==N) & (tunedf["K"]==K)].empty: - kernelId, splitK, time = tune_gemm(M, N, K, useSplitK) - kernelName = 'None' if kernelId == -1 else kernels_list[kernelId].name + kernelId, splitK, time = tune_gemm(M, N, K, dtype, useSplitK) + kernelName = 'None' if kernelId == -1 else kernels_params_dict[kernelId].name temp = pd.DataFrame({"M":[M], "N":[N], "K":[K], "kernelId":[kernelId], "splitK":[splitK], "us":[time], "kernelName":[kernelName]}) tunedf = pd.concat([tunedf, temp], ignore_index=True) @@ -150,6 +158,14 @@ def tune_gemm_list(untunedf, tunedf, issorted = False, useSplitK = False): help="Use splitK kernels" ) + parser.add_argument( + "-d", + "--dtype", + required=False, + default="int8", + help="int8 or fp8" + ) + parser.add_argument( "--sort", action='store_true', @@ -158,7 +174,8 @@ def tune_gemm_list(untunedf, tunedf, issorted = False, useSplitK = False): ) args = parser.parse_args() + dtype = torch.float8_e4m3fnuz if args.dtype == "fp8" else torch.int8 untunedf = get_untuned_gemm_list(args.untune_file) tunedf = get_tuned_gemm_list(args.tune_file) - tunedf = tune_gemm_list(untunedf, tunedf, args.sort, args.splitK) + tunedf = tune_gemm_list(untunedf, tunedf, dtype, args.sort, args.splitK) tunedf.to_csv(args.tune_file, index=False) diff --git a/csrc/ck_gemm_a8w8/gen_instances.py b/csrc/ck_gemm_a8w8/gen_instances.py index 347e8b23..ddbd3207 100644 --- a/csrc/ck_gemm_a8w8/gen_instances.py +++ b/csrc/ck_gemm_a8w8/gen_instances.py @@ -1,31 +1,200 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. import os -import sys -from dataclasses import dataclass -import copy from pathlib import Path import pandas as pd import argparse -import shutil -from gemm_a8w8_common import kernelInstance, kernels_list, default_kernels_dict +from gemm_a8w8_common import KernelParameters, kernels_params_dict, default_kernels_dict + +TEMPLATE_PARAMS = "template_params" +INSTANCE_FILE_SUFFIX = "instance_file_suffix" +LOOK_UP_TABLE_HEADER_PATH = "gemm_a8w8_lookup.h" +MANIFEST_HEADER_PATH = "gemm_a8w8_manifest.h" + +PARAM_INSTANCE_SUFFIX_LIST= [ + { + TEMPLATE_PARAMS: "I8, I8, I32, I8, I32, B16, B16", + INSTANCE_FILE_SUFFIX: "I8_I8_I32_I8_I32_BF16_BF16.cpp" + }, + { + TEMPLATE_PARAMS: "I8, I8, I32, I8, I32, F32, B16", + INSTANCE_FILE_SUFFIX: "I8_I8_I32_I8_I32_F32_BF16.cpp" + }, + { + TEMPLATE_PARAMS: "I8, I8, I32, I8, I32, F16, F16", + INSTANCE_FILE_SUFFIX: "I8_I8_I32_I8_I32_F16_F16.cpp" + }, + { + TEMPLATE_PARAMS: "I8, I8, I32, I8, I32, F32, F16", + INSTANCE_FILE_SUFFIX: "I8_I8_I32_I8_I32_F32_F16.cpp" + }, + { + TEMPLATE_PARAMS: "FP8, FP8, F32, FP8, F32, B16, B16", + INSTANCE_FILE_SUFFIX: "FP8_FP8_F32_FP8_F32_B16_B16.cpp" + }, + { + TEMPLATE_PARAMS: "FP8, FP8, F32, FP8, F32, F32, B16", + INSTANCE_FILE_SUFFIX: "FP8_FP8_F32_FP8_F32_F32_B16.cpp" + }, + { + TEMPLATE_PARAMS: "FP8, FP8, F32, FP8, F32, F16, F16", + INSTANCE_FILE_SUFFIX: "FP8_FP8_F32_FP8_F32_F16_F16.cpp" + }, + { + TEMPLATE_PARAMS: "FP8, FP8, F32, FP8, F32, F32, F16", + INSTANCE_FILE_SUFFIX: "FP8_FP8_F32_FP8_F32_F32_F16.cpp" + } +] + +TUNING_PARAM_INSTANCE_SUFFIX_LIST = [ + { + TEMPLATE_PARAMS: "I8, I8, I32, I8, I32, B16, B16", + INSTANCE_FILE_SUFFIX: "I8_I8_I32_I8_I32_BF16_BF16.cpp" + }, + { + TEMPLATE_PARAMS: "FP8, FP8, F32, FP8, F32, B16, B16", + INSTANCE_FILE_SUFFIX: "FP8_FP8_F32_FP8_F32_B16_B16.cpp" + }, +] + +def gen_a8w8_device_gemm_call_skip_bias_branch(k: KernelParameters, gemm_specialization: str) -> str: + return f"""using DeviceGemmInstance = DeviceGemmHelper< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType, + {k.BLOCK_SIZE}, + {k.MPerBLOCK}, + {k.NPerBLOCK}, + {k.KPerBLOCK}, + {k.WAVE_TILE_M}, + {k.WAVE_TILE_N}, + {k.WAVE_MAP_M}, + {k.WAVE_MAP_N}, + S<{(", ").join(map(lambda x:str(x),k.ABLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.BBLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_SPV))}>, + {k.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE}, + {k.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE}, + ck::BlockGemmPipelineScheduler::{k.LOOP_SCHED}, + ck::BlockGemmPipelineVersion::v{k.PIPELINE_VERSION}, + ck::tensor_operation::device::GemmSpecialization::{gemm_specialization}>; + + return gemm_a8w8_rowwise_impl< + ADataType, + BDataType, + AccDataType, + DDataType, + EDataType, + DeviceGemmInstance + >(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); +""" +def gen_a8w8_device_gemm_call(k: KernelParameters, gemm_specialization: str): + return f"""if (bias != std::nullopt) + {{ + using DeviceGemmInstance = DeviceGemmHelperMMA< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType, + {k.BLOCK_SIZE}, + {k.MPerBLOCK}, + {k.NPerBLOCK}, + {k.KPerBLOCK}, + {k.WAVE_TILE_M}, + {k.WAVE_TILE_N}, + {k.WAVE_MAP_M}, + {k.WAVE_MAP_N}, + S<{(", ").join(map(lambda x:str(x),k.ABLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.BBLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_SPV))}, {k.CBLOCK_SPV[0]}>, + {k.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE}, + {k.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE}, + ck::BlockGemmPipelineScheduler::{k.LOOP_SCHED}, + ck::BlockGemmPipelineVersion::v{k.PIPELINE_VERSION}, + ck::tensor_operation::device::GemmSpecialization::{gemm_specialization}>; + // Run kernel instance. + + return gemm_a8w8_mma_impl< + ADataType, + BDataType, + DDataType, + EDataType, + DeviceGemmInstance + >(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + }} + else + {{ + using DeviceGemmInstance = DeviceGemmHelper< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType, + {k.BLOCK_SIZE}, + {k.MPerBLOCK}, + {k.NPerBLOCK}, + {k.KPerBLOCK}, + {k.WAVE_TILE_M}, + {k.WAVE_TILE_N}, + {k.WAVE_MAP_M}, + {k.WAVE_MAP_N}, + S<{(", ").join(map(lambda x:str(x),k.ABLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.BBLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_SPV))}>, + {k.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE}, + {k.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE}, + ck::BlockGemmPipelineScheduler::{k.LOOP_SCHED}, + ck::BlockGemmPipelineVersion::v{k.PIPELINE_VERSION}, + ck::tensor_operation::device::GemmSpecialization::{gemm_specialization}>; + + return gemm_a8w8_rowwise_impl< + ADataType, + BDataType, + AccDataType, + DDataType, + EDataType, + DeviceGemmInstance + >(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + }} +""" -class gemm_a8w8_fwd_codegen: - def __init__(self, working_path, istune=False): - self.working_path = working_path - self.impl_path = os.path.join(working_path, "impl") - self.instances_path = os.path.join(working_path, "instances") - self.istune = istune - - def gen_instance(self, k: kernelInstance): - INSTANCE_IMPL = f"""// SPDX-License-Identifier: MIT +def gen_a8w8_implementation(k: KernelParameters, skip_bias_branch: bool) -> str: + if skip_bias_branch: + gemm_a8w8_device_gemm_instance_generator = gen_a8w8_device_gemm_call_skip_bias_branch + else: + gemm_a8w8_device_gemm_instance_generator = gen_a8w8_device_gemm_call + + padding = "MNKPadding" + no_padding = "Default" + return f""" +// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_a8w8_common.cuh" -template +template < + typename ADataType, + typename BDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename AccDataType, + typename DDataType, + typename EDataType + > torch::Tensor {k.name}( torch::Tensor &XQ, @@ -35,117 +204,32 @@ def gen_instance(self, k: kernelInstance): torch::Tensor &Y, std::optional bias, int KBatch) -{{{{ +{{ // The smallest kernel we have available. Works well for memory bound shapes. // Check if this input needs to be padded. int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); int N = WQ.size(0); int K = WQ.size(1); + bool pad = (M % {k.MPerBLOCK} != 0) || (N % {k.NPerBLOCK} != 0) || (K % ({k.KPerBLOCK} * KBatch) != 0); - if (pad) - {{{{ - // pad - {{INSTANCE_CONTENT_pad}} - // pad - }}}} - else - {{{{ - // no pad - {{INSTANCE_CONTENT_nopad}} - // no pad - }}}} -}}}} - -""" - INSTANCE_CONTENT_bias = f"""if (bias != std::nullopt) - {{{{ - using DeviceGemmInstance = DeviceGemmHelperMMA< - DDataType, EDataType, - {k.BLOCK_SIZE}, - {k.MPerBLOCK}, - {k.NPerBLOCK}, - {k.KPerBLOCK}, - {k.WAVE_TILE_M}, - {k.WAVE_TILE_N}, - {k.WAVE_MAP_M}, - {k.WAVE_MAP_N}, - S<{(", ").join(map(lambda x:str(x),k.ABLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.BBLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.CBLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.CBLOCK_SPV))}, {k.CBLOCK_SPV[0]}>, - {k.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE}, - {k.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE}, - ck::BlockGemmPipelineScheduler::{k.LOOP_SCHED}, - ck::BlockGemmPipelineVersion::v{k.PIPELINE_VERSION}, - ck::tensor_operation::device::GemmSpecialization::{{GemmSpec}}>; - // Run kernel instance. - return gemm_a8w8_mma_impl(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - }}}} - else - {{{{ - using DeviceGemmInstance = DeviceGemmHelper< - DDataType, EDataType, - {k.BLOCK_SIZE}, - {k.MPerBLOCK}, - {k.NPerBLOCK}, - {k.KPerBLOCK}, - {k.WAVE_TILE_M}, - {k.WAVE_TILE_N}, - {k.WAVE_MAP_M}, - {k.WAVE_MAP_N}, - S<{(", ").join(map(lambda x:str(x),k.ABLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.BBLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.CBLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.CBLOCK_SPV))}>, - {k.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE}, - {k.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE}, - ck::BlockGemmPipelineScheduler::{k.LOOP_SCHED}, - ck::BlockGemmPipelineVersion::v{k.PIPELINE_VERSION}, - ck::tensor_operation::device::GemmSpecialization::{{GemmSpec}}>; - // Run kernel instance. - return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - }}}} -""" - INSTANCE_CONTENT_nobias = f"""using DeviceGemmInstance = DeviceGemmHelper< - DDataType, EDataType, - {k.BLOCK_SIZE}, - {k.MPerBLOCK}, - {k.NPerBLOCK}, - {k.KPerBLOCK}, - {k.WAVE_TILE_M}, - {k.WAVE_TILE_N}, - {k.WAVE_MAP_M}, - {k.WAVE_MAP_N}, - S<{(", ").join(map(lambda x:str(x),k.ABLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.BBLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.CBLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.CBLOCK_SPV))}>, - {k.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE}, - {k.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE}, - ck::BlockGemmPipelineScheduler::{k.LOOP_SCHED}, - ck::BlockGemmPipelineVersion::v{k.PIPELINE_VERSION}, - ck::tensor_operation::device::GemmSpecialization::{{GemmSpec}}>; - // Run kernel instance. - return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); -""" - if self.istune: - INSTANCE_IMPL_str = INSTANCE_IMPL.format(INSTANCE_CONTENT_pad=(INSTANCE_CONTENT_nobias.format(GemmSpec="MNKPadding")), - INSTANCE_CONTENT_nopad=(INSTANCE_CONTENT_nobias.format(GemmSpec="Default"))) - else: - INSTANCE_IMPL_str = INSTANCE_IMPL.format(INSTANCE_CONTENT_pad=INSTANCE_CONTENT_bias.format(GemmSpec="MNKPadding"), - INSTANCE_CONTENT_nopad=INSTANCE_CONTENT_bias.format(GemmSpec="Default")) - - Path(os.path.join(self.impl_path, f"{k.name}.cuh")).write_text( - INSTANCE_IMPL_str) - - INSTANCE_template = """// SPDX-License-Identifier: MIT + if (pad) {{ + {gemm_a8w8_device_gemm_instance_generator(k, padding)} + }} + else{{ + {gemm_a8w8_device_gemm_instance_generator(k, no_padding)} + }} +}} +""" + +def gen_a8w8_instance(k: KernelParameters, template_params: str) -> str: + return f"""// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include "{name}.cuh" +#include "{k.name}.cuh" template torch::Tensor -{name}<{dtypes}>( +{k.name}<{template_params}>( torch::Tensor &XQ, torch::Tensor &WQ, torch::Tensor &x_scale, @@ -153,73 +237,42 @@ def gen_instance(self, k: kernelInstance): torch::Tensor &Y, std::optional bias, int KBatch); - """ - INSTANCE_dBF16_eBF16 = INSTANCE_template.format( - name=k.name, dtypes="B16") - INSTANCE_dFP32_eBF16 = INSTANCE_template.format( - name=k.name, dtypes="F32, B16") - INSTANCE_dFP16_eFP16 = INSTANCE_template.format( - name=k.name, dtypes="F16") - INSTANCE_dFP32_eFP16 = INSTANCE_template.format( - name=k.name, dtypes="F32, F16") - - if self.istune: - Path(os.path.join(self.instances_path, f"{k.name}_dBF16_eBF16.cpp")).write_text( - INSTANCE_dBF16_eBF16) - else: - Path(os.path.join(self.instances_path, f"{k.name}_dBF16_eBF16.cpp")).write_text( - INSTANCE_dBF16_eBF16) - Path(os.path.join(self.instances_path, f"{k.name}_dFP32_eBF16.cpp")).write_text( - INSTANCE_dFP32_eBF16) - Path(os.path.join(self.instances_path, f"{k.name}_dFP16_eFP16.cpp")).write_text( - INSTANCE_dFP16_eFP16) - Path(os.path.join(self.instances_path, f"{k.name}_dFP32_eFP16.cpp")).write_text( - INSTANCE_dFP32_eFP16) - - def gen_lookup_dict(self, kernels_dict): - LOOKUP_head = """#pragma once -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#ifdef USE_ROCM -#define GENERATE_LOOKUP_TABLE(DTYPE, ETYPE) \\ - { \\""" +def gen_kernel_dict_item_as_str(mnk: tuple, k: KernelParameters) -> str: + mnk_formatted = ', '.join(str(item) for item in mnk) + return f"{{ {{{mnk_formatted}}}, {k.name}}}" - LOOKUP_template = """ - {{{MNK}, \\ - {kernel_name}}}, \\""" +def gen_lookup_dict(kernel_dict: dict) -> str: - LOOKUP_end = """ - } - -#endif // USE_ROCM -""" - with open(os.path.join(self.working_path, "gemm_a8w8_lookup.h"), "w") as f: - f.write(LOOKUP_head) - for mnk, k in kernels_dict.items(): - # print((", ").join(map(lambda x: str(x), list(mnk))), ":", k.name) - if not self.istune and (isinstance(mnk, tuple) and mnk[0] > 0): - f.write(LOOKUP_template.format(MNK="{"+(", ").join( - map(lambda x: str(x), list(mnk))) + "}", kernel_name=k.name)) - elif self.istune and isinstance(mnk, int): - f.write(LOOKUP_template.format(MNK=mnk, kernel_name=k.name)) - f.write(LOOKUP_end) - - def gen_manifest_head(self, kernels_dict): - MAINFEST_head = """#pragma once + kernel_dict_items = [ + gen_kernel_dict_item_as_str(mnk, k) + for mnk, k in kernel_dict.items() + if isinstance(mnk, tuple) + ] + return f"""#pragma once // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #ifdef USE_ROCM - -#include - -#include +#define GENERATE_LOOKUP_TABLE(A_TYPE, B_TYPE, C_SHUFFLE_TYPE, COMPUTE_TYPE, ACC_TYPE, D_TYPE, E_TYPE) \\ +{{ \\ + {",\\\n ".join(kernel_dict_items) + " \\"} +}} +#endif // USE_ROCM """ - MAINFEST_template = """ -template + +def gen_kernel_definition(kernel_name: str) -> str: + return f""" +template < + typename ADataType, + typename BDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename AccDataType, + typename DDataType, + typename EDataType +> torch::Tensor {kernel_name}( torch::Tensor &XQ, @@ -229,43 +282,72 @@ def gen_manifest_head(self, kernels_dict): torch::Tensor &Y, std::optional bias, int KBatch); -""" - MAINFEST_end = """ + """ -#endif // USE_ROCM -""" +def gen_manifest(kernels_dict: dict) -> str: + kernel_definition_list = [ + gen_kernel_definition(k.name) for k in kernels_dict.values() + ] + return f"""#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - with open(os.path.join(self.working_path, "gemm_a8w8_manifest.h"), "w") as f: - f.write(MAINFEST_head) - for mnk, k in kernels_dict.items(): - f.write(MAINFEST_template.format(kernel_name=k.name)) - f.write(MAINFEST_end) +#ifdef USE_ROCM - def gen_instances(self, kernels_dict): - if os.path.exists(self.impl_path): - shutil.rmtree(self.impl_path) - os.mkdir(self.impl_path) - if os.path.exists(self.instances_path): - shutil.rmtree(self.instances_path) - os.mkdir(self.instances_path) +#include - for mnk, k in kernels_dict.items(): - self.gen_instance(k) +#include + {"\n".join(kernel_definition_list)} +#endif // USE_ROCM +""" - self.gen_lookup_dict(kernels_dict) - self.gen_manifest_head(kernels_dict) +def gemm_a8w8_fwd_codegen(working_path: Path, kernel_parameters_dict: dict, is_tune: bool): + impl_directory = Path(os.path.join(working_path, "impl")) + instance_directory = Path(os.path.join(working_path, "instances")) + impl_directory.mkdir(exist_ok=True) + instance_directory.mkdir(exist_ok=True) + # generate and write the implementation files + for _, kernel_parameters in kernel_parameters_dict.items(): + impl_path = Path(os.path.join(impl_directory, f"{kernel_parameters.name}.cuh")) + kernels_impl_str = gen_a8w8_implementation(kernel_parameters, is_tune) + impl_path.write_text(kernels_impl_str) + + + # generate and write the implementation files for each supported specialization + for _, kernel_parameters in kernel_parameters_dict.items(): + param_instance_list = TUNING_PARAM_INSTANCE_SUFFIX_LIST if is_tune else PARAM_INSTANCE_SUFFIX_LIST + for param_instance in param_instance_list: + template_params = param_instance[TEMPLATE_PARAMS] + instance_file_suffix = param_instance[INSTANCE_FILE_SUFFIX] + instance_file_name = f"{kernel_parameters.name}_{instance_file_suffix.lower()}" + instance_path = Path(os.path.join(instance_directory, instance_file_name)) + kernel_instance_str = gen_a8w8_instance(kernel_parameters, template_params) + instance_path.write_text(kernel_instance_str) + + # generate and write the lookup table + look_up_dict_str = gen_lookup_dict(kernel_parameters_dict) + look_up_table_header_path = Path(os.path.join(working_path, LOOK_UP_TABLE_HEADER_PATH)) + look_up_table_header_path.write_text(look_up_dict_str) + + # generate and write the manifest + manifest_str = gen_manifest(kernel_parameters_dict) + manifest_header_path = Path(os.path.join(working_path, MANIFEST_HEADER_PATH)) + manifest_header_path.write_text(manifest_str) + +def get_tune_dict(tune_dict_csv: Path) -> dict: + if not os.path.exists(tune_dict_csv): + return default_kernels_dict + + tune_dict = default_kernels_dict + tune_df = pd.read_csv(tune_dict_csv) + for i in range(len(tune_df)): + M = tune_df.loc[i, "M"] + N = tune_df.loc[i, "N"] + K = tune_df.loc[i, "K"] + kid = tune_df.loc[i, "kernelId"] + tune_dict[(M, N, K)] = kernels_params_dict[kid] -def get_tune_dict(tune_dict_csv): - tune_dict = default_kernels_dict - if os.path.exists(tune_dict_csv): - tune_df = pd.read_csv(tune_dict_csv) - for i in range(len(tune_df)): - M = tune_df.loc[i, "M"] - N = tune_df.loc[i, "N"] - K = tune_df.loc[i, "K"] - kid = tune_df.loc[i, "kernelId"] - tune_dict[(M, N, K)] = kernels_list[kid] return tune_dict if __name__ == "__main__": @@ -298,29 +380,10 @@ def get_tune_dict(tune_dict_csv): help="generated tune instanses" ) - # parser.add_argument( - # "--out_type", - # default="all", - # required=False, - # help="Specifie the type of scale\n \ - # all: [bf16, fp16] \n \ - # bf16, fp16" - # ) - - # parser.add_argument( - # "--scale_type", - # default="all", - # required=False, - # help="Specifie the type of scale\n \ - # all: [fp32, same as out] \n \ - # same: [same as out]" - # ) - - args = parser.parse_args() - codegen = gemm_a8w8_fwd_codegen(args.working_path, args.tune) + if args.tune: - codegen.gen_instances(kernels_list) + gemm_a8w8_fwd_codegen(args.working_path, kernels_params_dict, args.tune) else: - codegen.gen_instances(get_tune_dict(args.tune_file)) + gemm_a8w8_fwd_codegen(args.working_path, get_tune_dict(args.tune_file), args.tune) diff --git a/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh b/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh index 4265514b..5ee83ca1 100644 --- a/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh +++ b/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh @@ -46,12 +46,6 @@ using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; -using ADataType = FP8; -using BDataType = FP8; -using AccDataType = F32; -using CShuffleDataType = F32; -using ComputeDataType = FP8; - using ALayout = Row; using BLayout = Col; using D0Layout = Row; @@ -66,6 +60,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough; using BElementOp = PassThrough; +template struct RowwiseScale { template @@ -203,7 +198,6 @@ struct MultiplyMultiplyAdd } }; -using CDEElementOp = RowwiseScale; using CDEElementOp2 = MultiplyMultiplyAdd; template @@ -212,23 +206,15 @@ using DsDataType = ck::Tuple; template using DsDataType2 = ck::Tuple; -#if 0 -template -using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 -// clang-format off -///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| -///###### RRR -/// < Row, Row, DsLayout, ELayout, ADataType, BDataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>; -///###### RCR - < Row, Col, DsLayout, ELayout, ADataType, BDataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>; -// clang-format on -#endif template < - typename DDataType, typename EDataType, + typename ADataType, + typename BDataType, + typename AccDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename DDataType, + typename EDataType, int BLOCK_SIZE, int MBLOCK, int NBLOCK, @@ -261,7 +247,7 @@ using DeviceGemmHelper = CShuffleDataType, AElementOp, BElementOp, - CDEElementOp, + RowwiseScale, GEMM_SPEC, BLOCK_SIZE, // Block Size MBLOCK, // M per Block @@ -296,7 +282,13 @@ using DeviceGemmHelper = ComputeDataType>; template < - typename DDataType, typename EDataType, + typename ADataType, + typename BDataType, + typename AccDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename DDataType, + typename EDataType, int BLOCK_SIZE, int MBLOCK, int NBLOCK, @@ -363,8 +355,14 @@ using DeviceGemmHelperMMA = PIPELINE_VERSION, ComputeDataType>; -template +template < + typename ADataType, + typename BDataType, + typename AccDataType, + typename DDataType, + typename EDataType, + typename DeviceGemmInstance +> __forceinline__ torch::Tensor gemm_a8w8_rowwise_impl( torch::Tensor &XQ, torch::Tensor &WQ, @@ -388,7 +386,7 @@ __forceinline__ torch::Tensor gemm_a8w8_rowwise_impl( auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; + auto cde_element_op = RowwiseScale{}; constexpr ck::index_t NumDTensor = DeviceGemmInstance::NumDTensor; @@ -416,8 +414,13 @@ __forceinline__ torch::Tensor gemm_a8w8_rowwise_impl( return Y; } -template +template < + typename ADataType, + typename BDataType, + typename DDataType, + typename EDataType, + typename DeviceGemmInstance +> __forceinline__ torch::Tensor gemm_a8w8_mma_impl( torch::Tensor &XQ, torch::Tensor &WQ, From a201d3a073efb39a213f11ebb6c4b21cece84a78 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 13 Feb 2025 05:34:57 +0000 Subject: [PATCH 03/13] tuning bug fix --- csrc/ck_gemm_a8w8/gemm_a8w8_tune.py | 4 +--- csrc/ck_gemm_a8w8/gen_instances.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py index 6d805520..ccbe9b6c 100644 --- a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py +++ b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py @@ -67,10 +67,9 @@ def tune_gemm(m, n, k, dtype: torch.dtype, useSplitK = False): print(f"*******************M:{m} X N:{n} X K:{k}**************************") print(f"Start tuning a8w8 gemm kernel for M:{m}, N:{n}, K{k}:") - kernels_num = len(kernels_params_dict) best_kernelConfig = (-1, 0) best_time = -1 - for i in range(kernels_num): + for i in range(len(kernels_params_dict)): kernel = kernels_params_dict[i] maxsplitK = aiter.compute_gemm_SplitK(m, n, k, kernel.MPerBLOCK, kernel.NPerBLOCK, kernel.KPerBLOCK) \ if useSplitK else 0 @@ -86,7 +85,6 @@ def tune_gemm(m, n, k, dtype: torch.dtype, useSplitK = False): else: print(f"{str(dim):<20} kernelid:{i:<3d}\t No pass , {kernel.name}, {splitK=}") except RuntimeError as e: - print(str(e)) print(f"{str(dim):<20} kernelid:{i:<3d}\t No support , {kernel.name}, {splitK=}") best_kernelId, splitK = best_kernelConfig diff --git a/csrc/ck_gemm_a8w8/gen_instances.py b/csrc/ck_gemm_a8w8/gen_instances.py index ddbd3207..48c59414 100644 --- a/csrc/ck_gemm_a8w8/gen_instances.py +++ b/csrc/ck_gemm_a8w8/gen_instances.py @@ -239,16 +239,20 @@ def gen_a8w8_instance(k: KernelParameters, template_params: str) -> str: int KBatch); """ -def gen_kernel_dict_item_as_str(mnk: tuple, k: KernelParameters) -> str: - mnk_formatted = ', '.join(str(item) for item in mnk) +def gen_kernel_dict_item_as_str(mnk: tuple | int, k: KernelParameters) -> str: + if isinstance(mnk, tuple): + mnk_formatted = ', '.join(str(item) for item in mnk) + else: + mnk_formatted = f"{str(mnk)}" return f"{{ {{{mnk_formatted}}}, {k.name}}}" -def gen_lookup_dict(kernel_dict: dict) -> str: - +def gen_lookup_dict(kernel_dict: dict, is_tune: bool) -> str: + # Do not include default kernels in the lookup table for non-tuning calls. + filter_mnk = lambda mnk : True if is_tune else isinstance(mnk, tuple) kernel_dict_items = [ gen_kernel_dict_item_as_str(mnk, k) for mnk, k in kernel_dict.items() - if isinstance(mnk, tuple) + if filter_mnk(mnk) ] return f"""#pragma once // SPDX-License-Identifier: MIT @@ -325,7 +329,7 @@ def gemm_a8w8_fwd_codegen(working_path: Path, kernel_parameters_dict: dict, is_t instance_path.write_text(kernel_instance_str) # generate and write the lookup table - look_up_dict_str = gen_lookup_dict(kernel_parameters_dict) + look_up_dict_str = gen_lookup_dict(kernel_parameters_dict, is_tune) look_up_table_header_path = Path(os.path.join(working_path, LOOK_UP_TABLE_HEADER_PATH)) look_up_table_header_path.write_text(look_up_dict_str) From ecbefb907599472fb09cbb85792f0713c6732cde Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Fri, 14 Feb 2025 11:11:20 +0000 Subject: [PATCH 04/13] model MNK Shape config Signed-off-by: tjtanaa --- README.md | 3 + .../Llama-3.1-70B-Instruct-TP1.csv | 189 ++++++++++++++++++ .../Llama-3.1-70B-Instruct-TP2.csv | 189 ++++++++++++++++++ .../Llama-3.1-70B-Instruct-TP4.csv | 189 ++++++++++++++++++ .../Llama-3.1-70B-Instruct-TP8.csv | 189 ++++++++++++++++++ .../Llama-3.1-8B-Instruct-TP1.csv | 189 ++++++++++++++++++ .../Llama-3.1-8B-Instruct-TP2.csv | 189 ++++++++++++++++++ .../Llama-3.1-8B-Instruct-TP4.csv | 189 ++++++++++++++++++ .../Llama-3.1-8B-Instruct-TP8.csv | 189 ++++++++++++++++++ .../Mixtral-8x22B-Instruct-v0.1-TP2.csv | 189 ++++++++++++++++++ .../Mixtral-8x22B-Instruct-v0.1-TP4.csv | 189 ++++++++++++++++++ .../Mixtral-8x22B-Instruct-v0.1-TP8.csv | 189 ++++++++++++++++++ .../Mixtral-8x7B-Instruct-v0.1-TP1.csv | 189 ++++++++++++++++++ .../Mixtral-8x7B-Instruct-v0.1-TP2.csv | 189 ++++++++++++++++++ .../Mixtral-8x7B-Instruct-v0.1-TP4.csv | 189 ++++++++++++++++++ .../Mixtral-8x7B-Instruct-v0.1-TP8.csv | 189 ++++++++++++++++++ csrc/ck_gemm_a8w8/README.md | 36 +++- 17 files changed, 2873 insertions(+), 1 deletion(-) create mode 100644 aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP1.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP2.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP4.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP8.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP1.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP2.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP4.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP8.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP2.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP4.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP8.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP1.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP2.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP4.csv create mode 100644 aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP8.csv diff --git a/README.md b/README.md index adb2893e..5f3ae6e2 100644 --- a/README.md +++ b/README.md @@ -28,3 +28,6 @@ there are number of op test, you can run them like this: `python3 op_tests/test_ |GEMM | D=AxB+C | |FusedMoE | bf16 balabala | |WIP | coming soon... | + +## Ops +1. [INT8/FP8 A8W8 Per-Tensor/Rowwise Scaling GEMM](csrc/ck_gemm_a8w8/README.md) diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP1.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP1.csv new file mode 100644 index 00000000..031c2055 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP1.csv @@ -0,0 +1,189 @@ +M,N,K +1,8192,8192 +1,10240,8192 +1,8192,28672 +1,57344,8192 +2,8192,8192 +2,8192,28672 +2,10240,8192 +2,57344,8192 +4,10240,8192 +4,8192,8192 +4,57344,8192 +4,8192,28672 +8,8192,8192 +8,8192,28672 +8,57344,8192 +8,10240,8192 +16,8192,8192 +16,8192,28672 +16,57344,8192 +16,10240,8192 +17,8192,28672 +17,10240,8192 +17,57344,8192 +17,8192,8192 +24,57344,8192 +24,8192,8192 +24,8192,28672 +24,10240,8192 +32,8192,28672 +32,57344,8192 +32,10240,8192 +32,8192,8192 +40,8192,8192 +40,8192,28672 +40,10240,8192 +40,57344,8192 +48,10240,8192 +48,8192,8192 +48,8192,28672 +48,57344,8192 +56,8192,8192 +56,8192,28672 +56,57344,8192 +56,10240,8192 +64,8192,8192 +64,8192,28672 +64,57344,8192 +64,10240,8192 +72,8192,8192 +72,8192,28672 +72,57344,8192 +72,10240,8192 +80,8192,28672 +80,57344,8192 +80,10240,8192 +80,8192,8192 +88,8192,8192 +88,8192,28672 +88,10240,8192 +88,57344,8192 +96,8192,8192 +96,8192,28672 +96,57344,8192 +96,10240,8192 +104,8192,28672 +104,57344,8192 +104,10240,8192 +104,8192,8192 +112,8192,8192 +112,8192,28672 +112,57344,8192 +112,10240,8192 +120,10240,8192 +120,8192,8192 +120,8192,28672 +120,57344,8192 +128,57344,8192 +128,10240,8192 +128,8192,8192 +128,8192,28672 +136,10240,8192 +136,8192,8192 +136,8192,28672 +136,57344,8192 +144,10240,8192 +144,8192,8192 +144,8192,28672 +144,57344,8192 +152,10240,8192 +152,8192,8192 +152,8192,28672 +152,57344,8192 +160,57344,8192 +160,10240,8192 +160,8192,8192 +160,8192,28672 +168,8192,8192 +168,8192,28672 +168,57344,8192 +168,10240,8192 +176,10240,8192 +176,8192,8192 +176,8192,28672 +176,57344,8192 +184,10240,8192 +184,8192,8192 +184,8192,28672 +184,57344,8192 +192,8192,8192 +192,8192,28672 +192,57344,8192 +192,10240,8192 +200,8192,28672 +200,57344,8192 +200,10240,8192 +200,8192,8192 +208,57344,8192 +208,10240,8192 +208,8192,8192 +208,8192,28672 +216,57344,8192 +216,10240,8192 +216,8192,8192 +216,8192,28672 +224,10240,8192 +224,8192,8192 +224,8192,28672 +224,57344,8192 +232,10240,8192 +232,8192,8192 +232,8192,28672 +232,57344,8192 +240,10240,8192 +240,8192,8192 +240,8192,28672 +240,57344,8192 +248,8192,28672 +248,57344,8192 +248,10240,8192 +248,8192,8192 +256,8192,28672 +256,10240,8192 +256,57344,8192 +256,8192,8192 +512,10240,8192 +512,8192,8192 +512,8192,28672 +512,57344,8192 +1024,10240,8192 +1024,8192,8192 +1024,8192,28672 +1024,57344,8192 +1536,8192,8192 +1536,8192,28672 +1536,57344,8192 +1536,10240,8192 +2048,8192,28672 +2048,57344,8192 +2048,8192,8192 +2048,10240,8192 +3072,8192,28672 +3072,10240,8192 +3072,8192,8192 +3072,57344,8192 +4096,10240,8192 +4096,8192,8192 +4096,8192,28672 +4096,57344,8192 +8192,10240,8192 +8192,8192,8192 +8192,8192,28672 +8192,57344,8192 +16384,8192,28672 +16384,57344,8192 +16384,10240,8192 +16384,8192,8192 +18432,8192,8192 +18432,8192,28672 +18432,57344,8192 +18432,10240,8192 +20480,10240,8192 +20480,8192,8192 +20480,8192,28672 +20480,57344,8192 +32768,57344,8192 +32768,10240,8192 +32768,8192,8192 +32768,8192,28672 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP2.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP2.csv new file mode 100644 index 00000000..a1abdcda --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP2.csv @@ -0,0 +1,189 @@ +M,N,K +1,5120,8192 +1,8192,14336 +1,8192,4096 +1,28672,8192 +2,8192,4096 +2,5120,8192 +2,8192,14336 +2,28672,8192 +4,5120,8192 +4,8192,14336 +4,8192,4096 +4,28672,8192 +8,8192,4096 +8,8192,14336 +8,28672,8192 +8,5120,8192 +16,8192,14336 +16,5120,8192 +16,28672,8192 +16,8192,4096 +17,8192,14336 +17,28672,8192 +17,8192,4096 +17,5120,8192 +24,28672,8192 +24,8192,4096 +24,5120,8192 +24,8192,14336 +32,8192,4096 +32,5120,8192 +32,28672,8192 +32,8192,14336 +40,8192,4096 +40,5120,8192 +40,8192,14336 +40,28672,8192 +48,5120,8192 +48,8192,4096 +48,28672,8192 +48,8192,14336 +56,8192,4096 +56,8192,14336 +56,28672,8192 +56,5120,8192 +64,5120,8192 +64,28672,8192 +64,8192,14336 +64,8192,4096 +72,8192,14336 +72,8192,4096 +72,28672,8192 +72,5120,8192 +80,8192,4096 +80,5120,8192 +80,28672,8192 +80,8192,14336 +88,8192,4096 +88,5120,8192 +88,28672,8192 +88,8192,14336 +96,5120,8192 +96,8192,14336 +96,8192,4096 +96,28672,8192 +104,8192,14336 +104,8192,4096 +104,28672,8192 +104,5120,8192 +112,28672,8192 +112,5120,8192 +112,8192,14336 +112,8192,4096 +120,5120,8192 +120,8192,4096 +120,28672,8192 +120,8192,14336 +128,5120,8192 +128,28672,8192 +128,8192,4096 +128,8192,14336 +136,5120,8192 +136,8192,14336 +136,8192,4096 +136,28672,8192 +144,5120,8192 +144,8192,14336 +144,8192,4096 +144,28672,8192 +152,8192,4096 +152,28672,8192 +152,8192,14336 +152,5120,8192 +160,8192,4096 +160,5120,8192 +160,28672,8192 +160,8192,14336 +168,8192,14336 +168,5120,8192 +168,8192,4096 +168,28672,8192 +176,5120,8192 +176,28672,8192 +176,8192,4096 +176,8192,14336 +184,28672,8192 +184,5120,8192 +184,8192,14336 +184,8192,4096 +192,8192,4096 +192,5120,8192 +192,28672,8192 +192,8192,14336 +200,8192,14336 +200,8192,4096 +200,5120,8192 +200,28672,8192 +208,5120,8192 +208,8192,14336 +208,8192,4096 +208,28672,8192 +216,8192,14336 +216,5120,8192 +216,8192,4096 +216,28672,8192 +224,5120,8192 +224,28672,8192 +224,8192,4096 +224,8192,14336 +232,8192,4096 +232,5120,8192 +232,28672,8192 +232,8192,14336 +240,8192,4096 +240,28672,8192 +240,8192,14336 +240,5120,8192 +248,8192,4096 +248,28672,8192 +248,5120,8192 +248,8192,14336 +256,8192,14336 +256,28672,8192 +256,8192,4096 +256,5120,8192 +512,8192,4096 +512,5120,8192 +512,28672,8192 +512,8192,14336 +1024,5120,8192 +1024,8192,4096 +1024,28672,8192 +1024,8192,14336 +1536,8192,4096 +1536,28672,8192 +1536,8192,14336 +1536,5120,8192 +2048,28672,8192 +2048,5120,8192 +2048,8192,14336 +2048,8192,4096 +3072,5120,8192 +3072,8192,14336 +3072,8192,4096 +3072,28672,8192 +4096,5120,8192 +4096,8192,14336 +4096,8192,4096 +4096,28672,8192 +8192,5120,8192 +8192,28672,8192 +8192,8192,4096 +8192,8192,14336 +16384,8192,4096 +16384,5120,8192 +16384,28672,8192 +16384,8192,14336 +18432,8192,4096 +18432,8192,14336 +18432,5120,8192 +18432,28672,8192 +20480,8192,4096 +20480,5120,8192 +20480,28672,8192 +20480,8192,14336 +32768,5120,8192 +32768,28672,8192 +32768,8192,4096 +32768,8192,14336 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP4.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP4.csv new file mode 100644 index 00000000..232b0412 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP4.csv @@ -0,0 +1,189 @@ +M,N,K +1,2560,8192 +1,14336,8192 +1,8192,2048 +1,8192,7168 +2,2560,8192 +2,8192,7168 +2,8192,2048 +2,14336,8192 +4,14336,8192 +4,8192,2048 +4,8192,7168 +4,2560,8192 +8,8192,7168 +8,2560,8192 +8,8192,2048 +8,14336,8192 +16,2560,8192 +16,8192,2048 +16,14336,8192 +16,8192,7168 +17,2560,8192 +17,14336,8192 +17,8192,7168 +17,8192,2048 +24,14336,8192 +24,2560,8192 +24,8192,7168 +24,8192,2048 +32,8192,7168 +32,2560,8192 +32,14336,8192 +32,8192,2048 +40,8192,7168 +40,8192,2048 +40,2560,8192 +40,14336,8192 +48,14336,8192 +48,8192,7168 +48,8192,2048 +48,2560,8192 +56,8192,7168 +56,2560,8192 +56,8192,2048 +56,14336,8192 +64,8192,7168 +64,8192,2048 +64,14336,8192 +64,2560,8192 +72,8192,2048 +72,8192,7168 +72,2560,8192 +72,14336,8192 +80,8192,7168 +80,2560,8192 +80,14336,8192 +80,8192,2048 +88,8192,7168 +88,8192,2048 +88,14336,8192 +88,2560,8192 +96,8192,2048 +96,8192,7168 +96,2560,8192 +96,14336,8192 +104,8192,7168 +104,2560,8192 +104,14336,8192 +104,8192,2048 +112,8192,2048 +112,2560,8192 +112,14336,8192 +112,8192,7168 +120,14336,8192 +120,8192,7168 +120,8192,2048 +120,2560,8192 +128,8192,7168 +128,8192,2048 +128,14336,8192 +128,2560,8192 +136,2560,8192 +136,14336,8192 +136,8192,2048 +136,8192,7168 +144,8192,2048 +144,8192,7168 +144,2560,8192 +144,14336,8192 +152,8192,7168 +152,8192,2048 +152,2560,8192 +152,14336,8192 +160,2560,8192 +160,14336,8192 +160,8192,7168 +160,8192,2048 +168,2560,8192 +168,14336,8192 +168,8192,7168 +168,8192,2048 +176,8192,7168 +176,2560,8192 +176,8192,2048 +176,14336,8192 +184,8192,7168 +184,8192,2048 +184,2560,8192 +184,14336,8192 +192,8192,7168 +192,8192,2048 +192,14336,8192 +192,2560,8192 +200,8192,7168 +200,2560,8192 +200,14336,8192 +200,8192,2048 +208,2560,8192 +208,14336,8192 +208,8192,2048 +208,8192,7168 +216,14336,8192 +216,8192,7168 +216,8192,2048 +216,2560,8192 +224,8192,7168 +224,2560,8192 +224,8192,2048 +224,14336,8192 +232,14336,8192 +232,8192,7168 +232,8192,2048 +232,2560,8192 +240,8192,7168 +240,2560,8192 +240,8192,2048 +240,14336,8192 +248,8192,7168 +248,2560,8192 +248,14336,8192 +248,8192,2048 +256,2560,8192 +256,14336,8192 +256,8192,7168 +256,8192,2048 +512,8192,7168 +512,8192,2048 +512,2560,8192 +512,14336,8192 +1024,14336,8192 +1024,8192,7168 +1024,8192,2048 +1024,2560,8192 +1536,8192,7168 +1536,2560,8192 +1536,8192,2048 +1536,14336,8192 +2048,2560,8192 +2048,14336,8192 +2048,8192,2048 +2048,8192,7168 +3072,2560,8192 +3072,14336,8192 +3072,8192,2048 +3072,8192,7168 +4096,8192,2048 +4096,8192,7168 +4096,2560,8192 +4096,14336,8192 +8192,8192,7168 +8192,8192,2048 +8192,14336,8192 +8192,2560,8192 +16384,8192,7168 +16384,2560,8192 +16384,8192,2048 +16384,14336,8192 +18432,8192,7168 +18432,2560,8192 +18432,8192,2048 +18432,14336,8192 +20480,14336,8192 +20480,8192,7168 +20480,8192,2048 +20480,2560,8192 +32768,14336,8192 +32768,2560,8192 +32768,8192,7168 +32768,8192,2048 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP8.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP8.csv new file mode 100644 index 00000000..5a4d4fc0 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP8.csv @@ -0,0 +1,189 @@ +M,N,K +1,1280,8192 +1,7168,8192 +1,8192,3584 +1,8192,1024 +2,7168,8192 +2,1280,8192 +2,8192,3584 +2,8192,1024 +4,7168,8192 +4,1280,8192 +4,8192,3584 +4,8192,1024 +8,1280,8192 +8,8192,3584 +8,8192,1024 +8,7168,8192 +16,7168,8192 +16,1280,8192 +16,8192,3584 +16,8192,1024 +17,8192,3584 +17,7168,8192 +17,1280,8192 +17,8192,1024 +24,7168,8192 +24,1280,8192 +24,8192,3584 +24,8192,1024 +32,8192,3584 +32,8192,1024 +32,1280,8192 +32,7168,8192 +40,8192,1024 +40,1280,8192 +40,7168,8192 +40,8192,3584 +48,1280,8192 +48,8192,3584 +48,8192,1024 +48,7168,8192 +56,8192,3584 +56,8192,1024 +56,7168,8192 +56,1280,8192 +64,1280,8192 +64,7168,8192 +64,8192,3584 +64,8192,1024 +72,8192,3584 +72,8192,1024 +72,7168,8192 +72,1280,8192 +80,7168,8192 +80,1280,8192 +80,8192,3584 +80,8192,1024 +88,7168,8192 +88,8192,3584 +88,8192,1024 +88,1280,8192 +96,8192,3584 +96,8192,1024 +96,1280,8192 +96,7168,8192 +104,8192,3584 +104,8192,1024 +104,7168,8192 +104,1280,8192 +112,7168,8192 +112,1280,8192 +112,8192,3584 +112,8192,1024 +120,7168,8192 +120,1280,8192 +120,8192,3584 +120,8192,1024 +128,1280,8192 +128,7168,8192 +128,8192,1024 +128,8192,3584 +136,7168,8192 +136,1280,8192 +136,8192,3584 +136,8192,1024 +144,8192,3584 +144,8192,1024 +144,7168,8192 +144,1280,8192 +152,8192,3584 +152,8192,1024 +152,7168,8192 +152,1280,8192 +160,1280,8192 +160,8192,3584 +160,7168,8192 +160,8192,1024 +168,8192,3584 +168,1280,8192 +168,7168,8192 +168,8192,1024 +176,7168,8192 +176,1280,8192 +176,8192,3584 +176,8192,1024 +184,7168,8192 +184,1280,8192 +184,8192,3584 +184,8192,1024 +192,8192,3584 +192,8192,1024 +192,1280,8192 +192,7168,8192 +200,8192,3584 +200,8192,1024 +200,7168,8192 +200,1280,8192 +208,7168,8192 +208,1280,8192 +208,8192,3584 +208,8192,1024 +216,7168,8192 +216,1280,8192 +216,8192,3584 +216,8192,1024 +224,1280,8192 +224,8192,3584 +224,8192,1024 +224,7168,8192 +232,8192,1024 +232,8192,3584 +232,1280,8192 +232,7168,8192 +240,7168,8192 +240,1280,8192 +240,8192,3584 +240,8192,1024 +248,7168,8192 +248,1280,8192 +248,8192,3584 +248,8192,1024 +256,8192,3584 +256,7168,8192 +256,1280,8192 +256,8192,1024 +512,8192,3584 +512,8192,1024 +512,7168,8192 +512,1280,8192 +1024,1280,8192 +1024,8192,3584 +1024,8192,1024 +1024,7168,8192 +1536,8192,3584 +1536,8192,1024 +1536,7168,8192 +1536,1280,8192 +2048,1280,8192 +2048,7168,8192 +2048,8192,3584 +2048,8192,1024 +3072,7168,8192 +3072,1280,8192 +3072,8192,3584 +3072,8192,1024 +4096,7168,8192 +4096,1280,8192 +4096,8192,3584 +4096,8192,1024 +8192,7168,8192 +8192,8192,1024 +8192,8192,3584 +8192,1280,8192 +16384,8192,3584 +16384,8192,1024 +16384,7168,8192 +16384,1280,8192 +18432,8192,3584 +18432,8192,1024 +18432,1280,8192 +18432,7168,8192 +20480,7168,8192 +20480,8192,1024 +20480,1280,8192 +20480,8192,3584 +32768,7168,8192 +32768,1280,8192 +32768,8192,3584 +32768,8192,1024 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP1.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP1.csv new file mode 100644 index 00000000..6039ac82 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP1.csv @@ -0,0 +1,189 @@ +M,N,K +1,6144,4096 +1,4096,4096 +1,4096,14336 +1,28672,4096 +2,4096,14336 +2,4096,4096 +2,6144,4096 +2,28672,4096 +4,4096,14336 +4,28672,4096 +4,6144,4096 +4,4096,4096 +8,6144,4096 +8,28672,4096 +8,4096,4096 +8,4096,14336 +16,6144,4096 +16,4096,4096 +16,4096,14336 +16,28672,4096 +17,28672,4096 +17,4096,14336 +17,4096,4096 +17,6144,4096 +24,4096,4096 +24,4096,14336 +24,6144,4096 +24,28672,4096 +32,4096,14336 +32,6144,4096 +32,4096,4096 +32,28672,4096 +40,28672,4096 +40,4096,14336 +40,6144,4096 +40,4096,4096 +48,28672,4096 +48,6144,4096 +48,4096,4096 +48,4096,14336 +56,6144,4096 +56,28672,4096 +56,4096,4096 +56,4096,14336 +64,28672,4096 +64,6144,4096 +64,4096,14336 +64,4096,4096 +72,28672,4096 +72,6144,4096 +72,4096,4096 +72,4096,14336 +80,6144,4096 +80,4096,14336 +80,28672,4096 +80,4096,4096 +88,28672,4096 +88,6144,4096 +88,4096,4096 +88,4096,14336 +96,28672,4096 +96,6144,4096 +96,4096,4096 +96,4096,14336 +104,28672,4096 +104,6144,4096 +104,4096,4096 +104,4096,14336 +112,4096,4096 +112,6144,4096 +112,4096,14336 +112,28672,4096 +120,4096,14336 +120,28672,4096 +120,4096,4096 +120,6144,4096 +128,28672,4096 +128,6144,4096 +128,4096,14336 +128,4096,4096 +136,4096,14336 +136,28672,4096 +136,6144,4096 +136,4096,4096 +144,28672,4096 +144,6144,4096 +144,4096,4096 +144,4096,14336 +152,28672,4096 +152,6144,4096 +152,4096,4096 +152,4096,14336 +160,4096,4096 +160,6144,4096 +160,4096,14336 +160,28672,4096 +168,6144,4096 +168,28672,4096 +168,4096,4096 +168,4096,14336 +176,28672,4096 +176,4096,4096 +176,6144,4096 +176,4096,14336 +184,4096,4096 +184,6144,4096 +184,4096,14336 +184,28672,4096 +192,6144,4096 +192,28672,4096 +192,4096,14336 +192,4096,4096 +200,28672,4096 +200,6144,4096 +200,4096,4096 +200,4096,14336 +208,4096,4096 +208,4096,14336 +208,28672,4096 +208,6144,4096 +216,4096,14336 +216,4096,4096 +216,6144,4096 +216,28672,4096 +224,28672,4096 +224,6144,4096 +224,4096,4096 +224,4096,14336 +232,28672,4096 +232,4096,14336 +232,6144,4096 +232,4096,4096 +240,28672,4096 +240,6144,4096 +240,4096,4096 +240,4096,14336 +248,6144,4096 +248,4096,4096 +248,4096,14336 +248,28672,4096 +256,28672,4096 +256,4096,14336 +256,4096,4096 +256,6144,4096 +512,6144,4096 +512,28672,4096 +512,4096,14336 +512,4096,4096 +1024,4096,4096 +1024,6144,4096 +1024,28672,4096 +1024,4096,14336 +1536,6144,4096 +1536,28672,4096 +1536,4096,14336 +1536,4096,4096 +2048,6144,4096 +2048,28672,4096 +2048,4096,4096 +2048,4096,14336 +3072,4096,4096 +3072,4096,14336 +3072,6144,4096 +3072,28672,4096 +4096,4096,14336 +4096,28672,4096 +4096,6144,4096 +4096,4096,4096 +8192,28672,4096 +8192,4096,14336 +8192,6144,4096 +8192,4096,4096 +16384,4096,4096 +16384,4096,14336 +16384,28672,4096 +16384,6144,4096 +18432,4096,14336 +18432,6144,4096 +18432,4096,4096 +18432,28672,4096 +20480,4096,4096 +20480,4096,14336 +20480,28672,4096 +20480,6144,4096 +32768,28672,4096 +32768,6144,4096 +32768,4096,4096 +32768,4096,14336 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP2.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP2.csv new file mode 100644 index 00000000..43859492 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP2.csv @@ -0,0 +1,189 @@ +M,N,K +1,14336,4096 +1,4096,7168 +1,4096,2048 +1,3072,4096 +2,4096,7168 +2,4096,2048 +2,14336,4096 +2,3072,4096 +4,14336,4096 +4,3072,4096 +4,4096,7168 +4,4096,2048 +8,3072,4096 +8,14336,4096 +8,4096,7168 +8,4096,2048 +16,3072,4096 +16,14336,4096 +16,4096,7168 +16,4096,2048 +17,4096,2048 +17,14336,4096 +17,4096,7168 +17,3072,4096 +24,4096,7168 +24,4096,2048 +24,14336,4096 +24,3072,4096 +32,4096,2048 +32,14336,4096 +32,4096,7168 +32,3072,4096 +40,3072,4096 +40,4096,2048 +40,14336,4096 +40,4096,7168 +48,3072,4096 +48,4096,7168 +48,4096,2048 +48,14336,4096 +56,3072,4096 +56,4096,7168 +56,4096,2048 +56,14336,4096 +64,3072,4096 +64,4096,2048 +64,4096,7168 +64,14336,4096 +72,3072,4096 +72,4096,7168 +72,4096,2048 +72,14336,4096 +80,3072,4096 +80,4096,2048 +80,14336,4096 +80,4096,7168 +88,3072,4096 +88,4096,7168 +88,4096,2048 +88,14336,4096 +96,3072,4096 +96,4096,7168 +96,4096,2048 +96,14336,4096 +104,3072,4096 +104,14336,4096 +104,4096,7168 +104,4096,2048 +112,3072,4096 +112,4096,7168 +112,4096,2048 +112,14336,4096 +120,3072,4096 +120,4096,7168 +120,4096,2048 +120,14336,4096 +128,14336,4096 +128,3072,4096 +128,4096,2048 +128,4096,7168 +136,4096,7168 +136,4096,2048 +136,14336,4096 +136,3072,4096 +144,3072,4096 +144,4096,7168 +144,4096,2048 +144,14336,4096 +152,3072,4096 +152,4096,7168 +152,4096,2048 +152,14336,4096 +160,4096,7168 +160,4096,2048 +160,14336,4096 +160,3072,4096 +168,3072,4096 +168,4096,7168 +168,4096,2048 +168,14336,4096 +176,4096,7168 +176,3072,4096 +176,4096,2048 +176,14336,4096 +184,3072,4096 +184,4096,7168 +184,4096,2048 +184,14336,4096 +192,4096,2048 +192,4096,7168 +192,14336,4096 +192,3072,4096 +200,4096,7168 +200,4096,2048 +200,14336,4096 +200,3072,4096 +208,4096,7168 +208,4096,2048 +208,14336,4096 +208,3072,4096 +216,4096,2048 +216,4096,7168 +216,14336,4096 +216,3072,4096 +224,3072,4096 +224,4096,7168 +224,4096,2048 +224,14336,4096 +232,3072,4096 +232,4096,7168 +232,4096,2048 +232,14336,4096 +240,3072,4096 +240,14336,4096 +240,4096,7168 +240,4096,2048 +248,3072,4096 +248,4096,7168 +248,4096,2048 +248,14336,4096 +256,4096,2048 +256,14336,4096 +256,4096,7168 +256,3072,4096 +512,14336,4096 +512,3072,4096 +512,4096,2048 +512,4096,7168 +1024,4096,7168 +1024,14336,4096 +1024,3072,4096 +1024,4096,2048 +1536,3072,4096 +1536,4096,2048 +1536,4096,7168 +1536,14336,4096 +2048,14336,4096 +2048,4096,7168 +2048,4096,2048 +2048,3072,4096 +3072,4096,7168 +3072,4096,2048 +3072,14336,4096 +3072,3072,4096 +4096,4096,2048 +4096,3072,4096 +4096,14336,4096 +4096,4096,7168 +8192,4096,2048 +8192,3072,4096 +8192,4096,7168 +8192,14336,4096 +16384,4096,7168 +16384,4096,2048 +16384,14336,4096 +16384,3072,4096 +18432,3072,4096 +18432,4096,2048 +18432,4096,7168 +18432,14336,4096 +20480,4096,7168 +20480,4096,2048 +20480,14336,4096 +20480,3072,4096 +32768,3072,4096 +32768,4096,7168 +32768,4096,2048 +32768,14336,4096 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP4.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP4.csv new file mode 100644 index 00000000..a44a194e --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP4.csv @@ -0,0 +1,189 @@ +M,N,K +1,4096,3584 +1,4096,1024 +1,1536,4096 +1,7168,4096 +2,4096,3584 +2,7168,4096 +2,1536,4096 +2,4096,1024 +4,1536,4096 +4,4096,3584 +4,4096,1024 +4,7168,4096 +8,1536,4096 +8,4096,3584 +8,4096,1024 +8,7168,4096 +16,4096,3584 +16,4096,1024 +16,7168,4096 +16,1536,4096 +17,1536,4096 +17,4096,3584 +17,4096,1024 +17,7168,4096 +24,4096,1024 +24,4096,3584 +24,7168,4096 +24,1536,4096 +32,4096,3584 +32,4096,1024 +32,1536,4096 +32,7168,4096 +40,1536,4096 +40,7168,4096 +40,4096,3584 +40,4096,1024 +48,4096,3584 +48,7168,4096 +48,1536,4096 +48,4096,1024 +56,1536,4096 +56,4096,3584 +56,4096,1024 +56,7168,4096 +64,1536,4096 +64,4096,3584 +64,4096,1024 +64,7168,4096 +72,1536,4096 +72,4096,3584 +72,4096,1024 +72,7168,4096 +80,7168,4096 +80,1536,4096 +80,4096,3584 +80,4096,1024 +88,1536,4096 +88,7168,4096 +88,4096,1024 +88,4096,3584 +96,1536,4096 +96,4096,3584 +96,4096,1024 +96,7168,4096 +104,4096,3584 +104,4096,1024 +104,7168,4096 +104,1536,4096 +112,4096,1024 +112,7168,4096 +112,4096,3584 +112,1536,4096 +120,4096,3584 +120,7168,4096 +120,1536,4096 +120,4096,1024 +128,7168,4096 +128,1536,4096 +128,4096,3584 +128,4096,1024 +136,7168,4096 +136,1536,4096 +136,4096,3584 +136,4096,1024 +144,1536,4096 +144,4096,3584 +144,4096,1024 +144,7168,4096 +152,7168,4096 +152,4096,1024 +152,4096,3584 +152,1536,4096 +160,4096,3584 +160,4096,1024 +160,1536,4096 +160,7168,4096 +168,4096,3584 +168,4096,1024 +168,7168,4096 +168,1536,4096 +176,1536,4096 +176,7168,4096 +176,4096,3584 +176,4096,1024 +184,4096,1024 +184,4096,3584 +184,1536,4096 +184,7168,4096 +192,1536,4096 +192,4096,3584 +192,4096,1024 +192,7168,4096 +200,4096,3584 +200,4096,1024 +200,7168,4096 +200,1536,4096 +208,4096,3584 +208,4096,1024 +208,7168,4096 +208,1536,4096 +216,4096,3584 +216,4096,1024 +216,7168,4096 +216,1536,4096 +224,1536,4096 +224,7168,4096 +224,4096,1024 +224,4096,3584 +232,1536,4096 +232,4096,3584 +232,4096,1024 +232,7168,4096 +240,1536,4096 +240,4096,3584 +240,4096,1024 +240,7168,4096 +248,4096,3584 +248,4096,1024 +248,1536,4096 +248,7168,4096 +256,1536,4096 +256,4096,3584 +256,4096,1024 +256,7168,4096 +512,1536,4096 +512,4096,3584 +512,4096,1024 +512,7168,4096 +1024,4096,3584 +1024,4096,1024 +1024,7168,4096 +1024,1536,4096 +1536,1536,4096 +1536,4096,3584 +1536,4096,1024 +1536,7168,4096 +2048,4096,3584 +2048,4096,1024 +2048,7168,4096 +2048,1536,4096 +3072,4096,3584 +3072,4096,1024 +3072,7168,4096 +3072,1536,4096 +4096,1536,4096 +4096,4096,3584 +4096,4096,1024 +4096,7168,4096 +8192,1536,4096 +8192,4096,3584 +8192,4096,1024 +8192,7168,4096 +16384,4096,1024 +16384,7168,4096 +16384,1536,4096 +16384,4096,3584 +18432,1536,4096 +18432,4096,3584 +18432,4096,1024 +18432,7168,4096 +20480,4096,3584 +20480,4096,1024 +20480,7168,4096 +20480,1536,4096 +32768,1536,4096 +32768,7168,4096 +32768,4096,3584 +32768,4096,1024 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP8.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP8.csv new file mode 100644 index 00000000..a0ea181a --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP8.csv @@ -0,0 +1,189 @@ +M,N,K +1,4096,512 +1,4096,1792 +1,3584,4096 +1,768,4096 +2,4096,1792 +2,3584,4096 +2,4096,512 +2,768,4096 +4,3584,4096 +4,768,4096 +4,4096,512 +4,4096,1792 +8,768,4096 +8,4096,512 +8,4096,1792 +8,3584,4096 +16,4096,512 +16,4096,1792 +16,3584,4096 +16,768,4096 +17,4096,1792 +17,3584,4096 +17,4096,512 +17,768,4096 +24,4096,1792 +24,3584,4096 +24,4096,512 +24,768,4096 +32,4096,1792 +32,4096,512 +32,768,4096 +32,3584,4096 +40,768,4096 +40,3584,4096 +40,4096,1792 +40,4096,512 +48,4096,512 +48,3584,4096 +48,768,4096 +48,4096,1792 +56,768,4096 +56,4096,512 +56,4096,1792 +56,3584,4096 +64,768,4096 +64,4096,1792 +64,4096,512 +64,3584,4096 +72,768,4096 +72,4096,512 +72,4096,1792 +72,3584,4096 +80,4096,1792 +80,768,4096 +80,3584,4096 +80,4096,512 +88,3584,4096 +88,768,4096 +88,4096,1792 +88,4096,512 +96,768,4096 +96,4096,512 +96,4096,1792 +96,3584,4096 +104,4096,512 +104,4096,1792 +104,768,4096 +104,3584,4096 +112,4096,1792 +112,4096,512 +112,768,4096 +112,3584,4096 +120,4096,512 +120,3584,4096 +120,768,4096 +120,4096,1792 +128,768,4096 +128,3584,4096 +128,4096,1792 +128,4096,512 +136,4096,1792 +136,3584,4096 +136,768,4096 +136,4096,512 +144,768,4096 +144,4096,512 +144,4096,1792 +144,3584,4096 +152,4096,1792 +152,4096,512 +152,3584,4096 +152,768,4096 +160,4096,1792 +160,4096,512 +160,3584,4096 +160,768,4096 +168,4096,512 +168,4096,1792 +168,3584,4096 +168,768,4096 +176,3584,4096 +176,768,4096 +176,4096,512 +176,4096,1792 +184,4096,1792 +184,4096,512 +184,3584,4096 +184,768,4096 +192,768,4096 +192,3584,4096 +192,4096,512 +192,4096,1792 +200,4096,512 +200,4096,1792 +200,3584,4096 +200,768,4096 +208,4096,1792 +208,3584,4096 +208,768,4096 +208,4096,512 +216,4096,1792 +216,3584,4096 +216,4096,512 +216,768,4096 +224,768,4096 +224,4096,1792 +224,4096,512 +224,3584,4096 +232,3584,4096 +232,768,4096 +232,4096,512 +232,4096,1792 +240,3584,4096 +240,768,4096 +240,4096,512 +240,4096,1792 +248,4096,512 +248,4096,1792 +248,3584,4096 +248,768,4096 +256,4096,1792 +256,3584,4096 +256,4096,512 +256,768,4096 +512,768,4096 +512,4096,1792 +512,3584,4096 +512,4096,512 +1024,4096,512 +1024,768,4096 +1024,4096,1792 +1024,3584,4096 +1536,4096,1792 +1536,3584,4096 +1536,4096,512 +1536,768,4096 +2048,4096,512 +2048,4096,1792 +2048,3584,4096 +2048,768,4096 +3072,4096,512 +3072,4096,1792 +3072,3584,4096 +3072,768,4096 +4096,768,4096 +4096,3584,4096 +4096,4096,1792 +4096,4096,512 +8192,3584,4096 +8192,4096,1792 +8192,4096,512 +8192,768,4096 +16384,4096,1792 +16384,3584,4096 +16384,768,4096 +16384,4096,512 +18432,4096,1792 +18432,4096,512 +18432,3584,4096 +18432,768,4096 +20480,4096,1792 +20480,3584,4096 +20480,768,4096 +20480,4096,512 +32768,3584,4096 +32768,768,4096 +32768,4096,512 +32768,4096,1792 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP2.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP2.csv new file mode 100644 index 00000000..bf12b198 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP2.csv @@ -0,0 +1,189 @@ +M,N,K +1,6144,3072 +1,16384,6144 +1,4096,6144 +1,6144,16384 +2,4096,6144 +2,6144,3072 +2,6144,16384 +2,16384,6144 +4,4096,6144 +4,16384,6144 +4,6144,16384 +4,6144,3072 +8,6144,3072 +8,16384,6144 +8,6144,16384 +8,4096,6144 +16,16384,6144 +16,4096,6144 +16,6144,16384 +16,6144,3072 +17,6144,16384 +17,6144,3072 +17,16384,6144 +17,4096,6144 +24,6144,3072 +24,16384,6144 +24,4096,6144 +24,6144,16384 +32,6144,3072 +32,4096,6144 +32,6144,16384 +32,16384,6144 +40,6144,3072 +40,16384,6144 +40,4096,6144 +40,6144,16384 +48,6144,3072 +48,16384,6144 +48,6144,16384 +48,4096,6144 +56,6144,3072 +56,16384,6144 +56,6144,16384 +56,4096,6144 +64,16384,6144 +64,6144,16384 +64,4096,6144 +64,6144,3072 +72,16384,6144 +72,6144,3072 +72,6144,16384 +72,4096,6144 +80,6144,16384 +80,16384,6144 +80,4096,6144 +80,6144,3072 +88,16384,6144 +88,6144,16384 +88,6144,3072 +88,4096,6144 +96,16384,6144 +96,6144,3072 +96,6144,16384 +96,4096,6144 +104,16384,6144 +104,6144,3072 +104,6144,16384 +104,4096,6144 +112,6144,3072 +112,16384,6144 +112,6144,16384 +112,4096,6144 +120,6144,16384 +120,4096,6144 +120,6144,3072 +120,16384,6144 +128,16384,6144 +128,6144,16384 +128,4096,6144 +128,6144,3072 +136,16384,6144 +136,6144,3072 +136,4096,6144 +136,6144,16384 +144,16384,6144 +144,6144,3072 +144,6144,16384 +144,4096,6144 +152,6144,16384 +152,6144,3072 +152,16384,6144 +152,4096,6144 +160,16384,6144 +160,6144,16384 +160,4096,6144 +160,6144,3072 +168,16384,6144 +168,6144,16384 +168,4096,6144 +168,6144,3072 +176,6144,3072 +176,16384,6144 +176,6144,16384 +176,4096,6144 +184,6144,3072 +184,16384,6144 +184,6144,16384 +184,4096,6144 +192,6144,3072 +192,16384,6144 +192,6144,16384 +192,4096,6144 +200,6144,3072 +200,16384,6144 +200,4096,6144 +200,6144,16384 +208,6144,16384 +208,6144,3072 +208,16384,6144 +208,4096,6144 +216,16384,6144 +216,4096,6144 +216,6144,16384 +216,6144,3072 +224,6144,16384 +224,6144,3072 +224,16384,6144 +224,4096,6144 +232,16384,6144 +232,6144,3072 +232,4096,6144 +232,6144,16384 +240,6144,3072 +240,16384,6144 +240,4096,6144 +240,6144,16384 +248,6144,3072 +248,4096,6144 +248,6144,16384 +248,16384,6144 +256,6144,16384 +256,6144,3072 +256,16384,6144 +256,4096,6144 +512,6144,16384 +512,6144,3072 +512,16384,6144 +512,4096,6144 +1024,4096,6144 +1024,6144,16384 +1024,6144,3072 +1024,16384,6144 +1536,16384,6144 +1536,6144,16384 +1536,6144,3072 +1536,4096,6144 +2048,16384,6144 +2048,6144,16384 +2048,4096,6144 +2048,6144,3072 +3072,4096,6144 +3072,6144,3072 +3072,16384,6144 +3072,6144,16384 +4096,16384,6144 +4096,6144,3072 +4096,6144,16384 +4096,4096,6144 +8192,6144,3072 +8192,16384,6144 +8192,4096,6144 +8192,6144,16384 +16384,6144,16384 +16384,6144,3072 +16384,16384,6144 +16384,4096,6144 +18432,6144,16384 +18432,6144,3072 +18432,16384,6144 +18432,4096,6144 +20480,6144,16384 +20480,6144,3072 +20480,16384,6144 +20480,4096,6144 +65536,6144,3072 +65536,16384,6144 +65536,4096,6144 +65536,6144,16384 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP4.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP4.csv new file mode 100644 index 00000000..05f238ce --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP4.csv @@ -0,0 +1,189 @@ +M,N,K +1,16384,6144 +1,6144,1536 +1,6144,16384 +1,2048,6144 +2,2048,6144 +2,6144,1536 +2,6144,16384 +2,16384,6144 +4,2048,6144 +4,16384,6144 +4,6144,1536 +4,6144,16384 +8,16384,6144 +8,6144,1536 +8,6144,16384 +8,2048,6144 +16,16384,6144 +16,6144,1536 +16,6144,16384 +16,2048,6144 +17,6144,16384 +17,16384,6144 +17,2048,6144 +17,6144,1536 +24,16384,6144 +24,2048,6144 +24,6144,1536 +24,6144,16384 +32,6144,16384 +32,2048,6144 +32,6144,1536 +32,16384,6144 +40,2048,6144 +40,16384,6144 +40,6144,1536 +40,6144,16384 +48,2048,6144 +48,6144,1536 +48,16384,6144 +48,6144,16384 +56,16384,6144 +56,6144,1536 +56,6144,16384 +56,2048,6144 +64,6144,1536 +64,16384,6144 +64,6144,16384 +64,2048,6144 +72,16384,6144 +72,6144,1536 +72,6144,16384 +72,2048,6144 +80,6144,1536 +80,6144,16384 +80,2048,6144 +80,16384,6144 +88,2048,6144 +88,6144,1536 +88,16384,6144 +88,6144,16384 +96,6144,1536 +96,16384,6144 +96,6144,16384 +96,2048,6144 +104,6144,1536 +104,16384,6144 +104,6144,16384 +104,2048,6144 +112,6144,1536 +112,16384,6144 +112,6144,16384 +112,2048,6144 +120,6144,1536 +120,6144,16384 +120,2048,6144 +120,16384,6144 +128,2048,6144 +128,6144,1536 +128,16384,6144 +128,6144,16384 +136,2048,6144 +136,16384,6144 +136,6144,1536 +136,6144,16384 +144,16384,6144 +144,6144,1536 +144,6144,16384 +144,2048,6144 +152,2048,6144 +152,6144,1536 +152,6144,16384 +152,16384,6144 +160,6144,1536 +160,16384,6144 +160,6144,16384 +160,2048,6144 +168,16384,6144 +168,6144,1536 +168,6144,16384 +168,2048,6144 +176,2048,6144 +176,6144,1536 +176,16384,6144 +176,6144,16384 +184,6144,1536 +184,16384,6144 +184,6144,16384 +184,2048,6144 +192,16384,6144 +192,6144,16384 +192,2048,6144 +192,6144,1536 +200,16384,6144 +200,6144,1536 +200,6144,16384 +200,2048,6144 +208,6144,16384 +208,2048,6144 +208,16384,6144 +208,6144,1536 +216,16384,6144 +216,2048,6144 +216,6144,1536 +216,6144,16384 +224,2048,6144 +224,6144,1536 +224,6144,16384 +224,16384,6144 +232,16384,6144 +232,6144,1536 +232,6144,16384 +232,2048,6144 +240,16384,6144 +240,6144,1536 +240,6144,16384 +240,2048,6144 +248,6144,1536 +248,6144,16384 +248,16384,6144 +248,2048,6144 +256,6144,1536 +256,6144,16384 +256,16384,6144 +256,2048,6144 +512,6144,1536 +512,6144,16384 +512,16384,6144 +512,2048,6144 +1024,2048,6144 +1024,6144,1536 +1024,6144,16384 +1024,16384,6144 +1536,6144,1536 +1536,16384,6144 +1536,6144,16384 +1536,2048,6144 +2048,16384,6144 +2048,6144,1536 +2048,6144,16384 +2048,2048,6144 +3072,2048,6144 +3072,16384,6144 +3072,6144,1536 +3072,6144,16384 +4096,6144,1536 +4096,16384,6144 +4096,6144,16384 +4096,2048,6144 +8192,16384,6144 +8192,6144,1536 +8192,6144,16384 +8192,2048,6144 +16384,6144,16384 +16384,2048,6144 +16384,16384,6144 +16384,6144,1536 +18432,6144,1536 +18432,6144,16384 +18432,16384,6144 +18432,2048,6144 +20480,6144,16384 +20480,2048,6144 +20480,16384,6144 +20480,6144,1536 +65536,16384,6144 +65536,6144,1536 +65536,6144,16384 +65536,2048,6144 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP8.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP8.csv new file mode 100644 index 00000000..195089f7 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP8.csv @@ -0,0 +1,189 @@ +M,N,K +1,6144,768 +1,1024,6144 +1,16384,6144 +1,6144,16384 +2,1024,6144 +2,6144,768 +2,6144,16384 +2,16384,6144 +4,16384,6144 +4,6144,768 +4,1024,6144 +4,6144,16384 +8,6144,768 +8,1024,6144 +8,16384,6144 +8,6144,16384 +16,6144,768 +16,1024,6144 +16,16384,6144 +16,6144,16384 +17,6144,16384 +17,6144,768 +17,1024,6144 +17,16384,6144 +24,16384,6144 +24,1024,6144 +24,6144,768 +24,6144,16384 +32,6144,768 +32,1024,6144 +32,6144,16384 +32,16384,6144 +40,6144,768 +40,1024,6144 +40,16384,6144 +40,6144,16384 +48,6144,768 +48,1024,6144 +48,16384,6144 +48,6144,16384 +56,6144,768 +56,1024,6144 +56,16384,6144 +56,6144,16384 +64,1024,6144 +64,16384,6144 +64,6144,16384 +64,6144,768 +72,1024,6144 +72,16384,6144 +72,6144,768 +72,6144,16384 +80,6144,768 +80,6144,16384 +80,1024,6144 +80,16384,6144 +88,6144,768 +88,16384,6144 +88,6144,16384 +88,1024,6144 +96,1024,6144 +96,16384,6144 +96,6144,768 +96,6144,16384 +104,1024,6144 +104,16384,6144 +104,6144,768 +104,6144,16384 +112,6144,768 +112,1024,6144 +112,16384,6144 +112,6144,16384 +120,1024,6144 +120,6144,16384 +120,16384,6144 +120,6144,768 +128,6144,768 +128,1024,6144 +128,16384,6144 +128,6144,16384 +136,16384,6144 +136,6144,768 +136,1024,6144 +136,6144,16384 +144,1024,6144 +144,16384,6144 +144,6144,768 +144,6144,16384 +152,6144,768 +152,6144,16384 +152,1024,6144 +152,16384,6144 +160,6144,768 +160,1024,6144 +160,16384,6144 +160,6144,16384 +168,6144,768 +168,1024,6144 +168,16384,6144 +168,6144,16384 +176,1024,6144 +176,6144,768 +176,16384,6144 +176,6144,16384 +184,6144,768 +184,1024,6144 +184,16384,6144 +184,6144,16384 +192,6144,768 +192,1024,6144 +192,16384,6144 +192,6144,16384 +200,6144,768 +200,1024,6144 +200,16384,6144 +200,6144,16384 +208,6144,16384 +208,6144,768 +208,1024,6144 +208,16384,6144 +216,16384,6144 +216,6144,768 +216,1024,6144 +216,6144,16384 +224,6144,768 +224,6144,16384 +224,1024,6144 +224,16384,6144 +232,16384,6144 +232,6144,768 +232,1024,6144 +232,6144,16384 +240,6144,768 +240,1024,6144 +240,16384,6144 +240,6144,16384 +248,6144,768 +248,1024,6144 +248,6144,16384 +248,16384,6144 +256,6144,16384 +256,1024,6144 +256,16384,6144 +256,6144,768 +512,6144,768 +512,6144,16384 +512,1024,6144 +512,16384,6144 +1024,6144,768 +1024,6144,16384 +1024,1024,6144 +1024,16384,6144 +1536,6144,768 +1536,1024,6144 +1536,16384,6144 +1536,6144,16384 +2048,1024,6144 +2048,16384,6144 +2048,6144,16384 +2048,6144,768 +3072,6144,768 +3072,1024,6144 +3072,16384,6144 +3072,6144,16384 +4096,1024,6144 +4096,16384,6144 +4096,6144,768 +4096,6144,16384 +8192,6144,768 +8192,1024,6144 +8192,16384,6144 +8192,6144,16384 +16384,6144,16384 +16384,6144,768 +16384,1024,6144 +16384,16384,6144 +18432,6144,16384 +18432,6144,768 +18432,1024,6144 +18432,16384,6144 +20480,6144,16384 +20480,6144,768 +20480,1024,6144 +20480,16384,6144 +65536,6144,768 +65536,1024,6144 +65536,16384,6144 +65536,6144,16384 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP1.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP1.csv new file mode 100644 index 00000000..7ff7b198 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP1.csv @@ -0,0 +1,189 @@ +M,N,K +1,14336,4096 +1,6144,4096 +1,4096,4096 +1,4096,14336 +2,14336,4096 +2,4096,14336 +2,4096,4096 +2,6144,4096 +4,14336,4096 +4,4096,14336 +4,6144,4096 +4,4096,4096 +8,14336,4096 +8,6144,4096 +8,4096,4096 +8,4096,14336 +16,14336,4096 +16,6144,4096 +16,4096,4096 +16,4096,14336 +17,14336,4096 +17,4096,14336 +17,4096,4096 +17,6144,4096 +24,4096,4096 +24,14336,4096 +24,4096,14336 +24,6144,4096 +32,14336,4096 +32,4096,14336 +32,6144,4096 +32,4096,4096 +40,14336,4096 +40,4096,14336 +40,6144,4096 +40,4096,4096 +48,6144,4096 +48,4096,4096 +48,14336,4096 +48,4096,14336 +56,6144,4096 +56,4096,4096 +56,14336,4096 +56,4096,14336 +64,6144,4096 +64,4096,14336 +64,4096,4096 +64,14336,4096 +72,6144,4096 +72,4096,4096 +72,14336,4096 +72,4096,14336 +80,14336,4096 +80,6144,4096 +80,4096,14336 +80,4096,4096 +88,6144,4096 +88,4096,4096 +88,14336,4096 +88,4096,14336 +96,6144,4096 +96,4096,4096 +96,14336,4096 +96,4096,14336 +104,14336,4096 +104,6144,4096 +104,4096,4096 +104,4096,14336 +112,4096,4096 +112,14336,4096 +112,6144,4096 +112,4096,14336 +120,4096,14336 +120,4096,4096 +120,14336,4096 +120,6144,4096 +128,14336,4096 +128,6144,4096 +128,4096,14336 +128,4096,4096 +136,14336,4096 +136,4096,14336 +136,6144,4096 +136,4096,4096 +144,6144,4096 +144,4096,4096 +144,14336,4096 +144,4096,14336 +152,6144,4096 +152,4096,4096 +152,14336,4096 +152,4096,14336 +160,4096,4096 +160,6144,4096 +160,4096,14336 +160,14336,4096 +168,6144,4096 +168,4096,4096 +168,14336,4096 +168,4096,14336 +176,4096,4096 +176,14336,4096 +176,6144,4096 +176,4096,14336 +184,4096,4096 +184,14336,4096 +184,6144,4096 +184,4096,14336 +192,6144,4096 +192,4096,14336 +192,4096,4096 +192,14336,4096 +200,6144,4096 +200,4096,4096 +200,14336,4096 +200,4096,14336 +208,4096,4096 +208,14336,4096 +208,4096,14336 +208,6144,4096 +216,4096,14336 +216,4096,4096 +216,14336,4096 +216,6144,4096 +224,6144,4096 +224,4096,4096 +224,14336,4096 +224,4096,14336 +232,4096,14336 +232,6144,4096 +232,4096,4096 +232,14336,4096 +240,14336,4096 +240,6144,4096 +240,4096,4096 +240,4096,14336 +248,6144,4096 +248,4096,4096 +248,14336,4096 +248,4096,14336 +256,14336,4096 +256,4096,14336 +256,4096,4096 +256,6144,4096 +512,14336,4096 +512,6144,4096 +512,4096,14336 +512,4096,4096 +1024,4096,4096 +1024,14336,4096 +1024,6144,4096 +1024,4096,14336 +1536,6144,4096 +1536,4096,14336 +1536,4096,4096 +1536,14336,4096 +2048,6144,4096 +2048,14336,4096 +2048,4096,4096 +2048,4096,14336 +3072,4096,4096 +3072,14336,4096 +3072,4096,14336 +3072,6144,4096 +4096,4096,14336 +4096,14336,4096 +4096,6144,4096 +4096,4096,4096 +8192,4096,14336 +8192,6144,4096 +8192,4096,4096 +8192,14336,4096 +16384,4096,4096 +16384,14336,4096 +16384,4096,14336 +16384,6144,4096 +18432,4096,14336 +18432,6144,4096 +18432,4096,4096 +18432,14336,4096 +20480,4096,4096 +20480,14336,4096 +20480,4096,14336 +20480,6144,4096 +32768,6144,4096 +32768,4096,4096 +32768,14336,4096 +32768,4096,14336 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP2.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP2.csv new file mode 100644 index 00000000..7c7256c8 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP2.csv @@ -0,0 +1,189 @@ +M,N,K +1,14336,4096 +1,4096,2048 +1,4096,14336 +1,3072,4096 +2,4096,2048 +2,14336,4096 +2,4096,14336 +2,3072,4096 +4,14336,4096 +4,4096,14336 +4,3072,4096 +4,4096,2048 +8,3072,4096 +8,14336,4096 +8,4096,2048 +8,4096,14336 +16,3072,4096 +16,14336,4096 +16,4096,2048 +16,4096,14336 +17,4096,2048 +17,14336,4096 +17,4096,14336 +17,3072,4096 +24,4096,2048 +24,14336,4096 +24,4096,14336 +24,3072,4096 +32,4096,2048 +32,14336,4096 +32,4096,14336 +32,3072,4096 +40,3072,4096 +40,4096,2048 +40,14336,4096 +40,4096,14336 +48,3072,4096 +48,4096,2048 +48,14336,4096 +48,4096,14336 +56,3072,4096 +56,4096,2048 +56,14336,4096 +56,4096,14336 +64,3072,4096 +64,4096,2048 +64,4096,14336 +64,14336,4096 +72,3072,4096 +72,4096,2048 +72,14336,4096 +72,4096,14336 +80,3072,4096 +80,4096,2048 +80,14336,4096 +80,4096,14336 +88,3072,4096 +88,4096,2048 +88,14336,4096 +88,4096,14336 +96,3072,4096 +96,4096,2048 +96,14336,4096 +96,4096,14336 +104,3072,4096 +104,14336,4096 +104,4096,2048 +104,4096,14336 +112,3072,4096 +112,4096,2048 +112,14336,4096 +112,4096,14336 +120,4096,14336 +120,3072,4096 +120,4096,2048 +120,14336,4096 +128,14336,4096 +128,3072,4096 +128,4096,2048 +128,4096,14336 +136,4096,2048 +136,14336,4096 +136,4096,14336 +136,3072,4096 +144,3072,4096 +144,4096,2048 +144,14336,4096 +144,4096,14336 +152,3072,4096 +152,4096,2048 +152,14336,4096 +152,4096,14336 +160,4096,2048 +160,4096,14336 +160,14336,4096 +160,3072,4096 +168,3072,4096 +168,4096,2048 +168,14336,4096 +168,4096,14336 +176,3072,4096 +176,4096,2048 +176,14336,4096 +176,4096,14336 +184,3072,4096 +184,4096,2048 +184,14336,4096 +184,4096,14336 +192,4096,2048 +192,4096,14336 +192,14336,4096 +192,3072,4096 +200,4096,2048 +200,14336,4096 +200,4096,14336 +200,3072,4096 +208,4096,2048 +208,14336,4096 +208,4096,14336 +208,3072,4096 +216,4096,2048 +216,4096,14336 +216,14336,4096 +216,3072,4096 +224,3072,4096 +224,4096,2048 +224,14336,4096 +224,4096,14336 +232,3072,4096 +232,4096,14336 +232,4096,2048 +232,14336,4096 +240,3072,4096 +240,14336,4096 +240,4096,2048 +240,4096,14336 +248,3072,4096 +248,4096,2048 +248,14336,4096 +248,4096,14336 +256,4096,2048 +256,14336,4096 +256,4096,14336 +256,3072,4096 +512,14336,4096 +512,3072,4096 +512,4096,2048 +512,4096,14336 +1024,14336,4096 +1024,3072,4096 +1024,4096,2048 +1024,4096,14336 +1536,3072,4096 +1536,4096,2048 +1536,4096,14336 +1536,14336,4096 +2048,14336,4096 +2048,4096,2048 +2048,4096,14336 +2048,3072,4096 +3072,4096,2048 +3072,14336,4096 +3072,4096,14336 +3072,3072,4096 +4096,4096,2048 +4096,4096,14336 +4096,3072,4096 +4096,14336,4096 +8192,4096,2048 +8192,3072,4096 +8192,4096,14336 +8192,14336,4096 +16384,4096,2048 +16384,14336,4096 +16384,4096,14336 +16384,3072,4096 +18432,3072,4096 +18432,4096,2048 +18432,4096,14336 +18432,14336,4096 +20480,4096,2048 +20480,14336,4096 +20480,4096,14336 +20480,3072,4096 +32768,3072,4096 +32768,4096,2048 +32768,14336,4096 +32768,4096,14336 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP4.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP4.csv new file mode 100644 index 00000000..da2e030c --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP4.csv @@ -0,0 +1,189 @@ +M,N,K +1,14336,4096 +1,4096,1024 +1,4096,14336 +1,1536,4096 +2,14336,4096 +2,4096,14336 +2,1536,4096 +2,4096,1024 +4,14336,4096 +4,4096,14336 +4,1536,4096 +4,4096,1024 +8,1536,4096 +8,14336,4096 +8,4096,1024 +8,4096,14336 +16,14336,4096 +16,4096,1024 +16,4096,14336 +16,1536,4096 +17,1536,4096 +17,14336,4096 +17,4096,14336 +17,4096,1024 +24,4096,1024 +24,14336,4096 +24,4096,14336 +24,1536,4096 +32,14336,4096 +32,4096,14336 +32,4096,1024 +32,1536,4096 +40,1536,4096 +40,14336,4096 +40,4096,14336 +40,4096,1024 +48,1536,4096 +48,4096,1024 +48,14336,4096 +48,4096,14336 +56,1536,4096 +56,4096,1024 +56,14336,4096 +56,4096,14336 +64,1536,4096 +64,4096,14336 +64,4096,1024 +64,14336,4096 +72,1536,4096 +72,4096,1024 +72,14336,4096 +72,4096,14336 +80,14336,4096 +80,4096,14336 +80,1536,4096 +80,4096,1024 +88,1536,4096 +88,4096,1024 +88,14336,4096 +88,4096,14336 +96,1536,4096 +96,4096,1024 +96,14336,4096 +96,4096,14336 +104,14336,4096 +104,4096,1024 +104,4096,14336 +104,1536,4096 +112,4096,1024 +112,14336,4096 +112,4096,14336 +112,1536,4096 +120,4096,14336 +120,1536,4096 +120,4096,1024 +120,14336,4096 +128,14336,4096 +128,1536,4096 +128,4096,14336 +128,4096,1024 +136,14336,4096 +136,4096,14336 +136,1536,4096 +136,4096,1024 +144,1536,4096 +144,4096,1024 +144,14336,4096 +144,4096,14336 +152,4096,1024 +152,14336,4096 +152,4096,14336 +152,1536,4096 +160,4096,14336 +160,4096,1024 +160,14336,4096 +160,1536,4096 +168,4096,1024 +168,14336,4096 +168,4096,14336 +168,1536,4096 +176,1536,4096 +176,4096,1024 +176,14336,4096 +176,4096,14336 +184,4096,1024 +184,14336,4096 +184,4096,14336 +184,1536,4096 +192,1536,4096 +192,4096,14336 +192,4096,1024 +192,14336,4096 +200,4096,1024 +200,14336,4096 +200,4096,14336 +200,1536,4096 +208,4096,1024 +208,14336,4096 +208,4096,14336 +208,1536,4096 +216,4096,14336 +216,4096,1024 +216,14336,4096 +216,1536,4096 +224,1536,4096 +224,4096,1024 +224,14336,4096 +224,4096,14336 +232,1536,4096 +232,4096,14336 +232,4096,1024 +232,14336,4096 +240,1536,4096 +240,14336,4096 +240,4096,1024 +240,4096,14336 +248,4096,1024 +248,14336,4096 +248,4096,14336 +248,1536,4096 +256,1536,4096 +256,14336,4096 +256,4096,14336 +256,4096,1024 +512,14336,4096 +512,1536,4096 +512,4096,14336 +512,4096,1024 +1024,4096,1024 +1024,14336,4096 +1024,1536,4096 +1024,4096,14336 +1536,1536,4096 +1536,4096,14336 +1536,4096,1024 +1536,14336,4096 +2048,14336,4096 +2048,4096,1024 +2048,4096,14336 +2048,1536,4096 +3072,4096,1024 +3072,14336,4096 +3072,4096,14336 +3072,1536,4096 +4096,4096,14336 +4096,1536,4096 +4096,14336,4096 +4096,4096,1024 +8192,1536,4096 +8192,4096,14336 +8192,4096,1024 +8192,14336,4096 +16384,4096,1024 +16384,14336,4096 +16384,4096,14336 +16384,1536,4096 +18432,1536,4096 +18432,4096,14336 +18432,4096,1024 +18432,14336,4096 +20480,4096,1024 +20480,14336,4096 +20480,4096,14336 +20480,1536,4096 +32768,1536,4096 +32768,4096,1024 +32768,14336,4096 +32768,4096,14336 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP8.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP8.csv new file mode 100644 index 00000000..3a337241 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP8.csv @@ -0,0 +1,189 @@ +M,N,K +1,14336,4096 +1,4096,512 +1,4096,14336 +1,768,4096 +2,14336,4096 +2,4096,14336 +2,4096,512 +2,768,4096 +4,14336,4096 +4,4096,14336 +4,768,4096 +4,4096,512 +8,768,4096 +8,14336,4096 +8,4096,512 +8,4096,14336 +16,14336,4096 +16,4096,512 +16,4096,14336 +16,768,4096 +17,14336,4096 +17,4096,14336 +17,4096,512 +17,768,4096 +24,14336,4096 +24,4096,14336 +24,4096,512 +24,768,4096 +32,14336,4096 +32,4096,14336 +32,4096,512 +32,768,4096 +40,768,4096 +40,14336,4096 +40,4096,14336 +40,4096,512 +48,4096,512 +48,768,4096 +48,14336,4096 +48,4096,14336 +56,768,4096 +56,4096,512 +56,14336,4096 +56,4096,14336 +64,768,4096 +64,4096,14336 +64,4096,512 +64,14336,4096 +72,768,4096 +72,4096,512 +72,14336,4096 +72,4096,14336 +80,14336,4096 +80,4096,14336 +80,768,4096 +80,4096,512 +88,768,4096 +88,14336,4096 +88,4096,14336 +88,4096,512 +96,768,4096 +96,4096,512 +96,14336,4096 +96,4096,14336 +104,14336,4096 +104,4096,512 +104,4096,14336 +104,768,4096 +112,14336,4096 +112,4096,14336 +112,4096,512 +112,768,4096 +120,4096,14336 +120,4096,512 +120,768,4096 +120,14336,4096 +128,14336,4096 +128,768,4096 +128,4096,14336 +128,4096,512 +136,14336,4096 +136,4096,14336 +136,768,4096 +136,4096,512 +144,768,4096 +144,4096,512 +144,14336,4096 +144,4096,14336 +152,14336,4096 +152,4096,14336 +152,4096,512 +152,768,4096 +160,4096,14336 +160,4096,512 +160,14336,4096 +160,768,4096 +168,4096,512 +168,14336,4096 +168,4096,14336 +168,768,4096 +176,768,4096 +176,4096,512 +176,14336,4096 +176,4096,14336 +184,14336,4096 +184,4096,14336 +184,4096,512 +184,768,4096 +192,768,4096 +192,4096,14336 +192,4096,512 +192,14336,4096 +200,4096,512 +200,14336,4096 +200,4096,14336 +200,768,4096 +208,14336,4096 +208,4096,14336 +208,768,4096 +208,4096,512 +216,4096,14336 +216,4096,512 +216,14336,4096 +216,768,4096 +224,768,4096 +224,14336,4096 +224,4096,14336 +224,4096,512 +232,768,4096 +232,4096,14336 +232,4096,512 +232,14336,4096 +240,768,4096 +240,14336,4096 +240,4096,512 +240,4096,14336 +248,4096,512 +248,14336,4096 +248,4096,14336 +248,768,4096 +256,14336,4096 +256,4096,14336 +256,4096,512 +256,768,4096 +512,14336,4096 +512,768,4096 +512,4096,14336 +512,4096,512 +1024,4096,512 +1024,14336,4096 +1024,768,4096 +1024,4096,14336 +1536,4096,14336 +1536,4096,512 +1536,14336,4096 +1536,768,4096 +2048,14336,4096 +2048,4096,512 +2048,4096,14336 +2048,768,4096 +3072,4096,512 +3072,14336,4096 +3072,4096,14336 +3072,768,4096 +4096,4096,14336 +4096,768,4096 +4096,14336,4096 +4096,4096,512 +8192,4096,14336 +8192,4096,512 +8192,14336,4096 +8192,768,4096 +16384,14336,4096 +16384,4096,14336 +16384,768,4096 +16384,4096,512 +18432,4096,14336 +18432,4096,512 +18432,14336,4096 +18432,768,4096 +20480,14336,4096 +20480,4096,14336 +20480,768,4096 +20480,4096,512 +32768,768,4096 +32768,4096,512 +32768,14336,4096 +32768,4096,14336 diff --git a/csrc/ck_gemm_a8w8/README.md b/csrc/ck_gemm_a8w8/README.md index 884a7d94..e220a8a3 100644 --- a/csrc/ck_gemm_a8w8/README.md +++ b/csrc/ck_gemm_a8w8/README.md @@ -15,4 +15,38 @@ You can find the results of the tuning in `aiter/configs/a8w8_tuned_gemm.csv`. ## More If you want to re-install gemm_a8w8, you should remove `aiter/jit/module_gemm_a8w8.so` and `aiter/jit/build/module_gemm_a8w8` first. -If you use flag `PREBUILD_KERNELS=1 USE_CK_A8W8=1` when you install aiter, it will build gemm a8w8 kernels in `aiter/configs/a8w8_tuned_gemm.csv` by default. If you want to use the new result of gemm_a8w8_tune, please remove `build` and `*.so` first, then re-intall aiter after finishing tune. \ No newline at end of file +If you use flag `PREBUILD_KERNELS=1 USE_CK_A8W8=1` when you install aiter, it will build gemm a8w8 kernels in `aiter/configs/a8w8_tuned_gemm.csv` by default. If you want to use the new result of gemm_a8w8_tune, please remove `build` and `*.so` first, then re-intall aiter after finishing tune. + + +## FP8 A8W8 Rowwise Scaling GEMM + +The following steps will walk you through the full process of getting the best performance out of your hardware. + +0. Clear gemm_a8w8: Remove `aiter/jit/module_gemm_a8w8.so` and `aiter/jit/build/module_gemm_a8w8`. +```bash +rm -rf aiter/jit/module_gemm_a8w8.so +rm -rf aiter/jit/build/module_gemm_a8w8 +``` + +1. Install the AITER library +```bash +python3 setup.py develop +``` + +2. Tune your GEMM kernel + +python3 csrc/ck_gemm_a8w8/gemm_a8w8_tune.py -i aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP1.csv -o aiter/configs/a8w8_tuned_gemm.csv --dtype fp8 -k + +3. Use the operator: +```python +from aiter.ops.gemm_op_a8w8 import gemm_a8w8_CK + +output = gemm_a8w8_CK( + qinput, # [M, K] + weight, # [N, K] + x_scale, # [M, 1] + weight_scale, # [1, N] + bias, + dtype=out_dtype # torch.bfloat16, torch.float16 +) +``` From 0beaefc841fe880e92a3fe9f67ba16cd370ee3b4 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 18 Feb 2025 07:48:30 +0000 Subject: [PATCH 05/13] refactor gemm_a8w8_tests --- op_tests/__init__.py | 0 op_tests/test_gemm_a8w8.py | 183 ++++++++++++++++++++++--------------- op_tests/utils.py | 83 +++++++++++++++++ 3 files changed, 191 insertions(+), 75 deletions(-) create mode 100644 op_tests/__init__.py create mode 100644 op_tests/utils.py diff --git a/op_tests/__init__.py b/op_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/op_tests/test_gemm_a8w8.py b/op_tests/test_gemm_a8w8.py index a3903f5a..90c6a564 100644 --- a/op_tests/test_gemm_a8w8.py +++ b/op_tests/test_gemm_a8w8.py @@ -1,18 +1,55 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -from aiter.test_common import checkAllclose, perftest, tensor_dump +import pytest import torch import torch.nn.functional as F -import numpy as np -import sys -import os import aiter from aiter.ops.shuffle import shuffle_weight +from .utils import rand_tensor, check_all_close + + +MNK = [ + # qkv_proj + (1, 1280, 8192), + (32, 1280, 8192), + (64, 1280, 8192), + (128, 1280, 8192), + (192, 1280, 8192), + (256, 1280, 8192), + (320, 1280, 8192), + (512, 1280, 8192), + (1024, 1280, 8192), + (2048, 1280, 8192), + (4096, 1280, 8192), + (8192, 1280, 8192), + (16384, 1280, 8192), + # attn_out + (1, 8192, 1024), + (32, 8192, 1024), + (64, 8192, 1024), + (128, 8192, 1024), + (192, 8192, 1024), + (256, 8192, 1024), + (320, 8192, 1024), + (512, 8192, 1024), + (1024, 8192, 1024), + (2048, 8192, 1024), + (4096, 8192, 1024), + (8192, 8192, 1024), + (16384, 8192, 1024), +] +# ck_gemm only accepts the following combination for | scale : output | dtypes +# | F32 : F16 | F32 : BF16 | F16 : F16 | B16 : B16 | +CK_GEMM_SCALES_OUTPUT_DTYPES = [ + (torch.float32, torch.float16), + (torch.float32, torch.bfloat16), + (torch.float16, torch.float16), + (torch.bfloat16, torch.bfloat16), +] -@perftest() -def run_torch(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): + +def torch_scaled_mm(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): x = F.linear(x.to(torch.float32), weight.to(torch.float32)) scale = torch.matmul(x_scale, w_scale) out = torch.mul(x, scale) @@ -21,77 +58,73 @@ def run_torch(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): return out.to(dtype) -@perftest() -def run_gemm_ck(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): - return aiter.gemm_a8w8_CK(x, weight, x_scale, w_scale, bias) +def setup( + mnk: tuple[int, int, int], + ab_dtype: torch.dtype, + scales_dtype: torch.dtype, + bias_dtype: torch.dtype, + use_bias: bool, +) -> tuple[torch.tensor, ...]: + + m, n, k = mnk + a = rand_tensor(shape=(m, k), dtype=ab_dtype).cuda() + b = rand_tensor(shape=(n, k), dtype=ab_dtype).cuda() + a_scale = rand_tensor(shape=(m, 1), dtype=scales_dtype).cuda() + 1e-6 + b_scale = rand_tensor(shape=(1, n), dtype=scales_dtype).cuda() + 1e-6 + bias = (rand_tensor(shape=[1, n], dtype=torch.bfloat16).cuda()).to(dtype=bias_dtype) if use_bias else None + return a, b, a_scale, b_scale, bias -@perftest() -def run_gemm_asm(x, weightshuffle, x_scale, w_scale, bias=None, dtype=torch.bfloat16): - return aiter.gemm_a8w8_ASM(x, weightshuffle, x_scale, w_scale, bias) +@pytest.mark.parametrize("mnk", MNK) +@pytest.mark.parametrize("ab_dtype", [torch.int8, torch.float8_e4m3fnuz]) +@pytest.mark.parametrize("scales_output_dtype", CK_GEMM_SCALES_OUTPUT_DTYPES) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_ck_gemm_close_to_torch( + mnk: tuple[int, int, int], + ab_dtype: torch.dtype, + scales_output_dtype: tuple[torch.dtype, torch.dtype], + use_bias: bool +) -> None: + # using bias further reduces precision, so we use a higher tolerance + if use_bias: + rtol, atol = (1e-1, 1e-1) + else: + rtol, atol = (1e-2, 1e-1) + scales_dtype, out_dtype = scales_output_dtype + bias_dtype = out_dtype + a, b, a_scale, b_scale, bias = setup( + mnk, + ab_dtype, + scales_dtype, + bias_dtype, + use_bias, + ) -def test_gemm(dtype, m, n, k): - dim = (m, n, k) - x = torch.randint(-20, 20, (m, k), dtype=torch.int8).cuda() - weight = torch.randint(-20, 20, (n, k), dtype=torch.int8).cuda() - x_scale = torch.rand([m, 1], dtype=torch.float32).cuda() + 1e-6 - w_scale = torch.rand([1, n], dtype=torch.float32).cuda() + 1e-6 - bias = torch.rand([1, n], dtype=dtype).cuda() * 10 - weightshuffle = shuffle_weight(weight,layout=(32,16)) - bias_f32 = bias.to(torch.float) - x_pad, _ = F.pad(x,(0,128), "constant", 0).split([x.shape[1], 128],dim=1) - # print(f"{x_pad.shape=}{x_pad.stride()}") - # tensor_dump(x, 'x') - # tensor_dump(weight, 'weight') - # tensor_dump(shuffle_weight(weight), 'weight_shuffled') - # tensor_dump(x_scale, 'x_scale') - # tensor_dump(w_scale, 'w_scale') - # tensor_dump(bias, 'bias') + output = aiter.gemm_a8w8_CK(a, b, a_scale, b_scale, bias, dtype=out_dtype) + expected = torch_scaled_mm(a, b, a_scale, b_scale, bias, dtype=out_dtype) - a, avg_a = run_torch(x, weight, x_scale, w_scale, bias, dtype) - b, avg_b = run_gemm_ck(x, weight, x_scale, w_scale, bias, dtype) - c, avg_c = run_gemm_asm(x, weightshuffle, x_scale, w_scale, bias_f32, dtype) - if c is None: - msg = f"[perf] dim: {str(dim):<20} dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, asm : not support, uplift: {avg_a/avg_b-1:<5.1%}" - else: - msg = f"[perf] dim: {str(dim):<20} dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, asm avg: {avg_c:<8.2f} us, uplift: {avg_a/min(avg_b,avg_c)-1:<5.1%}" - checkAllclose(a, b, msg="a,b: "+msg, rtol=1e-2, atol=0.01) - if c != None: - checkAllclose(a, c, msg="\033[1A\033[2K" + "a,c: "+ msg, rtol=1e-2, atol=0.01) + check_all_close(output, expected, rtol=rtol, atol=atol) +@pytest.mark.parametrize("mnk", MNK) +def test_asm_gemm_close_to_torch( + mnk: tuple[int, int, int], +) -> None: + rtol, atol = (1e-1, 1e-1) + ab_dtype = torch.int8 + out_dtype = torch.bfloat16 + scales_dtype = torch.float32 + bias_dtype = torch.float + # asm_gemm requires bias and shuffle + a, b, a_scale, b_scale, bias = setup( + mnk, + ab_dtype, + scales_dtype, + bias_dtype, + use_bias=True, + ) + b_shuffled= shuffle_weight(b, layout=(32, 16)) -for dtype in [torch.bfloat16]: - # qkv_proj - for (m, n, k) in [ - (1, 1280, 8192), - (32, 1280, 8192), - (64, 1280, 8192), - (128, 1280, 8192), - (192, 1280, 8192), - (256, 1280, 8192), - (320, 1280, 8192), - (512, 1280, 8192), - (1024, 1280, 8192), - (2048, 1280, 8192), - (4096, 1280, 8192), - (8192, 1280, 8192), - (16384, 1280, 8192), - ]: - test_gemm(dtype, m, n, k) - # attn_out - for (m, n, k) in [ - (1, 8192, 1024), - (32, 8192, 1024), - (64, 8192, 1024), - (128, 8192, 1024), - (192, 8192, 1024), - (256, 8192, 1024), - (320, 8192, 1024), - (512, 8192, 1024), - (1024, 8192, 1024), - (2048, 8192, 1024), - (4096, 8192, 1024), - (8192, 8192, 1024), - (16384, 8192, 1024), - ]: - test_gemm(dtype, m, n, k) + output = aiter.gemm_a8w8_ASM(a, b_shuffled, a_scale, b_scale, bias) + expected = torch_scaled_mm(a, b, a_scale, b_scale, bias, dtype=out_dtype) + if output is not None and torch.sum(output.isnan()==True) ==0: + check_all_close(output, expected, rtol=rtol, atol=atol) diff --git a/op_tests/utils.py b/op_tests/utils.py new file mode 100644 index 00000000..ceca6537 --- /dev/null +++ b/op_tests/utils.py @@ -0,0 +1,83 @@ +import torch + +MAX_RAND_INT = 20 +MIN_RAND_INT = -20 + +def rand_tensor( + shape: tuple[int, int], + dtype: torch.dtype +) -> torch.tensor: + """ + Generate a random PyTorch tensor with specified shape and data type. + + - For integer types: Uses torch.randint to generate random integers within a fixed range. + - For float types: Uses torch.rand to generate random floats between 0 and 1. + + Parameters: + ----------- + shape : tuple[int, int] + The shape of the output tensor. Must be a tuple of two integers. + dtype : torch.dtype + The desired data type of the output tensor. + + Returns: + -------- + torch.Tensor + A random tensor of the specified shape and data type. + + Raises: + ------- + ValueError + If an unsupported data type is provided. + """ + if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + # For integer types, use randint + return torch.randint(MIN_RAND_INT, MAX_RAND_INT, shape, dtype=dtype) + elif dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16]: + # For float types, use rand + return torch.rand(shape, dtype=dtype) + elif dtype == torch.float8_e4m3fnuz: + # Special case for float8_e4m3fnuz + return torch.rand(shape, dtype=torch.float16).to(torch.float8_e4m3fnuz) + + raise ValueError(f"Unsupported dtype: {dtype}") + +def check_all_close( + a: torch.tensor, + b: torch.tensor, + rtol: float, + atol: float, +) -> None: + """ + Check if all elements in two tensors are close within specified tolerances. + + Parameters: + ----------- + a : torch.Tensor + First input tensor. + b : torch.Tensor + Second input tensor to compare with 'a'. + rtol : float + Relative tolerance. + atol : float + Absolute tolerance. + + Raises: + ------- + AssertionError + If any elements in 'a' and 'b' are not close within the specified tolerances. + The error message includes details about the maximum and average delta, + and the percentage of elements that are not close. + """ + is_close = torch.isclose(a, b, rtol=rtol, atol=atol) + is_not_close = ~is_close + num_not_close = is_not_close.sum() + delta = (a-b)[is_not_close] + percent = num_not_close/a.numel() + message = "" if num_not_close == 0 else f""" +check_all_close failed! +max delta:{delta.max()} +average delta:{delta.mean()} +delta details: {percent:.1%} ({num_not_close} of {a.numel()}) elements + """ + assert is_close.all(), message \ No newline at end of file From 856a908e2e4ec9ebc691cea56bf34ae8b8d26e68 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 18 Feb 2025 08:32:38 +0000 Subject: [PATCH 06/13] benchmarking WIP --- op_benchmarks/benchmark.py | 90 ++++++++++++++++++++++++++++ op_benchmarks/benchmark_gemm_a8w8.py | 0 2 files changed, 90 insertions(+) create mode 100644 op_benchmarks/benchmark.py create mode 100644 op_benchmarks/benchmark_gemm_a8w8.py diff --git a/op_benchmarks/benchmark.py b/op_benchmarks/benchmark.py new file mode 100644 index 00000000..91bc2264 --- /dev/null +++ b/op_benchmarks/benchmark.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Callable +import torch +import torch.profiler as tpf +import os +import numpy as np +import pandas as pd +from aiter import logger + + + +def get_trace_perf(prof, num_iters): + assert (num_iters > 1) + num_iters -= 1 + df = [] + cols = ['name', 'self_cpu_time_total', 'self_device_time_total', + 'device_type', 'device_index',] + for el in prof.events(): + df.append([getattr(el, x, None) for x in cols]) + df = pd.DataFrame(df, columns=cols) + df['cnt'] = 1 + rets = [] + for name, d in df.groupby('name', sort=False): + r = d.iloc[1:][['cnt', + 'self_cpu_time_total', + 'self_device_time_total']].sum() + if not r.empty: + device_type = str(d['device_type'].iat[0]).split('.')[-1] + r['name'] = name + r['device_type'] = device_type + r['device_index'] = str(d['device_index'].iat[0]) + if device_type == 'CUDA': + r['device_time_total'] = r['self_device_time_total'] + r['host_time_total'] = 0 + else: + r['host_time_total'] = r['self_device_time_total'] + r['device_time_total'] = 0 + + rets.append(r) + df = pd.DataFrame(rets) + + cols = ['name', 'cnt', 'host_time_total', 'device_time_total', + 'device_type', 'device_index',] + cols = [el for el in cols if el in df.columns] + df = df[(df.host_time_total > 0) | (df.device_time_total > 0)] + + timerList = ['host_time_total', 'device_time_total', ] + df = df[cols].sort_values(timerList, ignore_index=True) + avg_name = '[avg us/iter]' + for el in timerList: + df.at[avg_name, el] = df[el].sum()/num_iters + if int(os.environ.get('AITER_LOG_MORE', 0)): + logger.info(f'{df}') + return df.at[avg_name, 'device_time_total'] + +def execute_callback(num_iterations: int, func: Callable, *args, **kwargs) -> None: + for _ in range(num_iterations): + func(*args, **kwargs) + +def profile( + num_iterations: int, + num_warmup_iterations:int, + func: Callable, + *args, + **kwargs +): + # warmup + execute_callback(num_warmup_iterations, func, *args, **kwargs) + with tpf.profile( + activities=[ + tpf.ProfilerActivity.CPU, + tpf.ProfilerActivity.CUDA + ], + profile_memory=True, + with_stack=True, + with_modules=True, + ) as prof: + execute_callback(func, *args, **kwargs) + + avg = get_trace_perf(prof, num_iterations) + +def profile_cuda_graph( + num_iterations: int, + num_warmup_iterations:int, func: Callable, *args, **kwargs): + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + execute_callback(1, func, *args, **kwargs) + profile(num_iterations, num_warmup_iterations, func, *args, **kwargs) diff --git a/op_benchmarks/benchmark_gemm_a8w8.py b/op_benchmarks/benchmark_gemm_a8w8.py new file mode 100644 index 00000000..e69de29b From 24fad1610c07ef15dd34e78e55dee9e3c236d740 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 18 Feb 2025 12:26:03 +0000 Subject: [PATCH 07/13] add benchmark hooks. Implement the benchmark hooks for gemm_ck_a8w8 test --- op_benchmarks/benchmark.py | 90 ------------ op_benchmarks/benchmark_gemm_a8w8.py | 0 op_tests/test_gemm_a8w8.py | 46 ++---- op_tests/utils.py | 204 ++++++++++++++++++++++++++- requirements.txt | 1 + 5 files changed, 219 insertions(+), 122 deletions(-) delete mode 100644 op_benchmarks/benchmark.py delete mode 100644 op_benchmarks/benchmark_gemm_a8w8.py diff --git a/op_benchmarks/benchmark.py b/op_benchmarks/benchmark.py deleted file mode 100644 index 91bc2264..00000000 --- a/op_benchmarks/benchmark.py +++ /dev/null @@ -1,90 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -from typing import Callable -import torch -import torch.profiler as tpf -import os -import numpy as np -import pandas as pd -from aiter import logger - - - -def get_trace_perf(prof, num_iters): - assert (num_iters > 1) - num_iters -= 1 - df = [] - cols = ['name', 'self_cpu_time_total', 'self_device_time_total', - 'device_type', 'device_index',] - for el in prof.events(): - df.append([getattr(el, x, None) for x in cols]) - df = pd.DataFrame(df, columns=cols) - df['cnt'] = 1 - rets = [] - for name, d in df.groupby('name', sort=False): - r = d.iloc[1:][['cnt', - 'self_cpu_time_total', - 'self_device_time_total']].sum() - if not r.empty: - device_type = str(d['device_type'].iat[0]).split('.')[-1] - r['name'] = name - r['device_type'] = device_type - r['device_index'] = str(d['device_index'].iat[0]) - if device_type == 'CUDA': - r['device_time_total'] = r['self_device_time_total'] - r['host_time_total'] = 0 - else: - r['host_time_total'] = r['self_device_time_total'] - r['device_time_total'] = 0 - - rets.append(r) - df = pd.DataFrame(rets) - - cols = ['name', 'cnt', 'host_time_total', 'device_time_total', - 'device_type', 'device_index',] - cols = [el for el in cols if el in df.columns] - df = df[(df.host_time_total > 0) | (df.device_time_total > 0)] - - timerList = ['host_time_total', 'device_time_total', ] - df = df[cols].sort_values(timerList, ignore_index=True) - avg_name = '[avg us/iter]' - for el in timerList: - df.at[avg_name, el] = df[el].sum()/num_iters - if int(os.environ.get('AITER_LOG_MORE', 0)): - logger.info(f'{df}') - return df.at[avg_name, 'device_time_total'] - -def execute_callback(num_iterations: int, func: Callable, *args, **kwargs) -> None: - for _ in range(num_iterations): - func(*args, **kwargs) - -def profile( - num_iterations: int, - num_warmup_iterations:int, - func: Callable, - *args, - **kwargs -): - # warmup - execute_callback(num_warmup_iterations, func, *args, **kwargs) - with tpf.profile( - activities=[ - tpf.ProfilerActivity.CPU, - tpf.ProfilerActivity.CUDA - ], - profile_memory=True, - with_stack=True, - with_modules=True, - ) as prof: - execute_callback(func, *args, **kwargs) - - avg = get_trace_perf(prof, num_iterations) - -def profile_cuda_graph( - num_iterations: int, - num_warmup_iterations:int, func: Callable, *args, **kwargs): - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - execute_callback(1, func, *args, **kwargs) - profile(num_iterations, num_warmup_iterations, func, *args, **kwargs) diff --git a/op_benchmarks/benchmark_gemm_a8w8.py b/op_benchmarks/benchmark_gemm_a8w8.py deleted file mode 100644 index e69de29b..00000000 diff --git a/op_tests/test_gemm_a8w8.py b/op_tests/test_gemm_a8w8.py index 90c6a564..8465828c 100644 --- a/op_tests/test_gemm_a8w8.py +++ b/op_tests/test_gemm_a8w8.py @@ -1,41 +1,21 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +import aiter +from aiter.ops.shuffle import shuffle_weight import pytest import torch import torch.nn.functional as F -import aiter -from aiter.ops.shuffle import shuffle_weight -from .utils import rand_tensor, check_all_close +from .utils import DefaultBenchmarkHook, check_all_close, rand_tensor MNK = [ # qkv_proj (1, 1280, 8192), - (32, 1280, 8192), - (64, 1280, 8192), - (128, 1280, 8192), - (192, 1280, 8192), - (256, 1280, 8192), - (320, 1280, 8192), (512, 1280, 8192), - (1024, 1280, 8192), - (2048, 1280, 8192), - (4096, 1280, 8192), - (8192, 1280, 8192), (16384, 1280, 8192), # attn_out (1, 8192, 1024), - (32, 8192, 1024), - (64, 8192, 1024), - (128, 8192, 1024), - (192, 8192, 1024), - (256, 8192, 1024), - (320, 8192, 1024), (512, 8192, 1024), - (1024, 8192, 1024), - (2048, 8192, 1024), - (4096, 8192, 1024), - (8192, 8192, 1024), (16384, 8192, 1024), ] @@ -79,12 +59,14 @@ def setup( @pytest.mark.parametrize("scales_output_dtype", CK_GEMM_SCALES_OUTPUT_DTYPES) @pytest.mark.parametrize("use_bias", [True, False]) def test_ck_gemm_close_to_torch( + benchmark, mnk: tuple[int, int, int], ab_dtype: torch.dtype, scales_output_dtype: tuple[torch.dtype, torch.dtype], use_bias: bool ) -> None: - # using bias further reduces precision, so we use a higher tolerance + # using bias further reduces precision, + # so we use a higher tolerance when bias is enabled. if use_bias: rtol, atol = (1e-1, 1e-1) else: @@ -100,21 +82,21 @@ def test_ck_gemm_close_to_torch( use_bias, ) - output = aiter.gemm_a8w8_CK(a, b, a_scale, b_scale, bias, dtype=out_dtype) + _, output = benchmark( + DefaultBenchmarkHook(aiter.gemm_a8w8_CK, a, b, a_scale, b_scale, bias, out_dtype) + ) expected = torch_scaled_mm(a, b, a_scale, b_scale, bias, dtype=out_dtype) check_all_close(output, expected, rtol=rtol, atol=atol) @pytest.mark.parametrize("mnk", MNK) -def test_asm_gemm_close_to_torch( - mnk: tuple[int, int, int], -) -> None: +def test_asm_gemm_close_to_torch(benchmark, mnk: tuple[int, int, int]) -> None: rtol, atol = (1e-1, 1e-1) ab_dtype = torch.int8 out_dtype = torch.bfloat16 scales_dtype = torch.float32 bias_dtype = torch.float - # asm_gemm requires bias and shuffle + # gemm_a8w8_ASM requires bias and shuffle a, b, a_scale, b_scale, bias = setup( mnk, ab_dtype, @@ -124,7 +106,9 @@ def test_asm_gemm_close_to_torch( ) b_shuffled= shuffle_weight(b, layout=(32, 16)) - output = aiter.gemm_a8w8_ASM(a, b_shuffled, a_scale, b_scale, bias) + _, output = benchmark( + DefaultBenchmarkHook(aiter.gemm_a8w8_ASM, a, b_shuffled, a_scale, b_scale, bias) + ) expected = torch_scaled_mm(a, b, a_scale, b_scale, bias, dtype=out_dtype) - if output is not None and torch.sum(output.isnan()==True) ==0: + if output is not None and torch.sum(output.isnan()==True) == 0: check_all_close(output, expected, rtol=rtol, atol=atol) diff --git a/op_tests/utils.py b/op_tests/utils.py index ceca6537..e65e24d2 100644 --- a/op_tests/utils.py +++ b/op_tests/utils.py @@ -1,4 +1,10 @@ +from functools import partial +import pandas as pd import torch +import torch.profiler as tpf +from typing import Any, Callable, Generic, TypeVar + +T = TypeVar("T") MAX_RAND_INT = 20 MIN_RAND_INT = -20 @@ -80,4 +86,200 @@ def check_all_close( average delta:{delta.mean()} delta details: {percent:.1%} ({num_not_close} of {a.numel()}) elements """ - assert is_close.all(), message \ No newline at end of file + assert is_close.all(), message + +def get_trace_perf(prof, num_iters): + # TODO: clean up + assert (num_iters > 1) + num_iters -= 1 + df = [] + cols = ['name', 'self_cpu_time_total', 'self_device_time_total', + 'device_type', 'device_index',] + for el in prof.events(): + df.append([getattr(el, x, None) for x in cols]) + df = pd.DataFrame(df, columns=cols) + df['cnt'] = 1 + rets = [] + for name, d in df.groupby('name', sort=False): + r = d.iloc[1:][['cnt', + 'self_cpu_time_total', + 'self_device_time_total']].sum() + if not r.empty: + device_type = str(d['device_type'].iat[0]).split('.')[-1] + r['name'] = name + r['device_type'] = device_type + r['device_index'] = str(d['device_index'].iat[0]) + if device_type == 'CUDA': + r['device_time_total'] = r['self_device_time_total'] + r['host_time_total'] = 0 + else: + r['host_time_total'] = r['self_device_time_total'] + r['device_time_total'] = 0 + + rets.append(r) + df = pd.DataFrame(rets) + + cols = ['name', 'cnt', 'host_time_total', 'device_time_total', + 'device_type', 'device_index',] + cols = [el for el in cols if el in df.columns] + df = df[(df.host_time_total > 0) | (df.device_time_total > 0)] + + timerList = ['host_time_total', 'device_time_total', ] + df = df[cols].sort_values(timerList, ignore_index=True) + avg_name = '[avg us/iter]' + for el in timerList: + df.at[avg_name, el] = df[el].sum()/num_iters + return df.at[avg_name, 'device_time_total'] + +def execute_callback(num_iterations: int, func: Callable[..., T], *args, **kwargs) -> T: + """" + Execute a function multiple times and return the result of the last execution. + + Returns: + T: The result of the last function execution. + """ + for _ in range(num_iterations): + result = func(*args, **kwargs) + return result + +def profile( + num_iterations: int, + num_warmup_iterations:int, + func: Callable[..., T], + *args, + **kwargs +) -> T: + """ + Profile the execution of a function using PyTorch Profiler. + + This function performs warmup iterations, then profiles the actual execution + of the function for a specified number of iterations. + + Returns: + tuple[float, T]: A tuple containing: + - float: The average execution time. + - T: The result of the last function execution. + """ + # warmup + execute_callback(num_warmup_iterations, func, *args, **kwargs) + with tpf.profile( + activities=[ + tpf.ProfilerActivity.CPU, + tpf.ProfilerActivity.CUDA + ], + profile_memory=True, + with_stack=True, + with_modules=True, + ) as prof: + result = execute_callback(num_iterations, func, *args, **kwargs) + + return get_trace_perf(prof, num_iterations), result + +def profile_cuda_graph( + num_iterations: int, + num_warmup_iterations:int, + func: Callable[..., T], + *args, + **kwargs +) -> T: + """ + Profile the execution of a function using CUDA Graph and PyTorch Profiler. + + This function creates a CUDA Graph for the given function, then profiles its + execution using the standard profile function. + + Returns: + tuple[float, T]: A tuple containing: + - float: The average execution time or other performance metric. + - T: The result of the last function execution. + """ + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + execute_callback(1, func, *args, **kwargs) + return profile(num_iterations, num_warmup_iterations, func, *args, **kwargs) + + +class BenchmarkHook(Generic[T]): + """ + A generic class for custom benchmarking in pytest using pytest-benchmark. + + This class allows for fine-grained control over benchmarking parameters and + can be used with the pytest-benchmark fixture to measure function performance. + + Type Parameters: + ---------------- + T : The return type of the function being benchmarked. + + Attributes: + ----------- + num_iterations : int + The number of times the function will be executed during benchmarking. + num_warmup_iterations : int + The number of times the function will be executed before actual benchmarking begins. + These iterations are not included in the final measurements. + func : Callable[..., T] + The function to be benchmarked. + args : tuple + Positional arguments to be passed to the benchmarked function. + kwargs : dict + Keyword arguments to be passed to the benchmarked function. + + Methods: + -------- + __call__() -> tuple[float, T] + Executes the benchmarking process and returns a tuple containing + the execution time and the result of the benchmarked function. + + Usage: + ------ + def test_example(benchmark): + def function_to_benchmark(x: int) -> int: + return x * 2 + + hook = BenchmarkHook( + num_iterations=1000, + num_warmup_iterations=100, + func=function_to_benchmark, + 5 + ) + + execution_time, result = benchmark(hook) + + assert result == 10 + assert execution_time > 0 + """ + + def __init__(self, + num_iterations:int, + num_warmup_iterations:int, + use_cuda_graph: bool, + func: Callable[..., T], + *args, + **kwargs + ) -> None: + self.num_iterations = num_iterations + self.num_warmup_iterations = num_warmup_iterations + self.use_cuda_graph = use_cuda_graph + self.func = func + self.args = args + self.kwargs = kwargs + + def __call__(self) -> tuple[float, T]: + if self.use_cuda_graph: + return profile_cuda_graph( + self.num_iterations, + self.num_warmup_iterations, + self.func, + *self.args, + **self.kwargs + ) + return profile( + self.num_iterations, + self.num_warmup_iterations, + self.func, + *self.args, + **self.kwargs + ) + +DefaultBenchmarkHook = partial(BenchmarkHook, 100, 10, False) +DefaultCudaGraphBenchmarkHook = partial(BenchmarkHook, 100, 10, True) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3800156d..01cdf87c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ pandas<=2.0.3 pytest +pytest-benchmark==5.1.* \ No newline at end of file From 314f2deb5ead418067721569c1662a9dea0ada77 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 20 Feb 2025 03:17:14 +0000 Subject: [PATCH 08/13] Update documentation and add cuda events profiler option --- op_tests/utils.py | 199 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 149 insertions(+), 50 deletions(-) diff --git a/op_tests/utils.py b/op_tests/utils.py index e65e24d2..3e2d6f24 100644 --- a/op_tests/utils.py +++ b/op_tests/utils.py @@ -1,8 +1,10 @@ from functools import partial +import numpy as np +import os import pandas as pd import torch import torch.profiler as tpf -from typing import Any, Callable, Generic, TypeVar +from typing import Callable, Generic, TypeVar T = TypeVar("T") @@ -10,7 +12,7 @@ MIN_RAND_INT = -20 def rand_tensor( - shape: tuple[int, int], + size: torch.Size, dtype: torch.dtype ) -> torch.tensor: """ @@ -19,32 +21,31 @@ def rand_tensor( - For integer types: Uses torch.randint to generate random integers within a fixed range. - For float types: Uses torch.rand to generate random floats between 0 and 1. - Parameters: - ----------- - shape : tuple[int, int] - The shape of the output tensor. Must be a tuple of two integers. - dtype : torch.dtype - The desired data type of the output tensor. - - Returns: - -------- - torch.Tensor - A random tensor of the specified shape and data type. + Parameters + ---------- + size: torch.Size + The size of the generated tensor. + dtype: torch.dtype + The data type of the generated tensor. - Raises: + Returns ------- + torch.Tensor : A random tensor of the specified shape and data type. + + Raises + ------ ValueError If an unsupported data type is provided. """ if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: # For integer types, use randint - return torch.randint(MIN_RAND_INT, MAX_RAND_INT, shape, dtype=dtype) + return torch.randint(MIN_RAND_INT, MAX_RAND_INT, size=size, dtype=dtype) elif dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16]: # For float types, use rand - return torch.rand(shape, dtype=dtype) + return torch.rand(size=size, dtype=dtype) elif dtype == torch.float8_e4m3fnuz: # Special case for float8_e4m3fnuz - return torch.rand(shape, dtype=torch.float16).to(torch.float8_e4m3fnuz) + return torch.rand(size=size, dtype=torch.float16).to(torch.float8_e4m3fnuz) raise ValueError(f"Unsupported dtype: {dtype}") @@ -56,19 +57,19 @@ def check_all_close( ) -> None: """ Check if all elements in two tensors are close within specified tolerances. - - Parameters: - ----------- - a : torch.Tensor + + Parameters + ---------- + a: torch.tensor First input tensor. - b : torch.Tensor - Second input tensor to compare with 'a'. - rtol : float - Relative tolerance. - atol : float - Absolute tolerance. - - Raises: + b: torch.tensor + Second input tensor. + rtol: float + Relative tolerence. + atol: float + Absolute tolerence. + + Raises ------- AssertionError If any elements in 'a' and 'b' are not close within the specified tolerances. @@ -131,37 +132,118 @@ def get_trace_perf(prof, num_iters): df.at[avg_name, el] = df[el].sum()/num_iters return df.at[avg_name, 'device_time_total'] -def execute_callback(num_iterations: int, func: Callable[..., T], *args, **kwargs) -> T: - """" - Execute a function multiple times and return the result of the last execution. - Returns: - T: The result of the last function execution. +def execute_callback( + num_iterations: int, + func: Callable[..., T], + *args, + **kwargs +) -> T: + """ + Executes a callback for a given number of iterations. + + Parameters + ---------- + num_iterations : int. + The number of iterations to use for profiling the callback. + func : Callable[..., T]. + A callback function with arbitrary arguments to be executed. + *args + Variable length argument list for the callback function. + **kwargs + Keyword arguments for the callback function. + + Returns + ------- + T: The last value returned by the callback function. """ for _ in range(num_iterations): result = func(*args, **kwargs) return result +def time_callback_with_cuda_event( + num_iterations: int, + func: Callable[..., T], + *args, + **kwargs +) -> float: + """ + Measure the average execution time of a given function using CUDA + events in milliseconds. + + Parameters + ---------- + num_iterations : int. + The number of iterations to use for profiling the callback. + func : Callable[..., T]. + A callback function with arbitrary arguments to be executed. + *args + Variable length argument list for the callback function. + **kwargs + Keyword arguments for the callback function. + Returns + ------- + float: The average execution time in milliseconds over all iterations. + """ + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + latency_list = [] + for _ in range(num_iterations): + start_event.record() + func(*args, **kwargs) + end_event.record() + end_event.synchronize() + latency = start_event.elapsed_time(end_event) + latency_list.append(latency) + return np.mean(latency_list) + def profile( num_iterations: int, num_warmup_iterations:int, func: Callable[..., T], *args, **kwargs -) -> T: +) -> tuple[float, T]: """ Profile the execution of a function using PyTorch Profiler. This function performs warmup iterations, then profiles the actual execution of the function for a specified number of iterations. + + Parameters + ---------- + num_iterations : int. + The number of iterations to use for profiling the callback. + num_warmup_iterations : int. + The number of iterations to use for warmup before profiling the callback. + func : Callable[..., T]. + A callback function with arbitrary arguments to be executed. + *args + Variable length argument list for the callback function. + **kwargs + Keyword arguments for the callback function. - Returns: + Returns + ------- tuple[float, T]: A tuple containing: - - float: The average execution time. + - float: The average execution time in milliseconds. - T: The result of the last function execution. """ + # warmup - execute_callback(num_warmup_iterations, func, *args, **kwargs) + result = execute_callback(num_warmup_iterations, func, *args, **kwargs) + + # Profile using cuda.Event + # Note: The use of AITER_LOG_MORE variable + # is temporary (as we shift to using pytest) + # and should be replaced with a more descriptive + # flag. + if int(os.environ.get('AITER_LOG_MORE', 0)): + average_latency = time_callback_with_cuda_event( + num_iterations, func, *args, **kwargs) + return average_latency, result + + # Profile using torch.profiler with tpf.profile( activities=[ tpf.ProfilerActivity.CPU, @@ -171,9 +253,11 @@ def profile( with_stack=True, with_modules=True, ) as prof: - result = execute_callback(num_iterations, func, *args, **kwargs) + execute_callback(num_iterations, func, *args, **kwargs) + + average_latency = get_trace_perf(prof, num_iterations) - return get_trace_perf(prof, num_iterations), result + return average_latency, result def profile_cuda_graph( num_iterations: int, @@ -181,16 +265,30 @@ def profile_cuda_graph( func: Callable[..., T], *args, **kwargs -) -> T: +) -> tuple[float, T]: """ Profile the execution of a function using CUDA Graph and PyTorch Profiler. This function creates a CUDA Graph for the given function, then profiles its execution using the standard profile function. - Returns: + Parameters + ---------- + num_iterations : int. + The number of iterations to use for profiling the callback. + num_warmup_iterations : int. + The number of iterations to use for warmup before profiling the callback. + func : Callable[..., T]. + A callback function with arbitrary arguments to be executed. + *args + Variable length argument list for the callback function. + **kwargs + Keyword arguments for the callback function. + + Returns + ------- tuple[float, T]: A tuple containing: - - float: The average execution time or other performance metric. + - float: The average execution time in milliseconds. - T: The result of the last function execution. """ graph = torch.cuda.CUDAGraph() @@ -210,8 +308,8 @@ class BenchmarkHook(Generic[T]): ---------------- T : The return type of the function being benchmarked. - Attributes: - ----------- + Attributes + ---------- num_iterations : int The number of times the function will be executed during benchmarking. num_warmup_iterations : int @@ -224,14 +322,14 @@ class BenchmarkHook(Generic[T]): kwargs : dict Keyword arguments to be passed to the benchmarked function. - Methods: - -------- + Methods + ------- __call__() -> tuple[float, T] Executes the benchmarking process and returns a tuple containing the execution time and the result of the benchmarked function. - Usage: - ------ + Usage + ----- def test_example(benchmark): def function_to_benchmark(x: int) -> int: return x * 2 @@ -280,6 +378,7 @@ def __call__(self) -> tuple[float, T]: *self.args, **self.kwargs ) - + +# Define some BenchmarkHook partials for conveniance DefaultBenchmarkHook = partial(BenchmarkHook, 100, 10, False) DefaultCudaGraphBenchmarkHook = partial(BenchmarkHook, 100, 10, True) \ No newline at end of file From 5cbd670d3e89474bdfb00085e70ab3a9293cd7f0 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 20 Feb 2025 05:48:05 +0000 Subject: [PATCH 09/13] cleanup utils and add required MNK test values --- op_tests/test_gemm_a8w8.py | 30 ++++++++++++--- op_tests/utils.py | 77 +++++++++++++++----------------------- 2 files changed, 56 insertions(+), 51 deletions(-) diff --git a/op_tests/test_gemm_a8w8.py b/op_tests/test_gemm_a8w8.py index 8465828c..e360def3 100644 --- a/op_tests/test_gemm_a8w8.py +++ b/op_tests/test_gemm_a8w8.py @@ -11,11 +11,31 @@ MNK = [ # qkv_proj (1, 1280, 8192), + (32, 1280, 8192), + (64, 1280, 8192), + (128, 1280, 8192), + (192, 1280, 8192), + (256, 1280, 8192), + (320, 1280, 8192), (512, 1280, 8192), + (1024, 1280, 8192), + (2048, 1280, 8192), + (4096, 1280, 8192), + (8192, 1280, 8192), (16384, 1280, 8192), # attn_out (1, 8192, 1024), + (32, 8192, 1024), + (64, 8192, 1024), + (128, 8192, 1024), + (192, 8192, 1024), + (256, 8192, 1024), + (320, 8192, 1024), (512, 8192, 1024), + (1024, 8192, 1024), + (2048, 8192, 1024), + (4096, 8192, 1024), + (8192, 8192, 1024), (16384, 8192, 1024), ] @@ -47,11 +67,11 @@ def setup( ) -> tuple[torch.tensor, ...]: m, n, k = mnk - a = rand_tensor(shape=(m, k), dtype=ab_dtype).cuda() - b = rand_tensor(shape=(n, k), dtype=ab_dtype).cuda() - a_scale = rand_tensor(shape=(m, 1), dtype=scales_dtype).cuda() + 1e-6 - b_scale = rand_tensor(shape=(1, n), dtype=scales_dtype).cuda() + 1e-6 - bias = (rand_tensor(shape=[1, n], dtype=torch.bfloat16).cuda()).to(dtype=bias_dtype) if use_bias else None + a = rand_tensor(size=(m, k), dtype=ab_dtype).cuda() + b = rand_tensor(size=(n, k), dtype=ab_dtype).cuda() + a_scale = rand_tensor(size=(m, 1), dtype=scales_dtype).cuda() + 1e-6 + b_scale = rand_tensor(size=(1, n), dtype=scales_dtype).cuda() + 1e-6 + bias = (rand_tensor(size=(1, n), dtype=bias_dtype).cuda()) if use_bias else None return a, b, a_scale, b_scale, bias @pytest.mark.parametrize("mnk", MNK) diff --git a/op_tests/utils.py b/op_tests/utils.py index 3e2d6f24..bd7edf2c 100644 --- a/op_tests/utils.py +++ b/op_tests/utils.py @@ -1,7 +1,6 @@ from functools import partial import numpy as np import os -import pandas as pd import torch import torch.profiler as tpf from typing import Callable, Generic, TypeVar @@ -89,49 +88,36 @@ def check_all_close( """ assert is_close.all(), message -def get_trace_perf(prof, num_iters): - # TODO: clean up - assert (num_iters > 1) - num_iters -= 1 - df = [] - cols = ['name', 'self_cpu_time_total', 'self_device_time_total', - 'device_type', 'device_index',] - for el in prof.events(): - df.append([getattr(el, x, None) for x in cols]) - df = pd.DataFrame(df, columns=cols) - df['cnt'] = 1 - rets = [] - for name, d in df.groupby('name', sort=False): - r = d.iloc[1:][['cnt', - 'self_cpu_time_total', - 'self_device_time_total']].sum() - if not r.empty: - device_type = str(d['device_type'].iat[0]).split('.')[-1] - r['name'] = name - r['device_type'] = device_type - r['device_index'] = str(d['device_index'].iat[0]) - if device_type == 'CUDA': - r['device_time_total'] = r['self_device_time_total'] - r['host_time_total'] = 0 - else: - r['host_time_total'] = r['self_device_time_total'] - r['device_time_total'] = 0 - - rets.append(r) - df = pd.DataFrame(rets) - - cols = ['name', 'cnt', 'host_time_total', 'device_time_total', - 'device_type', 'device_index',] - cols = [el for el in cols if el in df.columns] - df = df[(df.host_time_total > 0) | (df.device_time_total > 0)] - - timerList = ['host_time_total', 'device_time_total', ] - df = df[cols].sort_values(timerList, ignore_index=True) - avg_name = '[avg us/iter]' - for el in timerList: - df.at[avg_name, el] = df[el].sum()/num_iters - return df.at[avg_name, 'device_time_total'] +def extract_avg_cuda_time_trace(torch_profiler: torch.profiler) -> float: + """ + Extract the average CUDA time from a PyTorch profiler trace. + + This function calculates the mean of the self device time for all CUDA events + in the profiler trace, excluding the first event. + Parameters + ---------- + torch_profiler : torch.profiler + A PyTorch profiler object containing trace events. + + Returns + ------- + float: The average CUDA time across all CUDA events, or 0.0 if there are + insufficient events or no CUDA events. + """ + get_cuda_total_time = lambda event: getattr(event, "self_device_time_total", 0.0) + is_cuda = lambda event: getattr(event, "device_type", None) == torch.profiler.DeviceType.CUDA + + if len(torch_profiler.events()) <=1: + return 0.0 + + return np.mean( + [ + get_cuda_total_time(event) + for event in torch_profiler.events()[1:] + if is_cuda(event) + ] + ) def execute_callback( num_iterations: int, @@ -235,7 +221,7 @@ def profile( # Profile using cuda.Event # Note: The use of AITER_LOG_MORE variable - # is temporary (as we shift to using pytest) + # is temporary (until we completly shift to using pytest) # and should be replaced with a more descriptive # flag. if int(os.environ.get('AITER_LOG_MORE', 0)): @@ -255,8 +241,7 @@ def profile( ) as prof: execute_callback(num_iterations, func, *args, **kwargs) - average_latency = get_trace_perf(prof, num_iterations) - + average_latency = extract_avg_cuda_time_trace(prof) return average_latency, result def profile_cuda_graph( From 40450a848e7c6388ead31e113561d7debdbe351d Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 20 Feb 2025 08:14:27 +0000 Subject: [PATCH 10/13] add custom benchmark fixture and fix time_trace unit --- op_tests/conftest.py | 301 +++++++++++++++++++++++++++++++++++++ op_tests/test_gemm_a8w8.py | 10 +- op_tests/utils.py | 280 ---------------------------------- requirements.txt | 2 +- 4 files changed, 305 insertions(+), 288 deletions(-) create mode 100644 op_tests/conftest.py diff --git a/op_tests/conftest.py b/op_tests/conftest.py new file mode 100644 index 00000000..ed0c6a46 --- /dev/null +++ b/op_tests/conftest.py @@ -0,0 +1,301 @@ +import pytest +from collections import defaultdict +import os +import numpy as np +from prettytable import PrettyTable +import torch +import torch.profiler as tpf +from typing import Callable, Generator, TypeVar +from functools import partial + +T = TypeVar("T") + + +def execute_callback( + num_iterations: int, + func: Callable[..., T], + *args, + **kwargs +) -> T: + """ + Executes a callback for a given number of iterations. + + Parameters + ---------- + num_iterations : int. + The number of iterations to use for profiling the callback. + func : Callable[..., T]. + A callback function with arbitrary arguments to be executed. + *args + Variable length argument list for the callback function. + **kwargs + Keyword arguments for the callback function. + + Returns + ------- + T: The last value returned by the callback function. + + """ + for _ in range(num_iterations): + result = func(*args, **kwargs) + return result + +def time_callback_with_cuda_event( + num_iterations: int, + func: Callable[..., T], + *args, + **kwargs +) -> list[float]: + """ + Measure the average execution time of a given function using CUDA + events in milliseconds. + + Parameters + ---------- + num_iterations : int. + The number of iterations to use for profiling the callback. + func : Callable[..., T]. + A callback function with arbitrary arguments to be executed. + *args + Variable length argument list for the callback function. + **kwargs + Keyword arguments for the callback function. + Returns + ------- + list[float]: The executions times in milliseconds. + """ + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + latency_list = [] + for _ in range(num_iterations): + start_event.record() + func(*args, **kwargs) + end_event.record() + end_event.synchronize() + latency = start_event.elapsed_time(end_event) + latency_list.append(latency) + return latency_list + +def extract_cuda_time_trace(torch_profiler: torch.profiler) -> list[float]: + """ + Extract the CUDA times from a PyTorch profiler trace. + + This function extracts self device time for all CUDA events + in the profiler trace, excluding the first event. + + Parameters + ---------- + torch_profiler : torch.profiler + A PyTorch profiler object containing trace events. + + Returns + ------- + float: list[float]: The executions times in milliseconds. + + """ + get_cuda_total_time = lambda event: getattr(event, "self_device_time_total", 0.0) + is_cuda = lambda event: getattr(event, "device_type", None) == torch.profiler.DeviceType.CUDA + + if len(torch_profiler.events()) <=1: + return 0.0 + + return[ + get_cuda_total_time(event) / 1000 + for event in torch_profiler.events()[1:] + if is_cuda(event) + ] + +def profile( + num_iterations: int, + num_warmup_iterations:int, + func: Callable[..., T], + *args, + **kwargs +) -> tuple[list[float], T]: + """ + Profile the execution of a function using PyTorch Profiler. + + This function performs warmup iterations, then profiles the actual execution + of the function for a specified number of iterations. + + Parameters + ---------- + num_iterations : int. + The number of iterations to use for profiling the callback. + num_warmup_iterations : int. + The number of iterations to use for warmup before profiling the callback. + func : Callable[..., T]. + A callback function with arbitrary arguments to be executed. + *args + Variable length argument list for the callback function. + **kwargs + Keyword arguments for the callback function. + + Returns + ------- + tuple[list[float], T]: A tuple containing: + - list[float]: The executions times in milliseconds. + - T: The result of the last function execution. + """ + + # warmup + result = execute_callback(num_warmup_iterations, func, *args, **kwargs) + + # Profile using cuda.Event + # Note: The use of AITER_LOG_MORE variable + # is temporary (until we completly shift to using pytest) + # and should be replaced with a more descriptive + # flag. + if int(os.environ.get('AITER_LOG_MORE', 0)): + latency_list = time_callback_with_cuda_event( + num_iterations, func, *args, **kwargs) + return latency_list, result + + # Profile using torch.profiler + with tpf.profile( + activities=[ + tpf.ProfilerActivity.CPU, + tpf.ProfilerActivity.CUDA + ], + profile_memory=True, + with_stack=True, + with_modules=True, + ) as prof: + execute_callback(num_iterations, func, *args, **kwargs) + + latency_list = extract_cuda_time_trace(prof) + return latency_list, result + +def profile_cuda_graph( + num_iterations: int, + num_warmup_iterations:int, + func: Callable[..., T], + *args, + **kwargs +) -> tuple[list[float], T]: + """ + Profile the execution of a function using CUDA Graph and PyTorch Profiler. + + This function creates a CUDA Graph for the given function, then profiles its + execution using the standard profile function. + + Parameters + ---------- + num_iterations : int. + The number of iterations to use for profiling the callback. + num_warmup_iterations : int. + The number of iterations to use for warmup before profiling the callback. + func : Callable[..., T]. + A callback function with arbitrary arguments to be executed. + *args + Variable length argument list for the callback function. + **kwargs + Keyword arguments for the callback function. + + Returns + ------- + tuple[list[float], T]: A tuple containing: + - list[float]: The executions times in milliseconds. + - T: The result of the last function execution. + """ + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + execute_callback(1, func, *args, **kwargs) + return profile(num_iterations, num_warmup_iterations, func, *args, **kwargs) + + +@pytest.fixture(scope="session") +def benchmark() -> Generator[Callable[..., tuple[float, T]], None, None]: + """ + A pytest fixture that provides a benchmarking function for performance testing. + + This fixture creates a generator that yields a function which can be used to benchmark + other functions, optionally using CUDA graphs. It collects statistics on execution times + and presents the results in a formatted table at the end of the test session. + + Yields + ------ + Callable[..., Tuple[float, T]]: A benchmarking function that can be called within tests. + + Usage + ----- + def test_example(benchmark): + avg_time, result = benchmark( + num_iterations=100, + num_warmup_iterations=10, + use_cuda_graph=False, + func=my_function_to_test, + arg1=value1, + arg2=value2 + ) + """ + test_result_data: dict[str, list[dict[str, float]]] = defaultdict(list) + + def _benchmark( + num_iterations: int, + num_warmup_iterations: int, + use_cuda_graph: bool, + func: Callable[..., T], + *args, + **kwargs + ) -> tuple[float, T]: + + test_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] + + if use_cuda_graph: + execution_time_list, result = profile_cuda_graph( + num_iterations, + num_warmup_iterations, + func, + *args, + **kwargs + ) + execution_time_list, result = profile( + num_iterations, + num_warmup_iterations, + func, + *args, + **kwargs + ) + + stats = { + "test_name": test_name, + "avg_time": np.mean(execution_time_list), + "std_time": np.std(execution_time_list), + "min_time": np.min(execution_time_list), + "max_time": np.max(execution_time_list), + "iterations": num_iterations, + "warmup": num_warmup_iterations, + "cuda_graph": use_cuda_graph + } + test_result_data[test_name].append(stats) + + return result + + yield _benchmark + + table = PrettyTable() + table.field_names = ["Test Name", "Avg Time (ms)", "Std Dev (ms)", "Min Time (ms)", "Max Time (ms)", "Iterations", "Warmup", "CUDA Graph"] + + for test_results in test_result_data.values(): + for result in test_results: + table.add_row([ + result["test_name"], + f"{result['avg_time']:.6f}", + f"{result['std_time']:.6f}", + f"{result['min_time']:.6f}", + f"{result['max_time']:.6f}", + result["iterations"], + result["warmup"], + "Yes" if result["cuda_graph"] else "No" + ]) + + table.align = "r" # Right-align all columns + table.align["Test"] = "l" # Left-align the Test Name column + table.float_format = ".6" # Set float precision to 6 decimal places + table.title = "Benchmark Results" + table.border = True + table.header = True + + print("\n", table) + diff --git a/op_tests/test_gemm_a8w8.py b/op_tests/test_gemm_a8w8.py index e360def3..88cd3e02 100644 --- a/op_tests/test_gemm_a8w8.py +++ b/op_tests/test_gemm_a8w8.py @@ -6,7 +6,7 @@ import torch import torch.nn.functional as F -from .utils import DefaultBenchmarkHook, check_all_close, rand_tensor +from .utils import check_all_close, rand_tensor MNK = [ # qkv_proj @@ -102,9 +102,7 @@ def test_ck_gemm_close_to_torch( use_bias, ) - _, output = benchmark( - DefaultBenchmarkHook(aiter.gemm_a8w8_CK, a, b, a_scale, b_scale, bias, out_dtype) - ) + output = benchmark(100, 10, False, aiter.gemm_a8w8_CK, a, b, a_scale, b_scale, bias, out_dtype) expected = torch_scaled_mm(a, b, a_scale, b_scale, bias, dtype=out_dtype) check_all_close(output, expected, rtol=rtol, atol=atol) @@ -126,9 +124,7 @@ def test_asm_gemm_close_to_torch(benchmark, mnk: tuple[int, int, int]) -> None: ) b_shuffled= shuffle_weight(b, layout=(32, 16)) - _, output = benchmark( - DefaultBenchmarkHook(aiter.gemm_a8w8_ASM, a, b_shuffled, a_scale, b_scale, bias) - ) + output = benchmark(100, 10, False, aiter.gemm_a8w8_ASM, a, b_shuffled, a_scale, b_scale, bias) expected = torch_scaled_mm(a, b, a_scale, b_scale, bias, dtype=out_dtype) if output is not None and torch.sum(output.isnan()==True) == 0: check_all_close(output, expected, rtol=rtol, atol=atol) diff --git a/op_tests/utils.py b/op_tests/utils.py index bd7edf2c..a177c333 100644 --- a/op_tests/utils.py +++ b/op_tests/utils.py @@ -87,283 +87,3 @@ def check_all_close( delta details: {percent:.1%} ({num_not_close} of {a.numel()}) elements """ assert is_close.all(), message - -def extract_avg_cuda_time_trace(torch_profiler: torch.profiler) -> float: - """ - Extract the average CUDA time from a PyTorch profiler trace. - - This function calculates the mean of the self device time for all CUDA events - in the profiler trace, excluding the first event. - - Parameters - ---------- - torch_profiler : torch.profiler - A PyTorch profiler object containing trace events. - - Returns - ------- - float: The average CUDA time across all CUDA events, or 0.0 if there are - insufficient events or no CUDA events. - """ - get_cuda_total_time = lambda event: getattr(event, "self_device_time_total", 0.0) - is_cuda = lambda event: getattr(event, "device_type", None) == torch.profiler.DeviceType.CUDA - - if len(torch_profiler.events()) <=1: - return 0.0 - - return np.mean( - [ - get_cuda_total_time(event) - for event in torch_profiler.events()[1:] - if is_cuda(event) - ] - ) - -def execute_callback( - num_iterations: int, - func: Callable[..., T], - *args, - **kwargs -) -> T: - """ - Executes a callback for a given number of iterations. - - Parameters - ---------- - num_iterations : int. - The number of iterations to use for profiling the callback. - func : Callable[..., T]. - A callback function with arbitrary arguments to be executed. - *args - Variable length argument list for the callback function. - **kwargs - Keyword arguments for the callback function. - - Returns - ------- - T: The last value returned by the callback function. - """ - for _ in range(num_iterations): - result = func(*args, **kwargs) - return result - -def time_callback_with_cuda_event( - num_iterations: int, - func: Callable[..., T], - *args, - **kwargs -) -> float: - """ - Measure the average execution time of a given function using CUDA - events in milliseconds. - - Parameters - ---------- - num_iterations : int. - The number of iterations to use for profiling the callback. - func : Callable[..., T]. - A callback function with arbitrary arguments to be executed. - *args - Variable length argument list for the callback function. - **kwargs - Keyword arguments for the callback function. - Returns - ------- - float: The average execution time in milliseconds over all iterations. - """ - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - latency_list = [] - for _ in range(num_iterations): - start_event.record() - func(*args, **kwargs) - end_event.record() - end_event.synchronize() - latency = start_event.elapsed_time(end_event) - latency_list.append(latency) - return np.mean(latency_list) - -def profile( - num_iterations: int, - num_warmup_iterations:int, - func: Callable[..., T], - *args, - **kwargs -) -> tuple[float, T]: - """ - Profile the execution of a function using PyTorch Profiler. - - This function performs warmup iterations, then profiles the actual execution - of the function for a specified number of iterations. - - Parameters - ---------- - num_iterations : int. - The number of iterations to use for profiling the callback. - num_warmup_iterations : int. - The number of iterations to use for warmup before profiling the callback. - func : Callable[..., T]. - A callback function with arbitrary arguments to be executed. - *args - Variable length argument list for the callback function. - **kwargs - Keyword arguments for the callback function. - - Returns - ------- - tuple[float, T]: A tuple containing: - - float: The average execution time in milliseconds. - - T: The result of the last function execution. - """ - - # warmup - result = execute_callback(num_warmup_iterations, func, *args, **kwargs) - - # Profile using cuda.Event - # Note: The use of AITER_LOG_MORE variable - # is temporary (until we completly shift to using pytest) - # and should be replaced with a more descriptive - # flag. - if int(os.environ.get('AITER_LOG_MORE', 0)): - average_latency = time_callback_with_cuda_event( - num_iterations, func, *args, **kwargs) - return average_latency, result - - # Profile using torch.profiler - with tpf.profile( - activities=[ - tpf.ProfilerActivity.CPU, - tpf.ProfilerActivity.CUDA - ], - profile_memory=True, - with_stack=True, - with_modules=True, - ) as prof: - execute_callback(num_iterations, func, *args, **kwargs) - - average_latency = extract_avg_cuda_time_trace(prof) - return average_latency, result - -def profile_cuda_graph( - num_iterations: int, - num_warmup_iterations:int, - func: Callable[..., T], - *args, - **kwargs -) -> tuple[float, T]: - """ - Profile the execution of a function using CUDA Graph and PyTorch Profiler. - - This function creates a CUDA Graph for the given function, then profiles its - execution using the standard profile function. - - Parameters - ---------- - num_iterations : int. - The number of iterations to use for profiling the callback. - num_warmup_iterations : int. - The number of iterations to use for warmup before profiling the callback. - func : Callable[..., T]. - A callback function with arbitrary arguments to be executed. - *args - Variable length argument list for the callback function. - **kwargs - Keyword arguments for the callback function. - - Returns - ------- - tuple[float, T]: A tuple containing: - - float: The average execution time in milliseconds. - - T: The result of the last function execution. - """ - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - execute_callback(1, func, *args, **kwargs) - return profile(num_iterations, num_warmup_iterations, func, *args, **kwargs) - - -class BenchmarkHook(Generic[T]): - """ - A generic class for custom benchmarking in pytest using pytest-benchmark. - - This class allows for fine-grained control over benchmarking parameters and - can be used with the pytest-benchmark fixture to measure function performance. - - Type Parameters: - ---------------- - T : The return type of the function being benchmarked. - - Attributes - ---------- - num_iterations : int - The number of times the function will be executed during benchmarking. - num_warmup_iterations : int - The number of times the function will be executed before actual benchmarking begins. - These iterations are not included in the final measurements. - func : Callable[..., T] - The function to be benchmarked. - args : tuple - Positional arguments to be passed to the benchmarked function. - kwargs : dict - Keyword arguments to be passed to the benchmarked function. - - Methods - ------- - __call__() -> tuple[float, T] - Executes the benchmarking process and returns a tuple containing - the execution time and the result of the benchmarked function. - - Usage - ----- - def test_example(benchmark): - def function_to_benchmark(x: int) -> int: - return x * 2 - - hook = BenchmarkHook( - num_iterations=1000, - num_warmup_iterations=100, - func=function_to_benchmark, - 5 - ) - - execution_time, result = benchmark(hook) - - assert result == 10 - assert execution_time > 0 - """ - - def __init__(self, - num_iterations:int, - num_warmup_iterations:int, - use_cuda_graph: bool, - func: Callable[..., T], - *args, - **kwargs - ) -> None: - self.num_iterations = num_iterations - self.num_warmup_iterations = num_warmup_iterations - self.use_cuda_graph = use_cuda_graph - self.func = func - self.args = args - self.kwargs = kwargs - - def __call__(self) -> tuple[float, T]: - if self.use_cuda_graph: - return profile_cuda_graph( - self.num_iterations, - self.num_warmup_iterations, - self.func, - *self.args, - **self.kwargs - ) - return profile( - self.num_iterations, - self.num_warmup_iterations, - self.func, - *self.args, - **self.kwargs - ) - -# Define some BenchmarkHook partials for conveniance -DefaultBenchmarkHook = partial(BenchmarkHook, 100, 10, False) -DefaultCudaGraphBenchmarkHook = partial(BenchmarkHook, 100, 10, True) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 01cdf87c..e4849dc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ pandas<=2.0.3 pytest -pytest-benchmark==5.1.* \ No newline at end of file +prettytable==3.14.* \ No newline at end of file From 7bf68b1435c8754870623fbb53467b31a998cb5c Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 21 Feb 2025 06:42:01 +0000 Subject: [PATCH 11/13] use env vars for profiling params --- op_tests/conftest.py | 35 +++++++++++++---------------------- op_tests/test_gemm_a8w8.py | 4 ++-- 2 files changed, 15 insertions(+), 24 deletions(-) diff --git a/op_tests/conftest.py b/op_tests/conftest.py index ed0c6a46..85765f23 100644 --- a/op_tests/conftest.py +++ b/op_tests/conftest.py @@ -10,6 +10,10 @@ T = TypeVar("T") +AITER_LOG_MORE = int(os.environ.get('AITER_LOG_MORE', 0)) +NUM_ITERATIONS = int(os.environ.get('NUM_ITERATIONS', 100)) +NUM_WARMUP_ITERATION = os.environ.get('NUM_WARMUP_ITERATION', 10) +USE_CUDA_GRAPH = os.environ.get('USE_CUDA_GRAPH', 0) def execute_callback( num_iterations: int, @@ -146,7 +150,7 @@ def profile( # is temporary (until we completly shift to using pytest) # and should be replaced with a more descriptive # flag. - if int(os.environ.get('AITER_LOG_MORE', 0)): + if AITER_LOG_MORE: latency_list = time_callback_with_cuda_event( num_iterations, func, *args, **kwargs) return latency_list, result @@ -221,9 +225,6 @@ def benchmark() -> Generator[Callable[..., tuple[float, T]], None, None]: ----- def test_example(benchmark): avg_time, result = benchmark( - num_iterations=100, - num_warmup_iterations=10, - use_cuda_graph=False, func=my_function_to_test, arg1=value1, arg2=value2 @@ -232,27 +233,17 @@ def test_example(benchmark): test_result_data: dict[str, list[dict[str, float]]] = defaultdict(list) def _benchmark( - num_iterations: int, - num_warmup_iterations: int, - use_cuda_graph: bool, func: Callable[..., T], *args, **kwargs - ) -> tuple[float, T]: + ) -> T: test_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] + profile_func = profile_cuda_graph if USE_CUDA_GRAPH else profile - if use_cuda_graph: - execution_time_list, result = profile_cuda_graph( - num_iterations, - num_warmup_iterations, - func, - *args, - **kwargs - ) - execution_time_list, result = profile( - num_iterations, - num_warmup_iterations, + execution_time_list, result = profile_func( + NUM_ITERATIONS, + NUM_WARMUP_ITERATION, func, *args, **kwargs @@ -264,9 +255,9 @@ def _benchmark( "std_time": np.std(execution_time_list), "min_time": np.min(execution_time_list), "max_time": np.max(execution_time_list), - "iterations": num_iterations, - "warmup": num_warmup_iterations, - "cuda_graph": use_cuda_graph + "iterations": NUM_ITERATIONS, + "warmup": NUM_WARMUP_ITERATION, + "cuda_graph": USE_CUDA_GRAPH } test_result_data[test_name].append(stats) diff --git a/op_tests/test_gemm_a8w8.py b/op_tests/test_gemm_a8w8.py index 88cd3e02..71926837 100644 --- a/op_tests/test_gemm_a8w8.py +++ b/op_tests/test_gemm_a8w8.py @@ -102,7 +102,7 @@ def test_ck_gemm_close_to_torch( use_bias, ) - output = benchmark(100, 10, False, aiter.gemm_a8w8_CK, a, b, a_scale, b_scale, bias, out_dtype) + output = benchmark(aiter.gemm_a8w8_CK, a, b, a_scale, b_scale, bias, out_dtype) expected = torch_scaled_mm(a, b, a_scale, b_scale, bias, dtype=out_dtype) check_all_close(output, expected, rtol=rtol, atol=atol) @@ -124,7 +124,7 @@ def test_asm_gemm_close_to_torch(benchmark, mnk: tuple[int, int, int]) -> None: ) b_shuffled= shuffle_weight(b, layout=(32, 16)) - output = benchmark(100, 10, False, aiter.gemm_a8w8_ASM, a, b_shuffled, a_scale, b_scale, bias) + output = benchmark(aiter.gemm_a8w8_ASM, a, b_shuffled, a_scale, b_scale, bias) expected = torch_scaled_mm(a, b, a_scale, b_scale, bias, dtype=out_dtype) if output is not None and torch.sum(output.isnan()==True) == 0: check_all_close(output, expected, rtol=rtol, atol=atol) From b5477f93ca4e00f98960cce212da873eed228a85 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 21 Feb 2025 14:41:28 +0000 Subject: [PATCH 12/13] move utils to test_common --- aiter/test_common.py | 96 ++++++++++++++++++++++++++++++++++++-- op_tests/test_gemm_a8w8.py | 6 +-- op_tests/utils.py | 89 ----------------------------------- 3 files changed, 94 insertions(+), 97 deletions(-) delete mode 100644 op_tests/utils.py diff --git a/aiter/test_common.py b/aiter/test_common.py index 3a1075f0..9729c93d 100644 --- a/aiter/test_common.py +++ b/aiter/test_common.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -import torch -import torch.profiler as tpf -import os +from aiter import logger import numpy as np +import os import pandas as pd -from aiter import logger +import torch +import torch.profiler as tpf +from typing import TypeVar + + +T = TypeVar("T") + +MAX_RAND_INT = 20 +MIN_RAND_INT = -20 def perftest(num_iters=101, num_warmup=10, testGraph=False): @@ -153,3 +159,83 @@ def tensor_load(filename: str): metafile = '.'.join(filename.split('.')[:-1])+'.meta' shape, dtype = [eval(line.strip()) for line in open(metafile)] return torch.tensor(DWs).view(dtype).view(shape) + + +def rand_tensor( + size: torch.Size, + dtype: torch.dtype +) -> torch.tensor: + """ + Generate a random PyTorch tensor with specified shape and data type. + + - For integer types: Uses torch.randint to generate random integers within a fixed range. + - For float types: Uses torch.rand to generate random floats between 0 and 1. + + Parameters + ---------- + size: torch.Size + The size of the generated tensor. + dtype: torch.dtype + The data type of the generated tensor. + + Returns + ------- + torch.Tensor : A random tensor of the specified shape and data type. + + Raises + ------ + ValueError + If an unsupported data type is provided. + """ + if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + # For integer types, use randint + return torch.randint(MIN_RAND_INT, MAX_RAND_INT, size=size, dtype=dtype) + elif dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16]: + # For float types, use rand + return torch.rand(size=size, dtype=dtype) + elif dtype == torch.float8_e4m3fnuz: + # Special case for float8_e4m3fnuz + return torch.rand(size=size, dtype=torch.float16).to(torch.float8_e4m3fnuz) + + raise ValueError(f"Unsupported dtype: {dtype}") + + +def assert_all_close( + a: torch.tensor, + b: torch.tensor, + rtol: float, + atol: float, +) -> None: + """ + Check if all elements in two tensors are close within specified tolerances. + + Parameters + ---------- + a: torch.tensor + First input tensor. + b: torch.tensor + Second input tensor. + rtol: float + Relative tolerence. + atol: float + Absolute tolerence. + + Raises + ------- + AssertionError + If any elements in 'a' and 'b' are not close within the specified tolerances. + The error message includes details about the maximum and average delta, + and the percentage of elements that are not close. + """ + is_close = torch.isclose(a, b, rtol=rtol, atol=atol) + is_not_close = ~is_close + num_not_close = is_not_close.sum() + delta = (a-b)[is_not_close] + percent = num_not_close/a.numel() + message = "" if num_not_close == 0 else f""" +check_all_close failed! +max delta:{delta.max()} +average delta:{delta.mean()} +delta details: {percent:.1%} ({num_not_close} of {a.numel()}) elements + """ + assert is_close.all(), message diff --git a/op_tests/test_gemm_a8w8.py b/op_tests/test_gemm_a8w8.py index 71926837..03ef8fa4 100644 --- a/op_tests/test_gemm_a8w8.py +++ b/op_tests/test_gemm_a8w8.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. import aiter +from aiter.test_common import assert_all_close, rand_tensor from aiter.ops.shuffle import shuffle_weight import pytest import torch import torch.nn.functional as F -from .utils import check_all_close, rand_tensor MNK = [ # qkv_proj @@ -105,7 +105,7 @@ def test_ck_gemm_close_to_torch( output = benchmark(aiter.gemm_a8w8_CK, a, b, a_scale, b_scale, bias, out_dtype) expected = torch_scaled_mm(a, b, a_scale, b_scale, bias, dtype=out_dtype) - check_all_close(output, expected, rtol=rtol, atol=atol) + assert_all_close(output, expected, rtol=rtol, atol=atol) @pytest.mark.parametrize("mnk", MNK) def test_asm_gemm_close_to_torch(benchmark, mnk: tuple[int, int, int]) -> None: @@ -127,4 +127,4 @@ def test_asm_gemm_close_to_torch(benchmark, mnk: tuple[int, int, int]) -> None: output = benchmark(aiter.gemm_a8w8_ASM, a, b_shuffled, a_scale, b_scale, bias) expected = torch_scaled_mm(a, b, a_scale, b_scale, bias, dtype=out_dtype) if output is not None and torch.sum(output.isnan()==True) == 0: - check_all_close(output, expected, rtol=rtol, atol=atol) + assert_all_close(output, expected, rtol=rtol, atol=atol) diff --git a/op_tests/utils.py b/op_tests/utils.py deleted file mode 100644 index a177c333..00000000 --- a/op_tests/utils.py +++ /dev/null @@ -1,89 +0,0 @@ -from functools import partial -import numpy as np -import os -import torch -import torch.profiler as tpf -from typing import Callable, Generic, TypeVar - -T = TypeVar("T") - -MAX_RAND_INT = 20 -MIN_RAND_INT = -20 - -def rand_tensor( - size: torch.Size, - dtype: torch.dtype -) -> torch.tensor: - """ - Generate a random PyTorch tensor with specified shape and data type. - - - For integer types: Uses torch.randint to generate random integers within a fixed range. - - For float types: Uses torch.rand to generate random floats between 0 and 1. - - Parameters - ---------- - size: torch.Size - The size of the generated tensor. - dtype: torch.dtype - The data type of the generated tensor. - - Returns - ------- - torch.Tensor : A random tensor of the specified shape and data type. - - Raises - ------ - ValueError - If an unsupported data type is provided. - """ - if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: - # For integer types, use randint - return torch.randint(MIN_RAND_INT, MAX_RAND_INT, size=size, dtype=dtype) - elif dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16]: - # For float types, use rand - return torch.rand(size=size, dtype=dtype) - elif dtype == torch.float8_e4m3fnuz: - # Special case for float8_e4m3fnuz - return torch.rand(size=size, dtype=torch.float16).to(torch.float8_e4m3fnuz) - - raise ValueError(f"Unsupported dtype: {dtype}") - -def check_all_close( - a: torch.tensor, - b: torch.tensor, - rtol: float, - atol: float, -) -> None: - """ - Check if all elements in two tensors are close within specified tolerances. - - Parameters - ---------- - a: torch.tensor - First input tensor. - b: torch.tensor - Second input tensor. - rtol: float - Relative tolerence. - atol: float - Absolute tolerence. - - Raises - ------- - AssertionError - If any elements in 'a' and 'b' are not close within the specified tolerances. - The error message includes details about the maximum and average delta, - and the percentage of elements that are not close. - """ - is_close = torch.isclose(a, b, rtol=rtol, atol=atol) - is_not_close = ~is_close - num_not_close = is_not_close.sum() - delta = (a-b)[is_not_close] - percent = num_not_close/a.numel() - message = "" if num_not_close == 0 else f""" -check_all_close failed! -max delta:{delta.max()} -average delta:{delta.mean()} -delta details: {percent:.1%} ({num_not_close} of {a.numel()}) elements - """ - assert is_close.all(), message From 6301420dcd71f1819731fb357171155d91fca2b7 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 26 Feb 2025 08:14:19 +0000 Subject: [PATCH 13/13] fix gen_instances and test bugs --- csrc/ck_gemm_a8w8/gen_instances.py | 19 ++++++++++++------- op_tests/test_gemm_a8w8.py | 6 +++--- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/csrc/ck_gemm_a8w8/gen_instances.py b/csrc/ck_gemm_a8w8/gen_instances.py index 48c59414..9431f3a9 100644 --- a/csrc/ck_gemm_a8w8/gen_instances.py +++ b/csrc/ck_gemm_a8w8/gen_instances.py @@ -248,22 +248,25 @@ def gen_kernel_dict_item_as_str(mnk: tuple | int, k: KernelParameters) -> str: def gen_lookup_dict(kernel_dict: dict, is_tune: bool) -> str: # Do not include default kernels in the lookup table for non-tuning calls. - filter_mnk = lambda mnk : True if is_tune else isinstance(mnk, tuple) + filter_mnk = lambda mnk: True if is_tune else isinstance(mnk, tuple) kernel_dict_items = [ gen_kernel_dict_item_as_str(mnk, k) for mnk, k in kernel_dict.items() if filter_mnk(mnk) ] + + lookup_table = ",\\\n ".join(kernel_dict_items) + return f"""#pragma once // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #ifdef USE_ROCM -#define GENERATE_LOOKUP_TABLE(A_TYPE, B_TYPE, C_SHUFFLE_TYPE, COMPUTE_TYPE, ACC_TYPE, D_TYPE, E_TYPE) \\ -{{ \\ - {",\\\n ".join(kernel_dict_items) + " \\"} +#define GENERATE_LOOKUP_TABLE(A_TYPE, B_TYPE, C_SHUFFLE_TYPE, COMPUTE_TYPE, ACC_TYPE, D_TYPE, E_TYPE) \\ +{{ \\ + {lookup_table} \\ }} -#endif // USE_ROCM +#endif """ def gen_kernel_definition(kernel_name: str) -> str: @@ -292,6 +295,8 @@ def gen_manifest(kernels_dict: dict) -> str: kernel_definition_list = [ gen_kernel_definition(k.name) for k in kernels_dict.values() ] + kernel_definitions = "\n".join(kernel_definition_list) + return f"""#pragma once // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. @@ -301,8 +306,8 @@ def gen_manifest(kernels_dict: dict) -> str: #include #include - {"\n".join(kernel_definition_list)} -#endif // USE_ROCM +{kernel_definitions} +#endif """ def gemm_a8w8_fwd_codegen(working_path: Path, kernel_parameters_dict: dict, is_tune: bool): diff --git a/op_tests/test_gemm_a8w8.py b/op_tests/test_gemm_a8w8.py index 03ef8fa4..4b786839 100644 --- a/op_tests/test_gemm_a8w8.py +++ b/op_tests/test_gemm_a8w8.py @@ -58,7 +58,7 @@ def torch_scaled_mm(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16 return out.to(dtype) -def setup( +def setup_gemm_test( mnk: tuple[int, int, int], ab_dtype: torch.dtype, scales_dtype: torch.dtype, @@ -94,7 +94,7 @@ def test_ck_gemm_close_to_torch( scales_dtype, out_dtype = scales_output_dtype bias_dtype = out_dtype - a, b, a_scale, b_scale, bias = setup( + a, b, a_scale, b_scale, bias = setup_gemm_test( mnk, ab_dtype, scales_dtype, @@ -115,7 +115,7 @@ def test_asm_gemm_close_to_torch(benchmark, mnk: tuple[int, int, int]) -> None: scales_dtype = torch.float32 bias_dtype = torch.float # gemm_a8w8_ASM requires bias and shuffle - a, b, a_scale, b_scale, bias = setup( + a, b, a_scale, b_scale, bias = setup_gemm_test( mnk, ab_dtype, scales_dtype,