Skip to content

Testx #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 145 commits into
base: main
Choose a base branch
from
Open

Testx #216

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
145 commits
Select commit Hold shift + click to select a range
9ab4df0
add ck moe stage1
valarLip Feb 6, 2025
84c8164
fix ref implement
valarLip Feb 6, 2025
a8ca6a0
calculate tflops
valarLip Feb 6, 2025
63af9f4
cherrypick int4 moe and merge
Jan 27, 2025
5c31867
enable hipblas legacy
HaiShaw Jan 29, 2025
89c771b
fix initial sglang integration error
HaiShaw Jan 30, 2025
4550df1
Fix PREBUILD_KERNELS errors, comment some missing header files
Feb 3, 2025
92e6623
cherrypick int32 for int4 and merge
Feb 4, 2025
80be63d
int4moe silu to gelu
Feb 7, 2025
74ee788
fix build
Feb 9, 2025
a0f270b
debuging coredump
Feb 9, 2025
caf8b1f
add torch_moe stage 2
charlifu Feb 13, 2025
a8a5293
pass torch_moe_stage2
valarLip Feb 13, 2025
1c02b7c
add debug version
valarLip Feb 14, 2025
9d7c5e1
update
valarLip Feb 14, 2025
dc7073c
rename num_tokens_post_padded -> num_valid_ids
valarLip Feb 14, 2025
6238087
fix moe_stage1
valarLip Feb 14, 2025
375cc35
update test
valarLip Feb 14, 2025
9c10f56
Merge remote-tracking branch 'origin/main' into ck_moe_2stage
junhaha666 Feb 14, 2025
5014832
fix asm_moe.cpp
junhaha666 Feb 14, 2025
7493e51
add template ck moe gemm
junhaha666 Feb 14, 2025
f50725c
merge code...
valarLip Feb 14, 2025
ec4cb82
disable stage2
valarLip Feb 14, 2025
4c749e7
compile pass with dev/ck_moe_gemm_hotfix
valarLip Feb 14, 2025
66fb030
add fp8 quant case
valarLip Feb 14, 2025
c210603
debug moe stage2
Feb 15, 2025
6e8491d
update
valarLip Feb 15, 2025
213d5f6
fix
Feb 15, 2025
81e32a5
Merge branch 'main' into ck_moe_2stage
valarLip Feb 15, 2025
3d8c531
add ck_moe_2stages
Feb 15, 2025
f1e2752
update
valarLip Feb 15, 2025
e43e16f
do CK copy and rm in mainprocess only
valarLip Feb 16, 2025
3fe47da
fix quant
junhaha666 Feb 17, 2025
ed698e2
fix quant2
junhaha666 Feb 17, 2025
1aeb356
use new ck, slight opt perf
Feb 17, 2025
a2d9da2
Move moe_ck_2stages out of module_moe
junhaha666 Feb 17, 2025
67c0e99
Merge branch 'main' into ck_moe_2stage
valarLip Feb 18, 2025
3fc08ee
add int4 W to moe stage2
amd-zfyu Feb 20, 2025
fe5bebd
update fucntion issues:int4
amd-zfyu Feb 20, 2025
f17baac
gemm1 int add
amd-zfyu Feb 20, 2025
91aa062
update gemm1 port
amd-zfyu Feb 20, 2025
1722f3c
fix bugs
amd-zfyu Feb 20, 2025
6dee16d
format code
amd-zfyu Feb 21, 2025
341176e
clean py test debug code
amd-zfyu Feb 21, 2025
6399993
fix gemm2 port issue
amd-zfyu Feb 21, 2025
fbcb245
update BK1 in GEMM1&2 Port
amd-zfyu Feb 21, 2025
10a575c
fix coredump
Feb 23, 2025
187d760
refine code, use int4 and fp8 tests together
Feb 23, 2025
fe3bbad
merge main
Feb 23, 2025
3988ece
fix dtype mismatch in torch test
Feb 23, 2025
56f962d
fix ut err
Feb 23, 2025
742b97d
cleanup useless codes
Feb 23, 2025
becb1f6
refine codes in pr
Feb 23, 2025
284aef6
revert ck branc
Feb 24, 2025
6e62a8f
add new int4 optimized kernel
Feb 24, 2025
aca940b
change unitest conf
Feb 24, 2025
2846661
merge main
Feb 24, 2025
882b276
fix merge err
Feb 24, 2025
80de9ce
add new int4 optimized kernel
Feb 24, 2025
f9c65cb
change unitest conf
Feb 24, 2025
52ec7f5
fix merge err
Feb 24, 2025
14cd92e
change shuffle logic
Feb 24, 2025
46cc594
change select tile logic for f8a16
Feb 24, 2025
15da23c
fix typo
Feb 24, 2025
08856d6
merge fix from felix: shuffle+quant
amd-zfyu Feb 24, 2025
ad08d0b
change tile strategy
Feb 25, 2025
3c3d3b8
merge main
Feb 25, 2025
0a854f5
change select tile logic int4
Feb 25, 2025
f7b5b90
fix typo
Feb 25, 2025
0c29953
Sync paged_attention_rocm() changes from 1b6ab3ce
poyenc Feb 25, 2025
f2e5481
Add kv_indptr, kv_page_indices, kv_last_page_lens params
poyenc Feb 14, 2025
5907ae8
Use single workspace buffer to accommodate intermediate tensors
poyenc Feb 18, 2025
8dcf62c
ignore kv_last_page_lens when page size = 1 (#126)
fsx950223 Feb 19, 2025
2f62813
merge upstream main
Feb 25, 2025
b82aae1
change tile config
Feb 26, 2025
e1b5f65
MulABScale->MulABScaleWin4(with out*16)
amd-zfyu Feb 26, 2025
ff77f8b
CK moe stage 1 passed with input&W == 1
amd-zfyu Feb 26, 2025
af3cda6
CK MOE Stage one PASS with randn input
amd-zfyu Feb 26, 2025
edcc5b2
fix int4 no smooth
Feb 27, 2025
b7c5dcc
fix bugs for stage2 : skip mblk32 for gemm2
amd-zfyu Feb 28, 2025
c540f32
remove smooth buffer laoding
Feb 28, 2025
54027cb
Merge remote-tracking branch 'origin/main' into testx
junhaha666 Feb 28, 2025
f457564
add fused_moe_api : aiter_moe
junhaha666 Feb 28, 2025
e6601be
add back mask
Mar 1, 2025
738250d
ä½revert testing codes
shengnxu Mar 1, 2025
d5979c8
fix gemm2 bugs: randn input pass~
amd-zfyu Mar 3, 2025
b1c6dbf
add CK 2 stage merge run wint4
amd-zfyu Mar 4, 2025
bec0d21
gemm tune: CShuffleMXDLPerWave 4->1
amd-zfyu Mar 4, 2025
47b38f9
sync input data size to ASM version(w1 = w1/10)
amd-zfyu Mar 5, 2025
3cb87f7
update ck
Mar 5, 2025
8a230e3
fix interface mismatch
valarLip Mar 5, 2025
e43a9b1
update torch stages
valarLip Mar 5, 2025
3301579
int4 ck impl per token per channel quant ok
Mar 6, 2025
296eb00
fix bf16 and blocksize 32
Mar 6, 2025
2cedaf0
add Mblk select func: token>128? 128 : 32
amd-zfyu Mar 7, 2025
d7d320b
use hip quant and recover per tensor quant instance
junhaha666 Mar 7, 2025
5f76725
fix
junhaha666 Mar 7, 2025
f3bba95
use int64_t as scatter idx
Mar 10, 2025
e70ee4d
revert tile switch logic
junhaha666 Mar 10, 2025
f1dc7ed
Merge remote-tracking branch 'origin/main' into ck_moe_2stage_int4
junhaha666 Mar 12, 2025
d0d7f87
update ck and support activation type
Mar 12, 2025
311249a
fix ut accurate due to act type
Mar 12, 2025
4f151db
fix typo
Mar 12, 2025
7899c18
update ck
Mar 13, 2025
fa9da09
fix ck dst oob
Mar 14, 2025
f3ce6ac
use uint for 23w
Mar 17, 2025
b78248a
remove torch.zeros
Mar 17, 2025
cf938a2
Revert "remove torch.zeros"
Mar 17, 2025
b8ad779
fix sorting output
Mar 17, 2025
01b40ab
updated MPerBlock selection logic for int4 2stage moe
Mar 17, 2025
4f9f890
merge main
Mar 18, 2025
eeedb5a
update ck fix build
Mar 18, 2025
35b778d
merge testx
Mar 18, 2025
bdbbd07
merge testx and ut ok
Mar 18, 2025
06d7055
merge
Mar 18, 2025
4ab9e82
unify act type
Mar 18, 2025
8220d56
fix prebuild err in sampling
Mar 18, 2025
0463e18
fix miss file
Mar 18, 2025
79c8f79
rm sampling in prebuild
Mar 18, 2025
62af8b1
temp rm sampling
Mar 18, 2025
a9e2d00
Revert "add tree_speculative_sampling_target_only (#168)"
valarLip Mar 18, 2025
2b0bbfb
add more tesets
Mar 18, 2025
bca85f3
fix fp8 per token quant
Mar 18, 2025
8d09eb2
merge main
valarLip Mar 18, 2025
455b993
prebuild pass
valarLip Mar 18, 2025
db0c6d8
typo
valarLip Mar 18, 2025
6634e55
merge prebuild fix
Mar 18, 2025
582c6e9
support 128x
Mar 20, 2025
3fd21ba
run fp8
lalala-sh Apr 2, 2025
0f9707d
update
amd-ruitang3 Apr 3, 2025
01ea661
fix moe gemm12
Apr 3, 2025
196aa36
3rdparty/composable_kernel
Apr 3, 2025
5c06d97
enable act switch
amd-ruitang3 Apr 3, 2025
1427c49
int4 gemm2 ok
amd-ruitang3 Apr 3, 2025
1ed44b4
update ck
lalala-sh Apr 3, 2025
f4a2f46
fp8 16x16 ok
Apr 3, 2025
af5ace6
fix int4
Apr 4, 2025
9144522
fix fused int4 2stage
Apr 4, 2025
6d23877
add llvm opt
Apr 4, 2025
41edcc9
remove useless comments
lalala-sh Apr 8, 2025
bb7dd0d
fix moe fused act in dtype:fp8, i4
lalala-sh Apr 8, 2025
17f1f01
fix act switch for moe i4
lalala-sh Apr 8, 2025
d302d02
remove kernel_suffix: compatible to triton 3.3
HaiShaw Apr 8, 2025
4e54917
Merge pull request #275 from HaiShaw/kernel_suffix
HaiShaw Apr 8, 2025
d9c1a73
update ck,fix bug
Apr 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel updated 53 files
+3 −3 CHANGELOG.md
+52 −7 Jenkinsfile
+3 −0 client_example/10_grouped_convnd_bwd_data/CMakeLists.txt
+3 −3 client_example/10_grouped_convnd_bwd_data/README.md
+205 −0 client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp
+6 −6 example/01_gemm/CMakeLists.txt
+10 −1 example/09_convnd_fwd/CMakeLists.txt
+18 −2 example/15_grouped_gemm/run_grouped_gemm_example.inc
+6 −0 example/65_gemm_multiply_multiply/CMakeLists.txt
+61 −57 example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp
+39 −73 example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp
+52 −33 example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp
+12 −9 example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp
+1 −1 example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
+621 −0 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp
+573 −0 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp
+99 −42 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
+22 −3 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
+3 −1 include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp
+7 −7 include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp
+397 −40 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+7 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+19 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+25 −2 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+26 −3 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+61 −9 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
+49 −8 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+51 −9 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+12 −10 include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
+4 −3 include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
+282 −100 include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp
+4 −3 include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp
+15 −37 include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
+19 −13 include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp
+32 −19 include/ck/utility/dynamic_buffer.hpp
+7 −0 include/ck/utility/tuple_helper.hpp
+10 −6 include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+74 −15 library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp
+144 −0 ...tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp
+61 −1 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
+91 −0 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc
+3 −0 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt
+48 −0 ...ance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp
+48 −0 ...tance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
+48 −0 ...tance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
+3 −0 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt
+49 −0 ...e/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp
+49 −0 ...ce/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp
+49 −0 ...ce/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp
+6 −1 profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp
+33 −1 profiler/src/profile_grouped_conv_bwd_data.cpp
+2 −1 script/convert_miopen_driver_to_profiler.py
+7 −1 test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp
2 changes: 0 additions & 2 deletions aiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
from .ops.rope import *
from .ops.topk import *
from .ops.mha import *
from .ops.speculative_sampling import *
from .ops.eagle_utils import *
from .ops.gradlib import *
from .aot.norm import *
from . import mla
Expand Down
10 changes: 8 additions & 2 deletions aiter/aot/triton_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import List

import triton
from triton.compiler.code_generator import kernel_suffix
from triton.backends.amd.driver import ty_to_cpp

desc = """
Expand Down Expand Up @@ -104,7 +103,14 @@ def constexpr(s):
arg_types += [signature[i]]

# dump C stub code
suffix = kernel_suffix(signature.values(), attrs)
suffix = ''
for i, ty in enumerate(signature.values()):
suffix += str(i)
if hints.get((i, ), None) == 1:
suffix += 'c'
if hints.get((i, ), None) == 16:
suffix += 'd'

func_name = '_'.join([out_name, sig_hash, suffix])

hex_ = binascii.hexlify(ccinfo.asm["hsaco"]).decode('utf-8')
Expand Down
95 changes: 95 additions & 0 deletions aiter/fused_moe_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

import torch
import torch.nn.functional as F
import numpy as np
import sys
import os
from typing import Any, Callable, Dict, Optional, Tuple
import aiter
from aiter import logger
from aiter.fused_moe_bf16_asm import asm_moe


def aiter_moe(hidden_states, # not quant
w1, # [expert(local_expert:EP), inter_dim*2, dim] N,K
w2, # [expert(local_expert:EP), dim, inter_dim]
topk_weight, topk_ids,
# following for int8 quant
fc1_scale=None, # [expert(local_expert:EP), inter_dim, 1]
fc2_scale=None, # [expert(local_expert:EP), model_dim, 1]
fc1_smooth_scale=None, # [expert(local_expert:EP), 1, model_dim]
fc2_smooth_scale=None, # [expert(local_expert:EP), 1, inter_dim]
a16=False,
acitvation=None,
per_tensor_quant_scale=None,
block_shape=None,
expert_mask=None,
):
useInt4Weight = True if w1.dtype in [torch.int32, torch.uint32] else False
lastdim_mul = 8 if useInt4Weight else 1
g1u1 = True if w1.shape[1] == w2.shape[2] * 2 * lastdim_mul else False
dtype = hidden_states.dtype
if acitvation is None:
acitvation = 'silu' if g1u1 else 'gelu'
assert acitvation in ['silu', 'gelu'], "aiter moe only support silu and gelu activation,\
by default, 'silu' is used for g1u1 and 'gelu' is used for g1u0"

if a16 == True:
assert dtype == torch.bfloat16, "aiter a16 asm_moe only support bfloat16 hidden_states"
assert w2.shape[2] % 512 == 0 or w2.shape[2] % 320 == 0, "aiter a16 asm_moe only support w2.shape[2] % 512 == 0 or w2.shape[2] % 320 == 0"
assert (g1u1 and w1.dtype == torch.float8_e4m3fnuz) or (not g1u1 and w1.dtype ==
torch.int8), "aiter a16 asm_moe only support g1u1 with fp8 or g1u0 with int8"
assert fc1_smooth_scale is not None and fc2_smooth_scale is not None, "aiter a16 asm_moe need smoothquant(per channel)"
assert fc1_scale is not None and fc2_scale is not None, "aiter a16 asm_moe need w_scale(per channel)"
assert per_tensor_quant_scale is None, "aiter a16 asm_moe not support per_tensor_quant_scale"
return asm_moe(hidden_states, w1, w2, topk_weight, topk_ids, fc1_scale, fc2_scale,
fc1_smooth_scale, fc2_smooth_scale, True, None, expert_mask=expert_mask)

elif useInt4Weight:
assert dtype == torch.bfloat16, "aiter a8wint4 asm_moe only support bfloat16 hidden_states"
assert g1u1, "aiter a8wint4 asm_moe only support g1u1"
assert fc1_smooth_scale is None and fc2_smooth_scale is None, "aiter a8wint4 asm_moe not support smoothquant"
return asm_moe(hidden_states, w1, w2, topk_weight, topk_ids, fc1_scale, fc2_scale,
fc1_smooth_scale, fc2_smooth_scale, False, per_tensor_quant_scale, expert_mask=expert_mask, activation=acitvation)

elif block_shape is not None:
assert dtype == torch.bfloat16, "aiter moe for block_scale only support bfloat16 hidden_states"
assert block_shape == (
128, 128), "aiter moe for block_scale only support (128, 128)"
assert fc1_smooth_scale is None and fc2_smooth_scale is None, "aiter moe for block_scale not support smoothquant"
assert per_tensor_quant_scale is None, "aiter moe for block_scale not support per_tensor_quant_scale"
assert g1u1, "aiter moe for block_scale only support g1u1"
assert acitvation == 'silu', "aiter moe for block_scale only support silu acitvation"
return asm_moe(hidden_states, w1, w2, topk_weight, topk_ids, fc1_scale, fc2_scale,
fc1_smooth_scale, fc2_smooth_scale, False, None, block_shape=block_shape, expert_mask=expert_mask)

elif fc1_smooth_scale is not None and fc2_smooth_scale is not None and w1.dtype in [torch.float8_e4m3fnuz, torch.int8]:
assert dtype == torch.bfloat16, "aiter asm_moe for smoothquant only support bfloat16 hidden_states"
if g1u1:
assert acitvation == 'silu', "aiter asm_moe for g1u1 smoothquant only support silu acitvation"
else:
assert acitvation == 'gelu', "aiter asm_moe for g1u0 smoothquant only support gelu acitvation"
assert g1u1 or (not g1u1 and w1.dtype ==
torch.int8), "aiter asm_moe for smoothquant not support g1u0 fp8 smoothquant"
return asm_moe(hidden_states, w1, w2, topk_weight, topk_ids, fc1_scale, fc2_scale,
fc1_smooth_scale, fc2_smooth_scale, False, per_tensor_quant_scale, expert_mask=expert_mask)

elif fc1_smooth_scale is None and fc2_smooth_scale is None and w1.dtype in [torch.float8_e4m3fnuz, torch.int8]:
assert dtype == torch.bfloat16, "aiter asm_moe for fp8/int8 quant only support bfloat16 hidden_states"
assert g1u1, "aiter asm_moe for fp8/int8 quant only support g1u1"
assert acitvation == 'silu', "aiter asm_moe for fp8/int8 quant only support silu acitvation"
return asm_moe(hidden_states, w1, w2, topk_weight, topk_ids, fc1_scale, fc2_scale,
fc1_smooth_scale, fc2_smooth_scale, False, per_tensor_quant_scale, expert_mask=expert_mask)

elif fc1_scale is None and fc2_scale is None:
assert fc1_smooth_scale is None and fc2_smooth_scale is None, "aiter moe for no quant not support smoothquant"
assert per_tensor_quant_scale is None, "aiter moe for no quant not support per_tensor_quant_scale"
if not g1u1 and acitvation == 'gelu':
return asm_moe(hidden_states, w1, w2, topk_weight, topk_ids, fc1_scale, fc2_scale,
fc1_smooth_scale, fc2_smooth_scale, False, per_tensor_quant_scale, expert_mask=expert_mask)
else:
block_m = 32
return aiter.ck_moe(hidden_states, w1, w2, topk_weight, topk_ids, fc1_scale, fc2_scale,
fc1_smooth_scale, fc2_smooth_scale, block_m, expert_mask, acitvation)
Loading