Skip to content

Commit 1bacd02

Browse files
committed
W4A8 based on CUTLASS
1 parent 8236a87 commit 1bacd02

File tree

7 files changed

+572
-2
lines changed

7 files changed

+572
-2
lines changed

setup.py

+7
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def get_extensions():
6565
extension = CUDAExtension if use_cuda else CppExtension
6666

6767
if not IS_WINDOWS:
68+
import cutlass_library
69+
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
70+
cutlass_include_dir = os.path.join(cutlass_library_dir, "source", "include")
71+
# FIXME: remove this once CUTLASS package updated to include int4/int8 MM
72+
cutlass_include_dir = "/data/quansight/scratch/cutlass/include"
73+
6874
extra_link_args = []
6975
extra_compile_args = {
7076
"cxx": [
@@ -74,6 +80,7 @@ def get_extensions():
7480
"nvcc": [
7581
"-O3" if not debug_mode else "-O0",
7682
"-t=0",
83+
"-I" + cutlass_include_dir,
7784
]
7885
}
7986

test/test_s8s4_linear_cutlass.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# FIXME: move this test to the appropriate test file!!!
2+
3+
import copy
4+
5+
from torchao.quantization import quantize_
6+
from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight
7+
8+
import torch
9+
from torch.testing._internal.common_utils import (
10+
TestCase,
11+
run_tests,
12+
)
13+
14+
import pytest
15+
16+
17+
class ToyModel(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
self.linear = torch.nn.Linear(128, 32)
21+
22+
def forward(self, x):
23+
x = self.linear(x)
24+
return x
25+
26+
27+
class TestS8S4LinearCUTLASS(TestCase):
28+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
29+
def test_s8s4_linear_cutlass_(self):
30+
# FIXME: remove this!
31+
torch.manual_seed(0)
32+
33+
input = torch.rand((64, 128)).half().cuda()
34+
model = ToyModel().half().cuda()
35+
36+
output_ref = model(input)
37+
38+
modelq = copy.deepcopy(model)
39+
quantize_(modelq, int8_dynamic_activation_int4_weight(group_size=128))
40+
output = modelq(input)
41+
42+
assert torch.allclose(output, output_ref, rtol=1e-1, atol=0)
43+
44+
45+
if __name__ == "__main__":
46+
run_tests()

0 commit comments

Comments
 (0)