Skip to content

Commit 45f3f3f

Browse files
authored
[ROCm][Bugfix] Ensure that the moe_wna16_gemm kernel is not built on ROCm platforms. (vllm-project#14629)
Signed-off-by: Sage Moore <[email protected]>
1 parent ff47aab commit 45f3f3f

File tree

4 files changed

+8
-3
lines changed

4 files changed

+8
-3
lines changed

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
559559
set(VLLM_MOE_EXT_SRC
560560
"csrc/moe/torch_bindings.cpp"
561561
"csrc/moe/moe_align_sum_kernels.cu"
562-
"csrc/moe/moe_wna16.cu"
563562
"csrc/moe/topk_softmax_kernels.cu")
564563

565564
set_gencode_flags_for_srcs(
@@ -574,6 +573,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
574573
SRCS "${VLLM_MOE_WNA16_SRC}"
575574
CUDA_ARCHS "${CUDA_ARCHS}")
576575

576+
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
577577
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
578578
if (MARLIN_MOE_ARCHS)
579579
set(MARLIN_MOE_SRC

csrc/moe/moe_ops.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
1818
torch::Tensor sorted_token_ids,
1919
torch::Tensor experts_ids,
2020
torch::Tensor num_tokens_post_pad);
21-
21+
#ifndef USE_ROCM
2222
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
2323
torch::Tensor b_qweight, torch::Tensor b_scales,
2424
std::optional<torch::Tensor> b_qzeros,
@@ -28,3 +28,4 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
2828
torch::Tensor num_tokens_post_pad, int64_t top_k,
2929
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
3030
int64_t BLOCK_SIZE_K, int64_t bit);
31+
#endif

csrc/moe/torch_bindings.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
3131
" Tensor! num_tokens_post_pad) -> ()");
3232
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
3333

34+
#ifndef USE_ROCM
3435
m.def(
3536
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
3637
"Tensor b_scales, Tensor? b_qzeros, "
@@ -41,7 +42,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
4142

4243
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
4344

44-
#ifndef USE_ROCM
4545
m.def(
4646
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
4747
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "

vllm/_custom_ops.py

+4
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,10 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
11461146
num_tokens_post_pad: torch.Tensor, top_k: int,
11471147
BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int,
11481148
bit: int) -> torch.Tensor:
1149+
if not current_platform.is_cuda():
1150+
raise NotImplementedError(
1151+
"The optimized moe_wna16_gemm kernel is only "
1152+
"available on CUDA platforms")
11491153
torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales,
11501154
b_qzeros, topk_weights, sorted_token_ids,
11511155
experts_ids, num_tokens_post_pad, top_k,

0 commit comments

Comments
 (0)