diff --git a/aiter/ops/triton/gemm_a8w8.py b/aiter/ops/triton/gemm_a8w8.py new file mode 100644 index 00000000..89a86de2 --- /dev/null +++ b/aiter/ops/triton/gemm_a8w8.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +from typing import Optional + + +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, + "GRID_MN": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"]) + * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), + } +) +@triton.jit +def _gemm_a8w8_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + bias_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + HAS_BIAS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + EVEN_K: tl.constexpr, + GRID_MN: tl.constexpr, +): + """ + Note: this is Triton jited function and not meant to be called directly. Call gemm_a8w8 function + below + + Computes the 8 bit matmul C = A x B, applies a conversion scale and optionally adds a bias to + the result. + The conversion scale is received in the form of two 1D tensors that are multiplied to form a + 2D one before being applied. + + Key parameters: + - A: Matrix A with shape (M, K). + - B: Matrix B with shape (K, N). + - C: Matrix C with shape (M, N). + - A_scale: First scale tensor with shape (M, 1). + - B_scale: Second scale tensor with shape (1, N). + - Bias: Bias tensor with shape (1, N). + """ + + NUM_XCDS: tl.constexpr = 8 + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + ## pid remapping on xcds + # Number of pids per XCD in the new arrangement + pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS + # When GRID_MN cannot divide NUM_XCDS, some xcds will have + # pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. + # We calculate the number of xcds that have pids_per_xcd pids as + # tall_xcds + tall_xcds = GRID_MN % NUM_XCDS + tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds + # Compute current XCD and local pid within the XCD + xcd = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + # Calculate new pid based on the new grouping + # Note that we need to consider the following two cases: + # 1. the current pid is on a tall xcd + # 2. the current pid is on a short xcd + if xcd < tall_xcds: + pid = xcd * pids_per_xcd + local_pid + else: + pid = ( + tall_xcds * pids_per_xcd + + (xcd - tall_xcds) * (pids_per_xcd - 1) + + local_pid + ) + + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + tl.assume(pid_m > 0) + tl.assume(pid_n > 0) + + # Create pointers for first block of A and B input matrices + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # Create pointers for the scale tensors and load them + offs_a_scale = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) % M + offs_b_scale = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) % N + a_scale = tl.load(a_scale_ptr + offs_a_scale) + b_scale = tl.load(b_scale_ptr + offs_b_scale) + + acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + accumulator += tl.dot(a, b, input_precision="ieee") + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Apply scale + accumulator *= a_scale[:, None] * b_scale[None, :] + + # Add bias + if HAS_BIAS: + offs_bias = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) % N + bias = tl.load(bias_ptr + offs_bias) + accumulator = accumulator.to(bias_ptr.type.element_ty) + bias[None, :] + + c = accumulator.to(c_ptr.type.element_ty) + + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def gemm_a8w8( + x: torch.Tensor, + w: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + dtype: Optional[float] = torch.bfloat16, +): + """ + Computes the 8 bit matmul Y = X x WT, applies a conversion scale and optionally adds a bias + to the result. + The conversion scale is received in the form of two 1D tensors that are multiplied to form a + 2D one before being applied. + + Key parameters: + - X: Matrix X with shape (M, K). + - W: Matrix W with shape (N, K). + - X_scale: First scale tensor with shape (M, 1). + - W_scale: Second scale tensor with shape (1, N). + - Bias: Bias tensor with shape (1, N). + + Returns: + - Y: The output matrix with shape (M, N). + """ + + # Check constraints. + assert x.shape[1] == w.shape[1], "Incompatible dimensions!!!" + + # Transpose w + w = w.T + + M, K = x.shape + K, N = w.shape + + y = torch.empty((M, N), dtype=dtype, device=x.device) + + BLOCK_SIZE_M = 256 + BLOCK_SIZE_N = 256 + BLOCK_SIZE_K = 64 + GROUP_SIZE_M = 4 + waves_per_eu = 2 + kpack = 1 + matrix_instr_nonkdim = 16 + num_warps = 8 + num_stages = 2 + + grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) + _gemm_a8w8_kernel[grid]( + x, + w, + y, + x_scale, + w_scale, + bias, + M, + N, + K, + x.stride(0), + x.stride(1), + w.stride(0), + w.stride(1), + y.stride(0), + y.stride(1), + bias is not None, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + GROUP_SIZE_M, + waves_per_eu=waves_per_eu, + kpack=kpack, + matrix_instr_nonkdim=matrix_instr_nonkdim, + num_warps=num_warps, + num_stages=num_stages, + ) + + return y diff --git a/op_tests/triton/test_gemm_a8w8.py b/op_tests/triton/test_gemm_a8w8.py new file mode 100644 index 00000000..eee821ec --- /dev/null +++ b/op_tests/triton/test_gemm_a8w8.py @@ -0,0 +1,90 @@ +import torch +import triton +import triton.language as tl +import pytest +from aiter.ops.triton.gemm_a8w8 import gemm_a8w8 +import torch.nn.functional as F + + +def run_torch(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): + x = F.linear(x.to(torch.float32), weight.to(torch.float32)) + scale = torch.matmul(x_scale, w_scale) + out = torch.mul(x, scale) + if bias is not None: + out = out.to(bias) + bias + return out.to(dtype) + + +def run_triton(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): + return gemm_a8w8(x, weight, x_scale, w_scale, bias) + + +def is_cdna4(): + return triton.runtime.driver.active.get_current_target().arch == "gfx950" + + +e5m2_type = torch.float8_e5m2 if is_cdna4() else torch.float8_e5m2fnuz +e4m3_type = torch.float8_e4m3fn if is_cdna4() else torch.float8_e4m3fnuz + +name_to_torch_types = { + "int8": torch.int8, + "int32": torch.int32, + "fp16": torch.float16, + "fp32": torch.float32, + "bf16": torch.bfloat16, + "fp8e5": e5m2_type, + "fp8e4": e4m3_type, +} + + +def get_x_vals(): + + x_vals = [(1024 * v, 1024 * v, 1024 * v) for v in range(1, 9)] + x_vals += [(4864, 4096, 8192), (9728, 8192, 65536), (4864, 8192, 4160)] + x_vals += [ + (1, 1280, 8192), + (32, 1280, 8192), + (64, 1280, 8192), + (128, 1280, 8192), + (192, 1280, 8192), + (256, 1280, 8192), + (320, 1280, 8192), + (512, 1280, 8192), + (1024, 1280, 8192), + (2048, 1280, 8192), + (4096, 1280, 8192), + (8192, 1280, 8192), + (16384, 1280, 8192), + (1, 8192, 1024), + (32, 8192, 1024), + (64, 8192, 1024), + (128, 8192, 1024), + (192, 8192, 1024), + (256, 8192, 1024), + (320, 8192, 1024), + (512, 8192, 1024), + (1024, 8192, 1024), + (2048, 8192, 1024), + (4096, 8192, 1024), + (8192, 8192, 1024), + (16384, 8192, 1024), + ] + return x_vals + + +@pytest.mark.parametrize( + "dtype, m, n, k", [(dtype, *shape) for dtype in ["bf16"] for shape in get_x_vals()] +) +def test_gemm(dtype, m, n, k): + dim = (m, n, k) + dtype = name_to_torch_types[dtype] + x = torch.randint(-20, 20, (m, k), dtype=torch.int8).cuda() + weight = torch.randint(-20, 20, (n, k), dtype=torch.int8).cuda() + x_scale = torch.rand([m, 1], dtype=torch.float32).cuda() + 1e-6 + w_scale = torch.rand([1, n], dtype=torch.float32).cuda() + 1e-6 + bias = torch.rand([1, n], dtype=dtype).cuda() * 10 + + a = run_torch(x, weight, x_scale, w_scale, bias, dtype) + b = run_triton(x, weight, x_scale, w_scale, bias, dtype) + + triton.testing.assert_close(a, b, atol=0.01, rtol=1e-2)