Skip to content

Commit e1a1ff1

Browse files
committed
W4A8 based on CUTLASS
1 parent 8236a87 commit e1a1ff1

File tree

10 files changed

+619
-0
lines changed

10 files changed

+619
-0
lines changed

docs/source/api_ref_quantization.rst

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ torchao.quantization
1818
Int4WeightOnlyQuantizer
1919
quantize_
2020
int8_dynamic_activation_int4_weight
21+
int8_dynamic_activation_int4_weight_cutlass
2122
int8_dynamic_activation_int8_weight
2223
int4_weight_only
2324
int8_weight_only

setup.py

+7
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def get_extensions():
6565
extension = CUDAExtension if use_cuda else CppExtension
6666

6767
if not IS_WINDOWS:
68+
import cutlass_library
69+
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
70+
cutlass_include_dir = os.path.join(cutlass_library_dir, "source", "include")
71+
# FIXME: remove this once CUTLASS package updated to include int4/int8 MM
72+
cutlass_include_dir = "/data/quansight/scratch/cutlass/include"
73+
6874
extra_link_args = []
6975
extra_compile_args = {
7076
"cxx": [
@@ -74,6 +80,7 @@ def get_extensions():
7480
"nvcc": [
7581
"-O3" if not debug_mode else "-O0",
7682
"-t=0",
83+
"-I" + cutlass_include_dir,
7784
]
7885
}
7986

test/dtypes/test_affine_quantized.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
int4_weight_only,
77
int8_weight_only,
88
int8_dynamic_activation_int4_weight,
9+
int8_dynamic_activation_int4_weight_cutlass,
910
int8_dynamic_activation_int8_weight,
1011
int8_dynamic_activation_int8_semi_sparse_weight,
1112
float8_weight_only,
@@ -25,6 +26,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
2526
base_functions = [
2627
int8_weight_only(),
2728
int8_dynamic_activation_int4_weight(),
29+
int8_dynamic_activation_int4_weight_cutlass(),
2830
int8_dynamic_activation_int8_weight(),
2931
]
3032
if do_int4:

test/quantization/test_quant_api.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
Quantizer,
4040
TwoStepQuantizer,
4141
int8_dynamic_activation_int4_weight,
42+
int8_dynamic_activation_int4_weight_cutlass,
4243
int4_weight_only,
4344
int8_weight_only,
4445
int8_dynamic_activation_int8_weight,

test/test_s8s4_linear_cutlass.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# FIXME: move this test to the appropriate test file!!!
2+
3+
import copy
4+
5+
from torchao.quantization import quantize_
6+
from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight_cutlass
7+
8+
import torch
9+
from torch.testing._internal.common_utils import (
10+
TestCase,
11+
run_tests,
12+
)
13+
14+
import pytest
15+
16+
17+
class ToyModel(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
self.linear1 = torch.nn.Linear(128, 256)
21+
self.linear2 = torch.nn.Linear(256, 128)
22+
23+
def forward(self, x):
24+
x = self.linear1(x)
25+
x = torch.nn.functional.relu(x)
26+
x = self.linear2(x)
27+
return x
28+
29+
30+
class TestS8S4LinearCUTLASS(TestCase):
31+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
32+
def test_s8s4_linear_cutlass_(self):
33+
# FIXME: remove this!
34+
torch.manual_seed(0)
35+
36+
input = torch.rand((64, 128)).half().cuda()
37+
model = ToyModel().half().cuda()
38+
39+
output_ref = model(input)
40+
41+
modelq = copy.deepcopy(model)
42+
quantize_(modelq, int8_dynamic_activation_int4_weight_cutlass())
43+
output = modelq(input)
44+
45+
assert torch.allclose(output, output_ref, rtol=1e-1, atol=0)
46+
47+
48+
if __name__ == "__main__":
49+
run_tests()

0 commit comments

Comments
 (0)