Skip to content

Commit 064ee66

Browse files
authored
Merge branch 'main' into lusantos/gemm_a8w8_triton
2 parents 1c2939f + b07f750 commit 064ee66

File tree

69 files changed

+4612
-371
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+4612
-371
lines changed

aiter/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .ops.quant import *
1717
from .ops.gemm_op_a8w8 import *
1818
from .ops.batched_gemm_op_a8w8 import *
19+
from .ops.batched_gemm_op_bf16 import *
1920
from .ops.aiter_operator import *
2021
from .ops.activation import *
2122
from .ops.attention import *

aiter/configs/a8w8_untuned_gemm.csv

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ M,N,K
2424
2048, 8192, 1024
2525
4096, 8192, 1024
2626
8192, 8192, 1024
27-
16384, 8192, 1024
27+
16384, 8192, 1024
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
B,M,N,K,kernelId,splitK,us,kernelName
2+
16,1,1280,8192,78,0,96.9067,bf16_batched_64x16x16x64_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2
3+
16,32,1280,8192,28,0,112.8655,bf16_batched_256x32x128x128_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3
4+
16,64,1280,8192,21,0,130.2174,bf16_batched_256x64x128x128_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3
5+
16,128,1280,8192,14,0,165.8107,bf16_batched_256x128x96x128_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3
6+
16,192,1280,8192,21,0,245.0521,bf16_batched_256x64x128x128_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3
7+
16,256,1280,8192,11,0,272.8916,bf16_batched_256x128x160x64_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3
8+
16,320,1280,8192,8,0,341.1548,bf16_batched_256x128x256x64_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3
9+
16,512,1280,8192,14,0,486.314,bf16_batched_256x128x96x128_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3
10+
16,1024,1280,8192,10,0,804.6945,bf16_batched_256x128x192x64_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3
11+
16,2048,1280,8192,41,0,1491.0997,bf16_batched_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5
12+
16,4096,1280,8192,41,0,2898.0224,bf16_batched_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5
13+
16,8192,1280,8192,8,0,5616.5567,bf16_batched_256x128x256x64_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3
14+
16,16384,1280,8192,8,0,11396.9711,bf16_batched_256x128x256x64_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3
15+
16,1,8192,1024,81,0,57.5454,bf16_batched_128x32x64x64_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2
16+
16,32,8192,1024,25,0,67.8632,bf16_batched_256x32x224x128_16x16_1x7_16x16x1_16x16x1_1x32x1x8_4x4x1_1x1_intrawave_v3
17+
16,64,8192,1024,20,0,88.4667,bf16_batched_256x64x160x128_16x16_2x5_16x16x1_16x16x1_1x64x1x4_8x8x1_2x1_intrawave_v3
18+
16,128,8192,1024,13,0,124.6653,bf16_batched_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3
19+
16,192,8192,1024,41,0,177.1559,bf16_batched_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5
20+
16,256,8192,1024,13,0,192.2976,bf16_batched_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3
21+
16,320,8192,1024,13,0,257.184,bf16_batched_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3
22+
16,512,8192,1024,13,0,340.1269,bf16_batched_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3
23+
16,1024,8192,1024,13,0,624.9993,bf16_batched_256x128x128x64_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3
24+
16,2048,8192,1024,0,0,1176.1171,bf16_batched_256x256x256x32_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4
25+
16,4096,8192,1024,0,0,2271.2554,bf16_batched_256x256x256x32_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4
26+
16,8192,8192,1024,0,0,4531.6427,bf16_batched_256x256x256x32_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4
27+
16,16384,8192,1024,0,0,8533.7636,bf16_batched_256x256x256x32_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
B,M,N,K
2+
16, 1, 1280, 8192
3+
16, 32, 1280, 8192
4+
16, 64, 1280, 8192
5+
16, 128, 1280, 8192
6+
16, 192, 1280, 8192
7+
16, 256, 1280, 8192
8+
16, 320, 1280, 8192
9+
16, 512, 1280, 8192
10+
16, 1024, 1280, 8192
11+
16, 2048, 1280, 8192
12+
16, 4096, 1280, 8192
13+
16, 8192, 1280, 8192
14+
16, 16384, 1280, 8192
15+
16, 1, 8192, 1024
16+
16, 32, 8192, 1024
17+
16, 64, 8192, 1024
18+
16, 128, 8192, 1024
19+
16, 192, 8192, 1024
20+
16, 256, 8192, 1024
21+
16, 320, 8192, 1024
22+
16, 512, 8192, 1024
23+
16, 1024, 8192, 1024
24+
16, 2048, 8192, 1024
25+
16, 4096, 8192, 1024
26+
16, 8192, 8192, 1024
27+
16, 16384, 8192, 1024

aiter/configs/tuned_gemm.csv

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
M,N,K,bias,dtype,outdtype,libtype,solidx,soltimes
1+
M,N,K,bias,dtype,outdtype,scaleAB,libtype,solidx,soltimes,kernelName

aiter/fused_moe_bf16_asm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def asm_moe(hidden_states,
7171
sorted_weights, sorted_expert_ids, num_valid_ids, topk)
7272
elif a16:
7373
# a16w8 smooth quant fmoe
74-
if w1.dtype == torch.float8_e4m3fnuz and inter_dim*2 == w1.shape[1]:
75-
aiter.fmoe_fp8_g1u1_a16(moe_buf, hidden_states, w1, w2, sorted_ids,
74+
if w1.dtype in [torch.float8_e4m3fnuz, torch.int8] and inter_dim*2 == w1.shape[1]:
75+
aiter.fmoe_g1u1_a16(moe_buf, hidden_states, w1, w2, sorted_ids,
7676
sorted_weights, sorted_expert_ids, num_valid_ids,
7777
topk,
7878
fc1_scale,

aiter/jit/core.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def build_module(md_name, srcs, flags_extra_cc, flags_extra_hip, blob_gen_cmd, e
177177
"-Wno-switch-bool",
178178
"-Wno-vla-cxx-extension",
179179
"-Wno-undefined-func-template",
180-
180+
"-Wno-macro-redefined",
181181
"-fgpu-flush-denormals-to-zero",
182182
]
183183

@@ -252,7 +252,7 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources):
252252
md_name,
253253
'-->'.join(traceback.format_exception(*sys.exc_info()))
254254
))
255-
sys.exit()
255+
raise Exception(f"failed build jit [{md_name}]...")
256256
logger.info(
257257
f'finish build [{md_name}], cost {time.perf_counter()-startTS:.8f}s')
258258
return module

aiter/jit/optCompilerConfig.json

+41-7
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,19 @@
115115
"verbose": "False",
116116
"blob_gen_cmd": "''"
117117
},
118+
"module_batched_gemm_bf16": {
119+
"srcs": [
120+
"f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/include'",
121+
"f'{AITER_CSRC_DIR}/pybind/batched_gemm_bf16_pybind.cu'",
122+
"f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/batched_gemm_bf16.cu'"
123+
],
124+
"flags_extra_cc": [],
125+
"flags_extra_hip": [],
126+
"extra_ldflags": "None",
127+
"extra_include": [],
128+
"verbose": "False",
129+
"blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/gen_instances.py --working_path {{}} --tune_file {AITER_CORE_DIR}/aiter/configs/bf16_tuned_batched_gemm.csv'"
130+
},
118131
"module_batched_gemm_a8w8": {
119132
"srcs": [
120133
"f'{AITER_CSRC_DIR}/ck_batched_gemm_a8w8/include'",
@@ -166,6 +179,18 @@
166179
"verbose": "False",
167180
"blob_gen_cmd": "''"
168181
},
182+
"module_gemm_a8w8_blockscale_asm": {
183+
"srcs": [
184+
"f'{AITER_CSRC_DIR}/py_itfs_cu/asm_flatmm_a8w8_blockscale.cpp'",
185+
"f'{AITER_CSRC_DIR}/pybind/flatmm_a8w8_blockscale_asm_pybind.cu'"
186+
],
187+
"flags_extra_cc": [],
188+
"flags_extra_hip": [],
189+
"extra_ldflags": "None",
190+
"extra_include": [],
191+
"verbose": "False",
192+
"blob_gen_cmd": "''"
193+
},
169194
"module_moe_asm": {
170195
"srcs": [
171196
"f'{AITER_CSRC_DIR}/pybind/moe_op_pybind.cu'",
@@ -284,6 +309,19 @@
284309
"verbose": "False",
285310
"blob_gen_cmd": "''"
286311
},
312+
"module_batched_gemm_bf16_tune": {
313+
"srcs": [
314+
"f'{AITER_CSRC_DIR}/pybind/batched_gemm_bf16_tune_pybind.cu'",
315+
"f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/batched_gemm_bf16_tune.cu'",
316+
"f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/include'"
317+
],
318+
"flags_extra_cc": [],
319+
"flags_extra_hip": [],
320+
"extra_ldflags": "None",
321+
"extra_include": [],
322+
"verbose": "False",
323+
"blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/gen_instances.py --working_path {{}} --tune'"
324+
},
287325
"module_batched_gemm_a8w8_tune": {
288326
"srcs": [
289327
"f'{AITER_CSRC_DIR}/pybind/batched_gemm_a8w8_tune_pybind.cu'",
@@ -480,15 +518,13 @@
480518
"f'{AITER_GRADLIB_DIR}/csrc/rocsolgemm.cu'"
481519
],
482520
"flags_extra_cc": [
483-
"'-O3'",
484-
"'-DLEGACY_HIPBLAS_DIRECT=ON'"
521+
"'-O3'"
485522
],
486523
"flags_extra_hip": [
487524
"'-O3'",
488525
"'-U__CUDA_NO_HALF_OPERATORS__'",
489526
"'-U__CUDA_NO_HALF_CONVERSIONS__'",
490-
"'-ftemplate-depth=1024'",
491-
"'-DLEGACY_HIPBLAS_DIRECT=ON'"
527+
"'-ftemplate-depth=1024'"
492528
],
493529
"extra_ldflags": "None",
494530
"extra_include": [
@@ -502,15 +538,13 @@
502538
"f'{AITER_GRADLIB_DIR}/csrc/hipbsolgemm.cu'"
503539
],
504540
"flags_extra_cc": [
505-
"'-O3'",
506-
"'-DLEGACY_HIPBLAS_DIRECT=ON'"
541+
"'-O3'"
507542
],
508543
"flags_extra_hip": [
509544
"'-O3'",
510545
"'-U__CUDA_NO_HALF_OPERATORS__'",
511546
"'-U__CUDA_NO_HALF_CONVERSIONS__'",
512547
"'-ftemplate-depth=1024'",
513-
"'-DLEGACY_HIPBLAS_DIRECT=ON'",
514548
"'-DENABLE_TORCH_FP8' if hasattr(torch, 'float8_e4m3fnuz') else '' "
515549
],
516550
"extra_ldflags": "None",

aiter/ops/batched_gemm_op_bf16.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
import torch
5+
from torch import Tensor
6+
from typing import List, Optional
7+
import functools
8+
import pandas as pd
9+
from ..jit.core import compile_ops, CK_DIR, AITER_CSRC_DIR, AITER_ROOT_DIR, AITER_CORE_DIR
10+
11+
12+
@compile_ops("module_batched_gemm_bf16", fc_name="batched_gemm_bf16")
13+
def batched_gemm_bf16(
14+
XQ: Tensor,
15+
WQ: Tensor,
16+
out: Tensor,
17+
bias: Optional[Tensor] = None,
18+
splitK = 0
19+
): ...
20+
21+
22+
@functools.lru_cache(maxsize=1024)
23+
def compute_batched_gemm_SplitK(
24+
M: int,
25+
N: int,
26+
K: int,
27+
tile_m: int,
28+
tile_n: int,
29+
tile_k: int):
30+
31+
device_properties = torch.cuda.get_device_properties(0)
32+
cu_num = device_properties.multi_processor_count
33+
tile_num = ((M + tile_m - 1) // tile_m) * ((N + tile_n - 1) // tile_n)
34+
cusPerTile = cu_num / tile_num
35+
splitK = 0
36+
while( cusPerTile >= pow(2, splitK+1) and (pow(2, splitK+1) * tile_k) < 2 * K):
37+
splitK += 1
38+
return splitK
39+
40+
41+
@functools.lru_cache(maxsize=1024)
42+
def get_CKBatchedGEMM_config(
43+
B: int,
44+
M: int,
45+
N: int,
46+
K: int,
47+
):
48+
if not hasattr(get_CKBatchedGEMM_config, "ck_batched_gemm_dict"):
49+
ck_batched_gemm_dict = pd.read_csv(f"{AITER_CORE_DIR}/aiter/configs/bf16_tuned_batched_gemm.csv").drop_duplicates()
50+
get_CKBatchedGEMM_config.ck_batched_gemm_dict = ck_batched_gemm_dict.set_index(['B','M','N','K']).to_dict('index')
51+
config = get_CKBatchedGEMM_config.ck_batched_gemm_dict.get((B,M,N,K), None)
52+
if config != None:
53+
mnk = config['kernelName'].split('_')[2].split('x')[1:]
54+
config["tile_m"] = int(mnk[0])
55+
config["tile_n"] = int(mnk[1])
56+
config["tile_k"] = int(mnk[2])
57+
return config
58+
59+
def batched_gemm_bf16_CK(
60+
XQ: Tensor,
61+
WQ: Tensor,
62+
bias: Optional[Tensor] = None,
63+
dtype=torch.bfloat16,
64+
splitK: Optional[int] = None
65+
):
66+
assert dtype in [
67+
torch.bfloat16,
68+
torch.float16,
69+
], f"Output {dtype=} is currently not supported in batched_gemm_bf16"
70+
71+
b = XQ.shape[0]
72+
m = XQ.shape[1]
73+
n = WQ.shape[1]
74+
k = XQ.shape[2]
75+
ck_config = get_CKBatchedGEMM_config(b, m, n, k)
76+
if splitK == None:
77+
if ck_config != None:
78+
splitK = ck_config['splitK']
79+
else:
80+
splitK = 0
81+
Y = torch.empty(b, m, n, dtype=dtype, device=XQ.device)
82+
return batched_gemm_bf16(XQ, WQ, Y, bias, splitK)
83+
84+
@compile_ops("module_batched_gemm_bf16_tune",fc_name="batched_gemm_bf16_tune")
85+
def batched_gemm_bf16_tune(
86+
XQ: Tensor,
87+
WQ: Tensor,
88+
out: Tensor,
89+
kernelId: int,
90+
splitK = 0
91+
): ...

aiter/ops/gemm_op_a8w8.py

+24
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ def gemm_a8w8_blockscale(
4747
out: Tensor,
4848
): ...
4949

50+
@compile_ops("module_gemm_a8w8_blockscale_asm", fc_name="flatmm_a8w8_blockscale_asm")
51+
def flatmm_a8w8_blockscale_asm(
52+
XQ: Tensor,
53+
WQ: Tensor,
54+
x_scale: Tensor,
55+
w_scale: Tensor,
56+
out: Tensor,
57+
): ...
58+
5059
@functools.lru_cache(maxsize=1024)
5160
def compute_gemm_SplitK(
5261
M: int,
@@ -176,6 +185,21 @@ def gemm_a8w8_blockscale_CK(
176185
Y = torch.empty(m, n, dtype=dtype, device=XQ.device)
177186
return gemm_a8w8_blockscale(XQ, WQ, x_scale, w_scale, Y)
178187

188+
def flatmm_a8w8_blockscale_ASM(
189+
XQ: Tensor,
190+
WQ: Tensor,
191+
x_scale: Tensor,
192+
w_scale: Tensor,
193+
dtype=torch.float16,
194+
):
195+
assert dtype in [
196+
torch.float16,
197+
], f"Output {dtype=} is currently not supported in gemm_a8w8"
198+
m = XQ.shape[0]
199+
n = WQ.shape[0]
200+
k = XQ.shape[-1]
201+
Y = torch.empty(m, n, dtype=dtype, device=XQ.device)
202+
return flatmm_a8w8_blockscale_asm(XQ, WQ, x_scale, w_scale, Y)
179203

180204
@compile_ops("module_gemm_a8w8_tune",fc_name="gemm_a8w8_tune")
181205
def gemm_a8w8_tune(

0 commit comments

Comments
 (0)