Skip to content

Commit cbc74ee

Browse files
authored
Add AffineQuantizedTensor based workflow doc and examples (#277)
1 parent 6dd63b8 commit cbc74ee

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

torchao/quantization/README.md

+111
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,117 @@ model = torch.compile(model, mode='max-autotune')
164164
model(input)
165165
```
166166

167+
## Affine Quantization
168+
Affine quantization refers to the type of quantization that maps from floating point numbers to quantized numbers (typically integer) with an affine transformation, i.e.: `quantized_val = float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data.
169+
170+
### Quantization Primitives
171+
We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass.
172+
173+
### Quantized Tensor Subclass
174+
We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel)
175+
176+
### Quantization Flow
177+
What we need to do afterwards is roughly the following
178+
179+
```
180+
from torchao.dtypes.aqt import to_aq
181+
def apply_int8wo_quant(weight):
182+
mapping_type = MappingType.SYMMETRIC
183+
target_dtype = torch.int8
184+
eps = torch.finfo(torch.float32).eps
185+
zero_point_dtype = torch.int64
186+
block_size = (1, weight.shape[1])
187+
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
188+
189+
for n, m in model.named_modules():
190+
if isinstance(m, torch.nn.Linear):
191+
# optional filtering for module name, shape etc.
192+
m.weight = nn.Parameter(apply_int8wo_quant(m.weight))
193+
# note: quantization for activation need to be applied after the weight quantization
194+
# quantization activation (needed by dynamic quantization)
195+
# input_quant_func = apply_int8wo_quant # specify how input activation is quantized
196+
# m.weight = nn.Parameter(to_laq(m.weight, input_quant_func))
197+
```
198+
The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support
199+
`torch.export.export` and `torch.aot_compile` with the following workaround:
200+
```
201+
from torchao.quantization.utils import unwrap_tensor_subclass
202+
m_unwrapped = unwrap_tensor_subclass(m)
203+
204+
205+
# export
206+
m = torch.export.export(m_unwrapped, example_inputs).module()
207+
208+
# aot_compile
209+
torch._export.aot_compile(m_unwrapped, example_inputs)
210+
```
211+
212+
But we expect this will be integrated into the export path by default in the future.
213+
214+
215+
### Example
216+
Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul
217+
as an example:
218+
```python
219+
import torch
220+
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
221+
from torchao.dtypes import to_aq
222+
from torch._inductor.runtime.runtime_utils import do_bench_gpu
223+
import copy
224+
from torchao.quantization.quant_api import (
225+
quantize,
226+
get_apply_int4wo_quant,
227+
)
228+
229+
class ToyLinearModel(torch.nn.Module):
230+
def __init__(self, m=64, n=32, k=64):
231+
super().__init__()
232+
self.linear1 = torch.nn.Linear(m, n, bias=False)
233+
self.linear2 = torch.nn.Linear(n, k, bias=False)
234+
235+
def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
236+
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
237+
238+
def forward(self, x):
239+
x = self.linear1(x)
240+
x = self.linear2(x)
241+
return x
242+
243+
dtype = torch.bfloat16
244+
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")
245+
m_bf16 = copy.deepcopy(m)
246+
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
247+
248+
m_bf16 = torch.compile(m_bf16, mode='max-autotune')
249+
# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao)
250+
groupsize = 32
251+
m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize))
252+
253+
torch._inductor.config.force_fuse_int_mm_with_mul = True
254+
torch._inductor.config.use_mixed_mm = True
255+
256+
# temporary workaround for tensor subclass + torch.compile
257+
from torchao.quantization.utils import unwrap_tensor_subclass
258+
m = unwrap_tensor_subclass(m)
259+
# compile the model to improve performance
260+
m = torch.compile(m, mode='max-autotune')
261+
262+
# benchmark to see the speedup
263+
from torchao.utils import benchmark_model
264+
265+
num_runs = 100
266+
torch._dynamo.reset()
267+
bf16_time = benchmark_model(m_bf16, num_runs, example_inputs[0])
268+
print(f"bf16 mean time: {bf16_time}")
269+
int4_time = benchmark_model(m, num_runs, example_inputs[0])
270+
print(f"int4 weight only quantized mean time: {int4_time}")
271+
print(f"speedup: {bf16_time / int4_time}")
272+
273+
# output (1xA100 GPU machine)
274+
bf16 mean time: 71.457685546875
275+
int4 weight only quantized mean time: 31.4580908203125
276+
speedup: 2.2715200981216173
277+
```
167278

168279
## Notes
169280

0 commit comments

Comments
 (0)