Skip to content

Commit e65a7b4

Browse files
Add Triton MLA Decode Rope Kernel
- Moved the _fwd_kernel_stage2_asm Triton Kernel to aiter/mla.py. - Renamed decode_mla.py to mla_decode_ref.py and moved it to op_tests/triton/utils. For now, it will be used as reference to the unit tests of both the ASM and the Triton MLA decode rope implementations. - Added the Triton MLA Decode Rope and the stage2 Kernels to the mla_decode_ref.py file.
1 parent b07f750 commit e65a7b4

File tree

6 files changed

+1530
-177
lines changed

6 files changed

+1530
-177
lines changed

aiter/mla.py

+72-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,74 @@
11
# SPDX-License-Identifier: MIT
2-
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
2+
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
# user interface
55

66
import torch
77
import aiter
8+
import triton
9+
import triton.language as tl
10+
11+
12+
@triton.jit
13+
def _fwd_kernel_stage2_asm(
14+
Mid_O,
15+
Mid_lse,
16+
O,
17+
kv_indptr,
18+
stride_mid_ob,
19+
stride_mid_oh,
20+
stride_mid_os,
21+
stride_obs,
22+
stride_oh,
23+
NUM_KV_SPLITS: tl.constexpr,
24+
BLOCK_DV: tl.constexpr,
25+
Lv: tl.constexpr,
26+
mgc: tl.constexpr,
27+
):
28+
cur_batch = tl.program_id(0)
29+
cur_head = tl.program_id(1)
30+
31+
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(
32+
kv_indptr + cur_batch
33+
)
34+
35+
offs_d = tl.arange(0, BLOCK_DV)
36+
mask_d = offs_d < Lv
37+
38+
e_sum = 0.0
39+
e_max = -float("inf")
40+
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
41+
42+
offs_v = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) * Lv + offs_d
43+
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh
44+
45+
for split_kv_id in range(0, NUM_KV_SPLITS):
46+
kv_len_per_split = tl.maximum(mgc, tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS))
47+
split_kv_start = kv_len_per_split * split_kv_id
48+
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
49+
50+
if split_kv_end > split_kv_start:
51+
tv = tl.load(
52+
Mid_O + offs_v + split_kv_id * stride_mid_os * Lv,
53+
mask=mask_d,
54+
other=0.0,
55+
)
56+
tlogic = tl.load(Mid_lse + offs_logic + split_kv_id * stride_mid_os)
57+
n_e_max = tl.maximum(tlogic, e_max)
58+
59+
old_scale = tl.exp(e_max - n_e_max)
60+
acc *= old_scale
61+
exp_logic = tl.exp(tlogic - n_e_max)
62+
acc += exp_logic * tv
63+
64+
e_sum = e_sum * old_scale + exp_logic
65+
e_max = n_e_max
66+
67+
tl.store(
68+
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
69+
acc / e_sum,
70+
mask=mask_d,
71+
)
872

973

1074
def mla_decode_fwd(
@@ -29,12 +93,14 @@ def mla_decode_fwd(
2993
if num_kv_splits is None:
3094
device_properties = torch.cuda.get_device_properties(device)
3195
cu_num = device_properties.multi_processor_count
32-
num_kv_splits = min(16, max(1, cu_num//bs))
96+
num_kv_splits = min(16, max(1, cu_num // bs))
3397

3498
logits = torch.empty(
35-
(bs, num_kv_splits, nhead, v_head_dim), dtype=torch.float, device=device)
99+
(bs, num_kv_splits, nhead, v_head_dim), dtype=torch.float, device=device
100+
)
36101
attn_lse = torch.empty(
37-
(bs, num_kv_splits, nhead, 1), dtype=torch.float, device=device)
102+
(bs, num_kv_splits, nhead, 1), dtype=torch.float, device=device
103+
)
38104

39105
aiter.mla_stage1_asm_fwd(
40106
q,
@@ -47,14 +113,11 @@ def mla_decode_fwd(
47113
attn_lse,
48114
)
49115

50-
from aiter.ops.triton import decode_mla
51-
import triton
52116
Lv = v_head_dim
53117
BLOCK_DV = triton.next_power_of_2(Lv)
54118
grid = (bs, nhead)
55-
extra_kargs = {"waves_per_eu": 4,
56-
"matrix_instr_nonkdim": 16, "kpack": 2}
57-
decode_mla._fwd_kernel_stage2_asm[grid](
119+
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
120+
_fwd_kernel_stage2_asm[grid](
58121
logits,
59122
attn_lse,
60123
o,

0 commit comments

Comments
 (0)