Skip to content

Commit debd6bb

Browse files
authored
[Kernel] Add ModelOpt FP4 Checkpoint Support (vllm-project#12520)
Signed-off-by: Pavani Majety <[email protected]>
1 parent 5c538c3 commit debd6bb

File tree

10 files changed

+388
-30
lines changed

10 files changed

+388
-30
lines changed

csrc/ops.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,16 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
160160
int64_t ggml_moe_get_block_size(int64_t type);
161161

162162
#ifndef USE_ROCM
163+
164+
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
165+
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
166+
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
167+
163168
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
164169
torch::Tensor const& B, torch::Tensor const& A_sf,
165170
torch::Tensor const& B_sf,
166171
torch::Tensor const& alpha);
167172

168-
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
169-
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
170-
171173
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
172174
torch::Tensor const& b, torch::Tensor const& a_scales,
173175
torch::Tensor const& b_scales,

csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu

+6
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
3636
"be compiled using CUDA 12.8 and target "
3737
"compute capability 100 or above.");
3838
}
39+
40+
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
41+
int runtimeVersion;
42+
cudaRuntimeGetVersion(&runtimeVersion);
43+
return cuda_device_capability >= 100 && runtimeVersion >= 12080;
44+
}

csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu

+4-3
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,11 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
201201
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
202202

203203
#define CHECK_TYPE(x, st, m) \
204-
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
205-
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
204+
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
205+
#define CHECK_TH_CUDA(x, m) \
206+
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
206207
#define CHECK_CONTIGUOUS(x, m) \
207-
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
208+
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
208209
#define CHECK_INPUT(x, st, m) \
209210
CHECK_TH_CUDA(x, m); \
210211
CHECK_CONTIGUOUS(x, m); \

csrc/torch_bindings.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
434434
" Tensor! output_scale, Tensor input_scale) -> ()");
435435
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
436436

437+
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
438+
// of the given capability
439+
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
440+
ops.impl("cutlass_scaled_mm_supports_fp4", &cutlass_scaled_mm_supports_fp4);
437441
#endif
438442

439443
// Quantized GEMM for GPTQ.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# flake8: noqa
3+
"""Tests Model Optimizer nvfp4 models against ground truth generation
4+
Note: these tests will only pass on B200
5+
"""
6+
import os
7+
from typing import List
8+
9+
import pytest
10+
from transformers import AutoTokenizer
11+
12+
from tests.quantization.utils import is_quant_method_supported
13+
from vllm import LLM, SamplingParams
14+
15+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
16+
17+
MAX_MODEL_LEN = 1024
18+
19+
MODELS = ["nvidia/Llama-3.3-70B-Instruct-FP4"]
20+
21+
EXPECTED_STRS_MAP = {
22+
"nvidia/Llama-3.3-70B-Instruct-FP4": [
23+
'vLLM (Vectorized Large Language Model) is indeed a high-throughput and memory-efficient inference',
24+
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
25+
'Artificial intelligence (AI) and human intelligence (HI) are two distinct forms of intelligence that process',
26+
'A neural network is a type of machine learning model inspired by the structure and function of the human brain',
27+
'In the heart of a cutting-edge robotics lab, a team of engineers had been working tirelessly to push',
28+
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models, leading',
29+
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
30+
'Here are the translations:\n\n* Japanese: (Sasuga no tori ga miwa o ts'
31+
]
32+
}
33+
34+
35+
# This test compares against golden strings for exact match since
36+
# there is no baseline implementation to compare against
37+
# and is unstable w.r.t specifics of the fp4 implementation or
38+
# the hardware being run on.
39+
# Disabled to prevent it from breaking the build
40+
@pytest.mark.skip(
41+
reason=
42+
"Prevent unstable test based on golden strings from breaking the build "
43+
" and test input model being too large and hanging the system.")
44+
@pytest.mark.quant_model
45+
@pytest.mark.skipif(not is_quant_method_supported("nvfp4"),
46+
reason="nvfp4 is not supported on this GPU type.")
47+
@pytest.mark.parametrize("model_name", MODELS)
48+
def test_models(example_prompts, model_name) -> None:
49+
model = LLM(
50+
model=model_name,
51+
max_model_len=MAX_MODEL_LEN,
52+
trust_remote_code=True,
53+
enforce_eager=True,
54+
quantization="nvfp4",
55+
)
56+
57+
tokenizer = AutoTokenizer.from_pretrained(model_name)
58+
formatted_prompts = [
59+
tokenizer.apply_chat_template([{
60+
"role": "user",
61+
"content": prompt
62+
}],
63+
tokenize=False,
64+
add_generation_prompt=True)
65+
for prompt in example_prompts
66+
]
67+
params = SamplingParams(max_tokens=20, temperature=0)
68+
generations: List[str] = []
69+
# Note: these need to be run 1 at a time due to numerical precision,
70+
# since the expected strs were generated this way.
71+
for prompt in formatted_prompts:
72+
outputs = model.generate(prompt, params)
73+
generations.append(outputs[0].outputs[0].text)
74+
del model
75+
76+
print(model_name, generations)
77+
expected_strs = EXPECTED_STRS_MAP[model_name]
78+
for i in range(len(example_prompts)):
79+
generated_str = generations[i]
80+
expected_str = expected_strs[i]
81+
assert expected_str == generated_str, (
82+
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}")

vllm/_custom_ops.py

+4
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,10 @@ def _ggml_moe_a8_fake(
467467

468468

469469
# cutlass
470+
def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool:
471+
return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability)
472+
473+
470474
def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
471475
block_scale_a: torch.Tensor,
472476
block_scale_b: torch.Tensor, alpha: torch.Tensor,

vllm/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def _verify_quantization(self) -> None:
613613
optimized_quantization_methods = [
614614
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
615615
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
616-
"compressed-tensors", "experts_int8", "quark"
616+
"compressed-tensors", "experts_int8", "quark", "nvfp4"
617617
]
618618
if self.quantization is not None:
619619
self.quantization = self.quantization.lower()

vllm/model_executor/layers/linear.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,23 @@
3030
logger = init_logger(__name__)
3131

3232
WEIGHT_LOADER_V2_SUPPORTED = [
33-
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
34-
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
35-
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
36-
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
37-
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
38-
"HQQMarlinMethod", "QuarkLinearMethod"
33+
"CompressedTensorsLinearMethod",
34+
"AWQMarlinLinearMethod",
35+
"AWQLinearMethod",
36+
"GPTQMarlinLinearMethod",
37+
"Fp8LinearMethod",
38+
"MarlinLinearMethod",
39+
"QQQLinearMethod",
40+
"GPTQMarlin24LinearMethod",
41+
"TPUInt8LinearMethod",
42+
"GPTQLinearMethod",
43+
"FBGEMMFp8LinearMethod",
44+
"ModelOptFp8LinearMethod",
45+
"IPEXAWQLinearMethod",
46+
"IPEXGPTQLinearMethod",
47+
"HQQMarlinMethod",
48+
"QuarkLinearMethod",
49+
"ModelOptNvFp4LinearMethod",
3950
]
4051

4152

vllm/model_executor/layers/quantization/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"ptpc_fp8",
1515
"fbgemm_fp8",
1616
"modelopt",
17+
"nvfp4",
1718
# The order of gptq methods is important for config.py iteration over
1819
# override_quantization_method(..)
1920
"marlin",
@@ -97,7 +98,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
9798
from .hqq_marlin import HQQMarlinConfig
9899
from .ipex_quant import IPEXConfig
99100
from .marlin import MarlinConfig
100-
from .modelopt import ModelOptFp8Config
101+
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
101102
from .moe_wna16 import MoeWNA16Config
102103
from .neuron_quant import NeuronQuantConfig
103104
from .ptpc_fp8 import PTPCFp8Config
@@ -112,6 +113,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
112113
"fp8": Fp8Config,
113114
"fbgemm_fp8": FBGEMMFp8Config,
114115
"modelopt": ModelOptFp8Config,
116+
"nvfp4": ModelOptNvFp4Config,
115117
# The order of gptq methods is important for config.py iteration over
116118
# override_quantization_method(..)
117119
"marlin": MarlinConfig,

0 commit comments

Comments
 (0)