Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QuantizedLinearNotImplementedError when inference with Int8DynamicActivationInt4WeightConfig #1909

Closed
goldhuang opened this issue Mar 17, 2025 · 12 comments

Comments

@goldhuang
Copy link

Hi, my inference code hits exception here https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor_ops.py#L228
when I use Int8DynamicActivationInt4Weight. The inference is slower than bf16 inference, as it falls back and dequantized back to bf16.
I'm with torch2.5.0+cu124.
It will hit the exception too when I disable torch.compile().

import torch
from torchao.quantization import (
    quantize_,
    Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    FromIntXQuantizationAwareTrainingConfig,
    IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization.quant_primitives import (
    TorchAODType,
)


class PytorchLinear(torch.nn.Module):
    def __init__(self, in_features=4096, out_features=12288):
        super().__init__()
        self.linear = torch.nn.Linear(in_features, out_features).cuda()

    def forward(self, x):
        return self.linear(x)

model = PytorchLinear().cuda().to(torch.bfloat16)

input_tensors = [torch.randn((70000, 4096), dtype=torch.bfloat16, device="cuda") for _ in range(100)]

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=32)
quantize_(
    model,
    IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
)

# convert: transform fake quantization ops into actual quantized ops
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
# quantized activation and weight tensor subclasses
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))

model = torch.compile(model.eval(), mode="max-autotune-no-cudagraphs")

# CUDA events for precise timing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

with torch.no_grad():
    for i in range(10):
        model(input_tensors[i])
torch.cuda.synchronize()

start_event.record()
with torch.no_grad():
    for i in range(50):
        model(input_tensors[i+10])
# Record end time
end_event.record()

# Wait for completion
torch.cuda.synchronize()

# Compute elapsed time (in milliseconds)
elapsed_time = start_event.elapsed_time(end_event) / 50  # Average per iteration

print(f"Avg Inference Time: {elapsed_time:.3f} ms")
@jerryzh168
Copy link
Contributor

Int8DynamicActivationInt4Weight is supposed to be lowered to executorch to get speedup, but we also support CutlassInt4PackedLayout that gives speedup on GPU I think:

`layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now
can you try that?

quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32, layout=CutlassInt4PackedLayout()))

@goldhuang
Copy link
Author

goldhuang commented Mar 17, 2025

@jerryzh168 I changed to

quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32, layout=CutlassInt4PackedLayout()))

It still hits the exception.

@jerryzh168
Copy link
Contributor

@goldhuang can you paste the error message?

@goldhuang
Copy link
Author

@jerryzh168 jerryzh168 assigned jerryzh168 and unassigned jerryzh168 Mar 17, 2025
@goldhuang
Copy link
Author

goldhuang commented Mar 18, 2025

torchao                  0.10.0+git6b76adbe
torch                    2.6.0
triton                   3.2.0
nvidia-cutlass           3.8.0.0

@jerryzh168 It's still not working after I upgraded to newer torch. Int8DynamicActivationInt4WeightConfig is basically not working, as it's actually running in bf16 (the original dtype of the model and input) with extra dequantize().

@goldhuang
Copy link
Author

goldhuang commented Mar 18, 2025

@jerryzh168
input_tensor.tensor_impl.scale.dtype == torch.float32
in
def _linear_int8_act_int4_weight_cutlass_check
is False. (It's torch.float64 instead.)
So _linear_int8_act_int4_weight_cutlass does not match in my case.

@jerryzh168
Copy link
Contributor

I see, then yeah Int8DynamicActivationInt4WeightConfig (the default layout) is supposed to lowered to executorch for speedup, it won't have speedup on GPUs. if you want to have speedup on GPU for float64, the easiest might be to convert the model to float32 and use the cutlass option, would that work for you?

@goldhuang
Copy link
Author

goldhuang commented Mar 18, 2025

@jerryzh168

  File "/opt/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/colligo/ao/torchao/utils.py", line 421, in _dispatch__torch_function__
    return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
  File "/home/colligo/ao/torchao/utils.py", line 400, in wrapper
    return func(f, types, args, kwargs)
  File "/home/colligo/ao/torchao/quantization/linear_activation_quantized_tensor.py", line 143, in _
    return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
  File "/home/colligo/ao/torchao/quantization/linear_activation_quantized_tensor.py", line 89, in _quantized_linear_op
    return torch.nn.functional.linear(
  File "/home/colligo/ao/torchao/utils.py", line 421, in _dispatch__torch_function__
    return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
  File "/home/colligo/ao/torchao/utils.py", line 400, in wrapper
    return func(f, types, args, kwargs)
  File "/home/colligo/ao/torchao/dtypes/affine_quantized_tensor_ops.py", line 227, in _
    return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
  File "/home/colligo/ao/torchao/dtypes/affine_quantized_tensor_ops.py", line 140, in _quantized_linear_op
    return impl(input_tensor, weight_tensor, bias)
  File "/home/colligo/ao/torchao/dtypes/uintx/cutlass_int4_packed_layout.py", line 185, in _linear_int8_act_int4_weight_cutlass_impl
    out = rowwise_scaled_linear_cutlass_s8s4(
  File "/home/colligo/ao/torchao/ops.py", line 575, in rowwise_scaled_linear_cutlass_s8s4
    return torch.ops.torchao.rowwise_scaled_linear_cutlass_s8s4.default(
  File "/opt/venv/lib/python3.10/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
NotImplementedError: Could not run 'torchao::rowwise_scaled_linear_cutlass_s8s4' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'torchao::rowwise_scaled_linear_cutlass_s8s4' is only available for these backends: [Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMTIA, AutogradMeta, Tracer, AutocastCPU, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Meta: registered at /dev/null:198 [kernel]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:194 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:503 [backend fallback]
Functionalize: registered at /pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp:349 [backend fallback]
Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /pytorch/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /pytorch/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at /pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:100 [backend fallback]
AutogradOther: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:63 [backend fallback]
AutogradCPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:67 [backend fallback]
AutogradCUDA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:75 [backend fallback]
AutogradXLA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:83 [backend fallback]
AutogradMPS: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:91 [backend fallback]
AutogradXPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:71 [backend fallback]
AutogradHPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:104 [backend fallback]
AutogradLazy: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:87 [backend fallback]
AutogradMTIA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:79 [backend fallback]
AutogradMeta: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:95 [backend fallback]
Tracer: registered at /pytorch/torch/csrc/autograd/TraceTypeManual.cpp:294 [backend fallback]
AutocastCPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:322 [backend fallback]
AutocastXPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:465 [backend fallback]
AutocastMPS: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:209 [backend fallback]
AutocastCUDA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:165 [backend fallback]
FuncTorchBatched: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:731 [backend fallback]
BatchedNestedTensor: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:207 [backend fallback]
PythonTLSSnapshot: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:202 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:499 [backend fallback]
PreDispatch: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:206 [backend fallback]
PythonDispatcher: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:198 [backend fallback]

I got this after changing this line to torch.float32 https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py#L552
(My model is not in float64. I think the line above created float64 scale.)

Could you guide me on the error above?

@jerryzh168
Copy link
Contributor

oh sorry, here is how you use cutlass s8s4 kernel:

int8_dynamic_activation_int4_weight(

@goldhuang
Copy link
Author

goldhuang commented Mar 19, 2025

@jerryzh168 Thanks for sharing the test code. But the code is obsolete and cannot run with latest torchao code because of the redesign of the quantize_().
Besides, I see from the comments in source code that int8 activation + int4 weight is only for executorch and I believe int4 kernels are not available in regular torchao build. (USE_CPP=1 does not build the kernels either.)
So I finally go with int8 activation + int8 weight. But to be honest, I don't see much value of it in a regular server deployment because it's much slower than fp8 (and we are going to train with fp8 linear layers soon).

@jerryzh168
Copy link
Contributor

I see, cc @alexsamardzic is there any other flag we need to enable to include cutlass kernel in build?

I can test out the api a bit later.

yeah the default layout variant for int8 activation + int4 weight is only for executorch

I don't see much value of it in a regular server deployment because it's much slower than fp8 (and we are going to train with fp8 linear layers soon).

if you are referring to the default layout, then yeah, it's just for ET, but for cutlass one I think has some speedup in A100: #880 (and float8 is probably better in H100)

@alexsamardzic
Copy link
Collaborator

In order to use CUTLASS-based kernel for W4A8, the quantize_ call should be as follows:

quantize_(
    model,
    Int8DynamicActivationInt4WeightConfig(
        group_size=None,
        mapping_type=MappingType.SYMMETRIC,
        act_mapping_type=MappingType.SYMMETRIC,
        layout=CutlassInt4PackedLayout(),
    ),
)

Also, please add following imports:

from torchao.dtypes.uintx import CutlassInt4PackedLayout
from torchao.quantization.quant_primitives import MappingType

As, to my knowledge, there is no support either in Triton for S8/S4 GEMM, or in CUTLASS auto-tuning "back-end" for Inductor, the same CUTLASS-based W4A8 CUDA kernel from torchao should be executed for above config in both eager and compiled mode. The speed-up expected over non-quantized case is not particularly significant, I'm at the moment looking into some improvements; also, this CUTLASS-based kernel has some caveats, for example group quantization is not supported. As @jerryzh168 mentioned above: this kernel is really just for Ampere generation of GPUs (the kernel will be compiled if TORCH_CUDA_ARCH_LIST contains an 8.x arch, and if USE_CPP set to 1, there are no additional flags needed), for Hopper and later quantizing to lower prevision FP data types (like FP8, or MX data types) is much better option.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants