Skip to content

Commit 825df6b

Browse files
committed
W4A8 based on CUTLASS
CUTLASS-based s8s4_linear_cutlass() operator is introduced, performing linear transformation over quantized 8-bit input and quantized 4-bit weight tensors, with corresponding floating point scale tensors attached. A benchmark script, for comparing performance of MM based on this linear operator with MM over 16-bit floating point tensors is supplied in benchmarks/benchmarks/benchmark_s8s4_cutlass.py. The Llama generator script torchao/_models/llama/generate.py is changed, to add "int8adq-int4w-symm" quantization as an option, that will in turn activate s8s4_linear_cutlass() operator. With this type of quantization activated, i.e. if generate.py script run as follows: python generate.py --compile --precision=torch.float16 -q int8adq-int4w-symm the generator achieves around 133 tok/sec on A100, vs. around 93 tok/sec without quantization, i.e. when generate.py script run as follows: python generate.py --compile --precision=torch.float16
1 parent 567cb46 commit 825df6b

20 files changed

+1024
-7
lines changed

.github/workflows/float8_test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ jobs:
3535
runner: ${{ matrix.runs-on }}
3636
gpu-arch-type: ${{ matrix.gpu-arch-type }}
3737
gpu-arch-version: ${{ matrix.gpu-arch-version }}
38+
submodules: recursive
3839
script: |
3940
conda create -n venv python=3.9 -y
4041
conda activate venv

.github/workflows/nightly_smoke_test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ jobs:
3131
runner: ${{ matrix.runs-on }}
3232
gpu-arch-type: ${{ matrix.gpu-arch-type }}
3333
gpu-arch-version: ${{ matrix.gpu-arch-version }}
34+
submodules: recursive
3435
script: |
3536
python -m pip install --upgrade pip
3637
pip install ${{ matrix.torch-spec }}

.github/workflows/regression_test.yml

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ jobs:
4040
runner: ${{ matrix.runs-on }}
4141
gpu-arch-type: ${{ matrix.gpu-arch-type }}
4242
gpu-arch-version: ${{ matrix.gpu-arch-version }}
43+
submodules: recursive
4344
script: |
4445
conda create -n venv python=3.9 -y
4546
conda activate venv
@@ -93,6 +94,7 @@ jobs:
9394
runner: ${{ matrix.runs-on }}
9495
gpu-arch-type: ${{ matrix.gpu-arch-type }}
9596
gpu-arch-version: ${{ matrix.gpu-arch-version }}
97+
submodules: recursive
9698
script: |
9799
conda create -n venv python=3.9 -y
98100
conda activate venv

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "third_party/cutlass"]
2+
path = third_party/cutlass
3+
url = https://github.com/NVIDIA/cutlass

benchmarks/benchmark_s8s4_cutlass.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import pandas as pd
3+
from torchao.utils import benchmark_torch_function_in_microseconds
4+
from torchao.ops import s8s4_linear_cutlass
5+
from tqdm import tqdm
6+
7+
8+
def get_problem(m, n, k):
9+
groupsize = k
10+
11+
dev = torch.device("cuda")
12+
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
13+
B_ref = torch.randn((k, n), dtype=torch.half, device=dev)
14+
15+
A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev)
16+
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
17+
B = torch.randint(-128, 127, size=(n, k // 2), dtype=torch.int8, device=dev)
18+
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
19+
C = None
20+
21+
return A_ref, B_ref, A, A_scale, B, B_scale, C
22+
23+
24+
def benchmark(m: int, k: int, n: int):
25+
A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k)
26+
27+
fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
28+
s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds(
29+
s8s4_linear_cutlass, A, A_scale, B, B_scale, C
30+
)
31+
32+
return {
33+
"m": m,
34+
"k": k,
35+
"n": n,
36+
"fp16_latency (ms)": fp16_time,
37+
"s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time,
38+
"speedup (d/s)": fp16_time / s8s4_linear_cutlass_time,
39+
}
40+
41+
42+
if __name__ == "__main__":
43+
k_vals = (8192, 8192, 8192, 28672)
44+
n_vals = (8192, 10240, 57344, 8192)
45+
46+
results = []
47+
for m in tqdm([1 << i for i in range(10)]):
48+
for n, k in zip(n_vals, k_vals):
49+
results.append(benchmark(m, k, n))
50+
51+
df = pd.DataFrame(results)
52+
df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False)
53+
print(df.to_markdown(index=False))

setup.py

+12
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ def get_extensions():
9393
extra_compile_args["nvcc"].append("-g")
9494
extra_link_args.append("/DEBUG")
9595

96+
use_cutlass = False
97+
if use_cuda and not IS_WINDOWS:
98+
use_cutlass = True
99+
this_dir = os.path.abspath(os.path.curdir)
100+
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass")
101+
cutlass_include_dir = os.path.join(cutlass_dir, "include")
102+
if use_cutlass:
103+
extra_compile_args["nvcc"].extend([
104+
"-DTORCHAO_USE_CUTLASS",
105+
"-I" + cutlass_include_dir,
106+
])
107+
96108
this_dir = os.path.dirname(os.path.curdir)
97109
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
98110
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

test/dtypes/test_affine_quantized.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
run_tests,
99
)
1010

11-
from torchao.dtypes import Int4CPULayout, SemiSparseLayout
11+
from torchao.dtypes import Int4CPULayout, Int4PackedLayout, SemiSparseLayout
1212
from torchao.quantization import (
1313
float8_weight_only,
1414
int4_weight_only,
@@ -48,6 +48,15 @@ def get_quantization_functions(
4848
)
4949
else:
5050
base_functions.append(int4_weight_only(group_size=32))
51+
if device == "cuda":
52+
base_functions.append(
53+
int8_dynamic_activation_int4_weight(
54+
group_size=None,
55+
mapping_type=MappingType.SYMMETRIC,
56+
act_mapping_type=MappingType.SYMMETRIC,
57+
layout=Int4PackedLayout(),
58+
)
59+
)
5160

5261
if do_sparse:
5362
base_functions.append(

test/test_s8s4_linear_cutlass.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import itertools
2+
3+
import torch
4+
5+
import torchao
6+
from torchao.ops import s8s4_linear_cutlass
7+
from torchao.quantization.utils import group_quantize_tensor_symmetric
8+
from torchao.utils import compute_max_diff
9+
10+
import pytest
11+
12+
13+
S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
14+
S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
15+
S8S4_LINEAR_CUTLASS_SIZE_MNK = [
16+
(2, 512, 128),
17+
(3, 2048, 2048),
18+
(4, 3584, 640),
19+
(13, 8704, 8576),
20+
(26, 18944, 1664),
21+
(67, 6656, 1408),
22+
]
23+
S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True]
24+
S8S4_LINEAR_CUTLASS_TEST_PARAMS = list(
25+
itertools.product(
26+
S8S4_LINEAR_CUTLASS_DTYPE,
27+
S8S4_LINEAR_CUTLASS_BATCH_SIZE,
28+
S8S4_LINEAR_CUTLASS_SIZE_MNK,
29+
S8S4_LINEAR_CUTLASS_USE_BIAS,
30+
)
31+
)
32+
33+
34+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
35+
@pytest.mark.parametrize(
36+
"dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS
37+
)
38+
def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias):
39+
size_m, size_n, size_k = size_mnk
40+
41+
input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
42+
weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
43+
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
44+
45+
input_2d = input.view(-1, input.shape[-1])
46+
input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
47+
input_2d, 8, size_k, dtype
48+
)
49+
assert torch.all(input_2d_zeros == 0)
50+
input_s8 = input_2d_s8.reshape(input.shape)
51+
input_scales = input_2d_scales.reshape(input.shape[:-1])
52+
53+
weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
54+
weight, 4, size_n, dtype
55+
)
56+
assert torch.all(weight_zeros == 0)
57+
weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)
58+
59+
# If torch.nn.functional.linear(input, weight, bias) used as
60+
# reference, the error would be too big. The calculation below is
61+
# approximately what s8s4_linear_cutlass kernel is doing (except
62+
# that matrrix multiplication is over integers there)).
63+
size_m_2d = input_2d.shape[0]
64+
output_ref = (
65+
(input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
66+
* input_2d_scales.view(size_m_2d, 1)
67+
* weight_scales.view(1, size_n)
68+
)
69+
if bias is not None:
70+
output_ref += bias
71+
output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))
72+
73+
fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias)
74+
try:
75+
output = s8s4_linear_cutlass(*fn_inputs)
76+
except NotImplementedError as e:
77+
pytest.xfail("s8s4_linear_cutlass() op not implemented")
78+
79+
max_diff = compute_max_diff(output, output_ref)
80+
assert max_diff < 5e-3

third_party/cutlass

Submodule cutlass added at bf9da7b

torchao/_models/llama/generate.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,17 @@ def ffn_or_attn_only(mod, fqn):
405405
256,
406406
], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
407407
quantize_(model, int4_weight_only(group_size=group_size))
408+
elif "int8adq-int4w-symm" in quantization:
409+
from torchao.dtypes import Int4PackedLayout
410+
quantize_(
411+
model,
412+
int8_dynamic_activation_int4_weight(
413+
group_size=None,
414+
mapping_type=MappingType.SYMMETRIC,
415+
act_mapping_type=MappingType.SYMMETRIC,
416+
layout=Int4PackedLayout(),
417+
)
418+
)
408419
if "marlin" in quantization:
409420
if "qqq" in quantization:
410421
from torchao.dtypes import MarlinQQQLayout
@@ -1004,7 +1015,7 @@ def callback(x):
10041015
help=(
10051016
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
10061017
+ "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
1007-
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>"
1018+
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, int8adq-int4w-symm"
10081019
),
10091020
)
10101021
parser.add_argument(

0 commit comments

Comments
 (0)