Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GEMM A8W8 Triton Kernel #191

Merged
merged 1 commit into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions aiter/ops/triton/gemm_a8w8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import torch
import triton
import triton.language as tl
from typing import Optional


@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0,
"GRID_MN": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"])
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
def _gemm_a8w8_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
a_scale_ptr,
b_scale_ptr,
bias_ptr,
# Matrix dimensions
M,
N,
K,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# Meta-parameters
HAS_BIAS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
EVEN_K: tl.constexpr,
GRID_MN: tl.constexpr,
):
"""
Note: this is Triton jited function and not meant to be called directly. Call gemm_a8w8 function
below

Computes the 8 bit matmul C = A x B, applies a conversion scale and optionally adds a bias to
the result.
The conversion scale is received in the form of two 1D tensors that are multiplied to form a
2D one before being applied.

Key parameters:
- A: Matrix A with shape (M, K).
- B: Matrix B with shape (K, N).
- C: Matrix C with shape (M, N).
- A_scale: First scale tensor with shape (M, 1).
- B_scale: Second scale tensor with shape (1, N).
- Bias: Bias tensor with shape (1, N).
"""

NUM_XCDS: tl.constexpr = 8

tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)

# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

## pid remapping on xcds
# Number of pids per XCD in the new arrangement
pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS
# When GRID_MN cannot divide NUM_XCDS, some xcds will have
# pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
# We calculate the number of xcds that have pids_per_xcd pids as
# tall_xcds
tall_xcds = GRID_MN % NUM_XCDS
tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds
# Compute current XCD and local pid within the XCD
xcd = pid % NUM_XCDS
local_pid = pid // NUM_XCDS
# Calculate new pid based on the new grouping
# Note that we need to consider the following two cases:
# 1. the current pid is on a tall xcd
# 2. the current pid is on a short xcd
if xcd < tall_xcds:
pid = xcd * pids_per_xcd + local_pid
else:
pid = (
tall_xcds * pids_per_xcd
+ (xcd - tall_xcds) * (pids_per_xcd - 1)
+ local_pid
)

if GROUP_SIZE_M == 1:
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
else:
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

tl.assume(pid_m > 0)
tl.assume(pid_n > 0)

# Create pointers for first block of A and B input matrices
offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

# Create pointers for the scale tensors and load them
offs_a_scale = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) % M
offs_b_scale = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) % N
a_scale = tl.load(a_scale_ptr + offs_a_scale)
b_scale = tl.load(b_scale_ptr + offs_b_scale)

acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)

accumulator += tl.dot(a, b, input_precision="ieee")

# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

# Apply scale
accumulator *= a_scale[:, None] * b_scale[None, :]

# Add bias
if HAS_BIAS:
offs_bias = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) % N
bias = tl.load(bias_ptr + offs_bias)
accumulator = accumulator.to(bias_ptr.type.element_ty) + bias[None, :]

c = accumulator.to(c_ptr.type.element_ty)

# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


def gemm_a8w8(
x: torch.Tensor,
w: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype: Optional[float] = torch.bfloat16,
):
"""
Computes the 8 bit matmul Y = X x WT, applies a conversion scale and optionally adds a bias
to the result.
The conversion scale is received in the form of two 1D tensors that are multiplied to form a
2D one before being applied.

Key parameters:
- X: Matrix X with shape (M, K).
- W: Matrix W with shape (N, K).
- X_scale: First scale tensor with shape (M, 1).
- W_scale: Second scale tensor with shape (1, N).
- Bias: Bias tensor with shape (1, N).

Returns:
- Y: The output matrix with shape (M, N).
"""

# Check constraints.
assert x.shape[1] == w.shape[1], "Incompatible dimensions!!!"

# Transpose w
w = w.T

M, K = x.shape
K, N = w.shape

y = torch.empty((M, N), dtype=dtype, device=x.device)

BLOCK_SIZE_M = 256
BLOCK_SIZE_N = 256
BLOCK_SIZE_K = 64
GROUP_SIZE_M = 4
waves_per_eu = 2
kpack = 1
matrix_instr_nonkdim = 16
num_warps = 8
num_stages = 2

grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),)
_gemm_a8w8_kernel[grid](
x,
w,
y,
x_scale,
w_scale,
bias,
M,
N,
K,
x.stride(0),
x.stride(1),
w.stride(0),
w.stride(1),
y.stride(0),
y.stride(1),
bias is not None,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
GROUP_SIZE_M,
waves_per_eu=waves_per_eu,
kpack=kpack,
matrix_instr_nonkdim=matrix_instr_nonkdim,
num_warps=num_warps,
num_stages=num_stages,
)

return y
90 changes: 90 additions & 0 deletions op_tests/triton/test_gemm_a8w8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
import triton
import triton.language as tl
import pytest
from aiter.ops.triton.gemm_a8w8 import gemm_a8w8
import torch.nn.functional as F


def run_torch(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16):
x = F.linear(x.to(torch.float32), weight.to(torch.float32))
scale = torch.matmul(x_scale, w_scale)
out = torch.mul(x, scale)
if bias is not None:
out = out.to(bias) + bias
return out.to(dtype)


def run_triton(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16):
return gemm_a8w8(x, weight, x_scale, w_scale, bias)


def is_cdna4():
return triton.runtime.driver.active.get_current_target().arch == "gfx950"


e5m2_type = torch.float8_e5m2 if is_cdna4() else torch.float8_e5m2fnuz
e4m3_type = torch.float8_e4m3fn if is_cdna4() else torch.float8_e4m3fnuz

name_to_torch_types = {
"int8": torch.int8,
"int32": torch.int32,
"fp16": torch.float16,
"fp32": torch.float32,
"bf16": torch.bfloat16,
"fp8e5": e5m2_type,
"fp8e4": e4m3_type,
}


def get_x_vals():

x_vals = [(1024 * v, 1024 * v, 1024 * v) for v in range(1, 9)]
x_vals += [(4864, 4096, 8192), (9728, 8192, 65536), (4864, 8192, 4160)]
x_vals += [
(1, 1280, 8192),
(32, 1280, 8192),
(64, 1280, 8192),
(128, 1280, 8192),
(192, 1280, 8192),
(256, 1280, 8192),
(320, 1280, 8192),
(512, 1280, 8192),
(1024, 1280, 8192),
(2048, 1280, 8192),
(4096, 1280, 8192),
(8192, 1280, 8192),
(16384, 1280, 8192),
(1, 8192, 1024),
(32, 8192, 1024),
(64, 8192, 1024),
(128, 8192, 1024),
(192, 8192, 1024),
(256, 8192, 1024),
(320, 8192, 1024),
(512, 8192, 1024),
(1024, 8192, 1024),
(2048, 8192, 1024),
(4096, 8192, 1024),
(8192, 8192, 1024),
(16384, 8192, 1024),
]
return x_vals


@pytest.mark.parametrize(
"dtype, m, n, k", [(dtype, *shape) for dtype in ["bf16"] for shape in get_x_vals()]
)
def test_gemm(dtype, m, n, k):
dim = (m, n, k)
dtype = name_to_torch_types[dtype]
x = torch.randint(-20, 20, (m, k), dtype=torch.int8).cuda()
weight = torch.randint(-20, 20, (n, k), dtype=torch.int8).cuda()
x_scale = torch.rand([m, 1], dtype=torch.float32).cuda() + 1e-6
w_scale = torch.rand([1, n], dtype=torch.float32).cuda() + 1e-6
bias = torch.rand([1, n], dtype=dtype).cuda() * 10

a = run_torch(x, weight, x_scale, w_scale, bias, dtype)
b = run_triton(x, weight, x_scale, w_scale, bias, dtype)

triton.testing.assert_close(a, b, atol=0.01, rtol=1e-2)