1
1
# SPDX-License-Identifier: MIT
2
- # Copyright (c ) 2024, Advanced Micro Devices, Inc. All rights reserved.
2
+ # Copyright (C ) 2024-2025 , Advanced Micro Devices, Inc. All rights reserved.
3
3
4
4
# user interface
5
5
6
6
import torch
7
7
import aiter
8
+ import triton
9
+ import triton .language as tl
10
+
11
+
12
+ @triton .jit
13
+ def _fwd_kernel_stage2_asm (
14
+ Mid_O ,
15
+ Mid_lse ,
16
+ O ,
17
+ kv_indptr ,
18
+ stride_mid_ob ,
19
+ stride_mid_oh ,
20
+ stride_mid_os ,
21
+ stride_obs ,
22
+ stride_oh ,
23
+ NUM_KV_SPLITS : tl .constexpr ,
24
+ BLOCK_DV : tl .constexpr ,
25
+ Lv : tl .constexpr ,
26
+ mgc : tl .constexpr ,
27
+ ):
28
+ cur_batch = tl .program_id (0 )
29
+ cur_head = tl .program_id (1 )
30
+
31
+ cur_batch_seq_len = tl .load (kv_indptr + cur_batch + 1 ) - tl .load (
32
+ kv_indptr + cur_batch
33
+ )
34
+
35
+ offs_d = tl .arange (0 , BLOCK_DV )
36
+ mask_d = offs_d < Lv
37
+
38
+ e_sum = 0.0
39
+ e_max = - float ("inf" )
40
+ acc = tl .zeros ([BLOCK_DV ], dtype = tl .float32 )
41
+
42
+ offs_v = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh ) * Lv + offs_d
43
+ offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh
44
+
45
+ for split_kv_id in range (0 , NUM_KV_SPLITS ):
46
+ kv_len_per_split = tl .maximum (mgc , tl .cdiv (cur_batch_seq_len , NUM_KV_SPLITS ))
47
+ split_kv_start = kv_len_per_split * split_kv_id
48
+ split_kv_end = tl .minimum (split_kv_start + kv_len_per_split , cur_batch_seq_len )
49
+
50
+ if split_kv_end > split_kv_start :
51
+ tv = tl .load (
52
+ Mid_O + offs_v + split_kv_id * stride_mid_os * Lv ,
53
+ mask = mask_d ,
54
+ other = 0.0 ,
55
+ )
56
+ tlogic = tl .load (Mid_lse + offs_logic + split_kv_id * stride_mid_os )
57
+ n_e_max = tl .maximum (tlogic , e_max )
58
+
59
+ old_scale = tl .exp (e_max - n_e_max )
60
+ acc *= old_scale
61
+ exp_logic = tl .exp (tlogic - n_e_max )
62
+ acc += exp_logic * tv
63
+
64
+ e_sum = e_sum * old_scale + exp_logic
65
+ e_max = n_e_max
66
+
67
+ tl .store (
68
+ O + cur_batch * stride_obs + cur_head * stride_oh + offs_d ,
69
+ acc / e_sum ,
70
+ mask = mask_d ,
71
+ )
8
72
9
73
10
74
def mla_decode_fwd (
@@ -29,12 +93,14 @@ def mla_decode_fwd(
29
93
if num_kv_splits is None :
30
94
device_properties = torch .cuda .get_device_properties (device )
31
95
cu_num = device_properties .multi_processor_count
32
- num_kv_splits = min (16 , max (1 , cu_num // bs ))
96
+ num_kv_splits = min (16 , max (1 , cu_num // bs ))
33
97
34
98
logits = torch .empty (
35
- (bs , num_kv_splits , nhead , v_head_dim ), dtype = torch .float , device = device )
99
+ (bs , num_kv_splits , nhead , v_head_dim ), dtype = torch .float , device = device
100
+ )
36
101
attn_lse = torch .empty (
37
- (bs , num_kv_splits , nhead , 1 ), dtype = torch .float , device = device )
102
+ (bs , num_kv_splits , nhead , 1 ), dtype = torch .float , device = device
103
+ )
38
104
39
105
aiter .mla_stage1_asm_fwd (
40
106
q ,
@@ -47,14 +113,11 @@ def mla_decode_fwd(
47
113
attn_lse ,
48
114
)
49
115
50
- from aiter .ops .triton import decode_mla
51
- import triton
52
116
Lv = v_head_dim
53
117
BLOCK_DV = triton .next_power_of_2 (Lv )
54
118
grid = (bs , nhead )
55
- extra_kargs = {"waves_per_eu" : 4 ,
56
- "matrix_instr_nonkdim" : 16 , "kpack" : 2 }
57
- decode_mla ._fwd_kernel_stage2_asm [grid ](
119
+ extra_kargs = {"waves_per_eu" : 4 , "matrix_instr_nonkdim" : 16 , "kpack" : 2 }
120
+ _fwd_kernel_stage2_asm [grid ](
58
121
logits ,
59
122
attn_lse ,
60
123
o ,
0 commit comments