Skip to content

Commit 39f16f4

Browse files
authored
Update torchao api reference and add contributor guide (#1255)
* Update torchao api reference and add contributor guide Summary: 1. updated torchao api reference for quantization to include the APIs we want to expose, renamed torchao/quantization/linear_activation_weight_observer.py and removed the safe_int_mm and int_scaled_matmul from quant_primitives.py 2. added #391 to torchao docs Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * format * typo * renaming * comma * format * comments
1 parent 333bde6 commit 39f16f4

16 files changed

+731
-57
lines changed

docs/source/api_ref_dtypes.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ torchao.dtypes
1212

1313
to_nf4
1414
to_affine_quantized_intx
15-
to_affine_quantized_floatx
1615
to_affine_quantized_intx_static
16+
to_affine_quantized_floatx
1717
to_affine_quantized_floatx_static
18+
to_affine_quantized_fpx
19+
NF4Tensor
1820
AffineQuantizedTensor
1921

2022
..

docs/source/api_ref_intro.rst

+3-6
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
``torchao`` API Reference
22
=========================
33

4-
This section introduces the torchao API reference.
5-
Dive into the details of how torchao integrates with PyTorch to
6-
optimize your machine learning models.
4+
This section introduces the torchao API reference. Dive into the details of how torchao integrates with PyTorch to optimize your machine learning models.
75

86
.. toctree::
97
:glob:
108
:maxdepth: 1
119
:caption: Python API Reference
1210

13-
api_ref_sparsity
14-
api_ref_quantization
1511
api_ref_dtypes
16-
api_ref_kernel
12+
api_ref_quantization
13+
api_ref_sparsity

docs/source/api_ref_quantization.rst

+31-7
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,39 @@ torchao.quantization
99
.. autosummary::
1010
:toctree: generated/
1111
:nosignatures:
12-
13-
SmoothFakeDynQuantMixin
14-
SmoothFakeDynamicallyQuantizedLinear
15-
swap_linear_with_smooth_fq_linear
16-
smooth_fq_linear_to_inference
17-
Int4WeightOnlyGPTQQuantizer
18-
Int4WeightOnlyQuantizer
12+
autoquant
13+
1914
quantize_
2015
int8_dynamic_activation_int4_weight
2116
int8_dynamic_activation_int8_weight
2217
int4_weight_only
2318
int8_weight_only
19+
float8_weight_only
20+
float8_dynamic_activation_float8_weight
21+
float8_static_activation_float8_weight
22+
uintx_weight_only
23+
fpx_weight_only
24+
25+
to_linear_activation_quantized
26+
27+
swap_linear_with_smooth_fq_linear
28+
smooth_fq_linear_to_inference
29+
30+
choose_qparams_affine
31+
choose_qparams_affine_with_min_max
32+
choose_qparams_affine_floatx
33+
quantize_affine
34+
quantize_affine_floatx
35+
dequantize_affine
36+
dequantize_affine_floatx
37+
choose_qparams_and_quantize_affine_hqq
38+
fake_quantize_affine
39+
fake_quantize_affine_cachemask
40+
41+
safe_int_mm
42+
int_scaled_matmul
43+
44+
MappingType
45+
ZeroPointDomain
46+
TorchAODType
47+

docs/source/contributor_guide.rst

+604
Large diffs are not rendered by default.

docs/source/index.rst

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
Welcome to the torchao Documentation
22
=======================================
33

4-
**torchao** is an open-source library that provides the functionality
5-
to quantize and prune your models using native PyTorch. Our documentation is under development
6-
with more content coming soon.
4+
`**torchao** <https://github.com/pytorch/ao>`__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README <https://github.com/pytorch/ao#torchao-pytorch-architecture-optimization>`__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on 1. API Reference 2. Developer / Researcher Contribution Guide 3. Tutorials.
75

86
..
97
.. grid:: 3
@@ -81,13 +79,19 @@ with more content coming soon.
8179
:maxdepth: 1
8280
:caption: API Reference
8381

84-
api_ref_sparsity
85-
api_ref_intro
86-
api_ref_quantization
8782
api_ref_dtypes
83+
api_ref_quantization
84+
api_ref_sparsity
8885
..
8986
api_ref_kernel
90-
87+
88+
.. toctree::
89+
:glob:
90+
:maxdepth: 1
91+
:caption: Contributor Guide
92+
93+
contributor_guide
94+
9195
.. toctree::
9296
:glob:
9397
:maxdepth: 1

test/integration/test_integration.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
change_linear_weights_to_int8_woqtensors,
3535
change_linear_weights_to_int4_woqtensors,
3636
)
37-
from torchao.quantization.quant_primitives import (
37+
from torchao.quantization import (
3838
safe_int_mm,
39+
)
40+
from torchao.quantization.quant_primitives import (
3941
choose_qparams_affine,
4042
quantize_affine,
4143
dequantize_affine,

torchao/dtypes/affine_quantized_tensor.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
addmm_float8_unwrapped_inference,
2222
preprocess_data,
2323
)
24+
from torchao.kernel import (
25+
int_scaled_matmul,
26+
)
2427
from torchao.quantization.quant_primitives import (
2528
FP8_TYPES,
2629
MappingType,
@@ -31,7 +34,6 @@
3134
choose_qparams_and_quantize_affine_hqq,
3235
dequantize_affine,
3336
dequantize_affine_floatx,
34-
int_scaled_matmul,
3537
quantize_affine,
3638
quantize_affine_floatx,
3739
)

torchao/kernel/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from torchao.kernel.intmm import int_scaled_matmul
2+
from torchao.kernel.intmm import safe_int_mm
3+
4+
__all__ = [
5+
"safe_int_mm",
6+
"int_scaled_matmul",
7+
]

torchao/quantization/__init__.py

+55-25
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from torchao.kernel import (
8+
int_scaled_matmul,
9+
safe_int_mm,
10+
)
11+
712
from .autoquant import (
813
DEFAULT_AUTOQUANT_CLASS_LIST,
914
DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
@@ -51,10 +56,18 @@
5156
)
5257
from .quant_primitives import (
5358
MappingType,
59+
TorchAODType,
5460
ZeroPointDomain,
5561
choose_qparams_affine,
62+
choose_qparams_affine_floatx,
63+
choose_qparams_affine_with_min_max,
64+
choose_qparams_and_quantize_affine_hqq,
5665
dequantize_affine,
66+
dequantize_affine_floatx,
67+
fake_quantize_affine,
68+
fake_quantize_affine_cachemask,
5769
quantize_affine,
70+
quantize_affine_floatx,
5871
)
5972
from .smoothquant import (
6073
SmoothFakeDynamicallyQuantizedLinear,
@@ -72,50 +85,67 @@
7285
from .weight_only import WeightOnlyInt8QuantLinear
7386

7487
__all__ = [
75-
"swap_conv2d_1x1_to_linear",
88+
# top level API - auto
7689
"autoquant",
7790
"DEFAULT_AUTOQUANT_CLASS_LIST",
7891
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
7992
"OTHER_AUTOQUANT_CLASS_LIST",
80-
"get_scale",
81-
"SmoothFakeDynQuantMixin",
82-
"SmoothFakeDynamicallyQuantizedLinear",
83-
"swap_linear_with_smooth_fq_linear",
84-
"smooth_fq_linear_to_inference",
85-
"set_smooth_fq_attribute",
86-
"compute_error",
87-
"Int4WeightOnlyGPTQQuantizer",
88-
"Int4WeightOnlyQuantizer",
89-
"quantize_affine",
90-
"dequantize_affine",
91-
"choose_qparams_affine",
93+
# top level API - manual
9294
"quantize_",
9395
"int8_dynamic_activation_int4_weight",
9496
"int8_dynamic_activation_int8_weight",
9597
"int8_dynamic_activation_int8_semi_sparse_weight",
9698
"int4_weight_only",
9799
"int8_weight_only",
100+
"float8_weight_only",
101+
"float8_dynamic_activation_float8_weight",
102+
"float8_static_activation_float8_weight",
98103
"uintx_weight_only",
99104
"fpx_weight_only",
100-
"LinearActivationQuantizedTensor",
105+
# smooth quant - subject to change
106+
"swap_conv2d_1x1_to_linear",
107+
"get_scale",
108+
"SmoothFakeDynQuantMixin",
109+
"SmoothFakeDynamicallyQuantizedLinear",
110+
"swap_linear_with_smooth_fq_linear",
111+
"smooth_fq_linear_to_inference",
112+
"set_smooth_fq_attribute",
113+
"compute_error",
114+
# building blocks
101115
"to_linear_activation_quantized",
102116
"to_weight_tensor_with_linear_activation_scale_metadata",
103-
"float8_weight_only",
104-
"float8_dynamic_activation_float8_weight",
105-
"float8_static_activation_float8_weight",
106-
"Int8DynActInt4WeightGPTQQuantizer",
107-
"Int8DynActInt4WeightQuantizer",
108-
"Int8DynActInt4WeightLinear",
109-
"WeightOnlyInt8QuantLinear",
110-
"TwoStepQuantizer",
111-
"Quantizer",
112-
"ZeroPointDomain",
113-
"MappingType",
114117
"AffineQuantizedMinMaxObserver",
115118
"AffineQuantizedObserverBase",
119+
# quant primitive ops
120+
"choose_qparams_affine",
121+
"choose_qparams_affine_with_min_max",
122+
"choose_qparams_affine_floatx",
123+
"quantize_affine",
124+
"quantize_affine_floatx",
125+
"dequantize_affine",
126+
"dequantize_affine_floatx",
127+
"choose_qparams_and_quantize_affine_hqq",
128+
"fake_quantize_affine",
129+
"fake_quantize_affine_cachemask",
130+
# operators/kernels
131+
"safe_int_mm",
132+
"int_scaled_matmul",
133+
# dataclasses and types
134+
"MappingType",
135+
"ZeroPointDomain",
136+
"TorchAODType",
116137
"PerTensor",
117138
"PerAxis",
118139
"PerGroup",
119140
"PerRow",
120141
"PerToken",
142+
"LinearActivationQuantizedTensor",
143+
"Int4WeightOnlyGPTQQuantizer",
144+
"Int4WeightOnlyQuantizer",
145+
"Int8DynActInt4WeightGPTQQuantizer",
146+
"Int8DynActInt4WeightQuantizer",
147+
"Int8DynActInt4WeightLinear",
148+
"WeightOnlyInt8QuantLinear",
149+
"TwoStepQuantizer",
150+
"Quantizer",
121151
]

torchao/quantization/autoquant.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
TensorCoreTiledLayout,
1111
)
1212
from torchao.float8.inference import Float8MMConfig
13+
from torchao.kernel import safe_int_mm
1314
from torchao.quantization.linear_activation_quantized_tensor import (
1415
LinearActivationQuantizedTensor,
1516
)
@@ -24,7 +25,6 @@
2425
PerRow,
2526
PerTensor,
2627
)
27-
from .quant_primitives import safe_int_mm
2828
from .subclass import ( # noqa
2929
Int8DynamicallyQuantizedLinearWeight,
3030
Int8WeightOnlyQuantizedLinearWeight,

torchao/quantization/linear_activation_quantized_tensor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
class LinearActivationQuantizedTensor(TorchAOBaseTensor):
2020
"""
2121
Applies activation quantization for linear operator, this is used to support
22-
dynamic quantization or static quantization, user can pass in a `input_quant_func`
22+
dynamic quantization, user can pass in a `input_quant_func`
2323
that is used to quantize the activation
2424
2525
Args:
@@ -60,7 +60,7 @@ def __init__(
6060
self.quant_kwargs = quant_kwargs
6161

6262
def __repr__(self):
63-
return f"LinearActivationQuantizedTensor({self.original_weight_tensor}, {self.input_quant_func}, quant_kwargs={self.quant_kwargs}))"
63+
return f"{self.__class__.__name__}({self.original_weight_tensor}, {self.input_quant_func}, quant_kwargs={self.quant_kwargs}))"
6464

6565
def __tensor_flatten__(self):
6666
return ["original_weight_tensor"], [self.input_quant_func, self.quant_kwargs]

torchao/quantization/quant_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from torchao.dtypes.uintx.uintx import UintxLayout
4040
from torchao.float8.inference import Float8MMConfig
41-
from torchao.quantization.linear_activation_weight_observer import (
41+
from torchao.quantization.linear_activation_weight_observed_tensor import (
4242
LinearActivationWeightObservedTensor,
4343
)
4444
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size

torchao/quantization/quant_primitives.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import torch
1212

13-
from torchao.kernel.intmm import int_scaled_matmul, safe_int_mm
1413
from torchao.prototype.custom_fp_utils import (
1514
_f32_to_floatx_unpacked,
1615
_floatx_unpacked_to_f32,
@@ -24,8 +23,6 @@
2423
)
2524

2625
__all__ = [
27-
"safe_int_mm",
28-
"int_scaled_matmul",
2926
"choose_qparams_affine",
3027
"choose_qparams_affine_with_min_max",
3128
"choose_qparams_affine_floatx",
@@ -36,6 +33,9 @@
3633
"fake_quantize_affine",
3734
"fake_quantize_affine_cachemask",
3835
"choose_qparams_and_quantize_affine_hqq",
36+
"MappingType",
37+
"ZeroPointDomain",
38+
"TorchAODType",
3939
]
4040

4141

torchao/quantization/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
import torch
1010
from torch.utils._python_dispatch import TorchDispatchMode
1111

12+
from torchao.kernel import (
13+
int_scaled_matmul,
14+
)
1215
from torchao.quantization.quant_primitives import (
1316
MappingType,
1417
ZeroPointDomain,
1518
choose_qparams_affine,
1619
dequantize_affine,
17-
int_scaled_matmul,
1820
quantize_affine,
1921
)
2022
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

torchao/quantization/weight_tensor_linear_activation_quantization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
self.quant_kwargs = quant_kwargs
7171

7272
def __repr__(self):
73-
return f"LinearActivationQuantizedTensor({self.original_weight_tensor}, {self.input_quant_func_static}, scale={self.scale}, zero_point={self.zero_point}, quant_kwargs={self.quant_kwargs})"
73+
return f"{self.__class__.__name__}({self.original_weight_tensor}, {self.input_quant_func_static}, scale={self.scale}, zero_point={self.zero_point}, quant_kwargs={self.quant_kwargs})"
7474

7575
def __tensor_flatten__(self):
7676
tensor_data = ["original_weight_tensor", "scale"]

0 commit comments

Comments
 (0)