|
4 | 4 | # This source code is licensed under the license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +from torchao.kernel import ( |
| 8 | + int_scaled_matmul, |
| 9 | + safe_int_mm, |
| 10 | +) |
| 11 | + |
7 | 12 | from .autoquant import (
|
8 | 13 | DEFAULT_AUTOQUANT_CLASS_LIST,
|
9 | 14 | DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
|
|
51 | 56 | )
|
52 | 57 | from .quant_primitives import (
|
53 | 58 | MappingType,
|
| 59 | + TorchAODType, |
54 | 60 | ZeroPointDomain,
|
55 | 61 | choose_qparams_affine,
|
| 62 | + choose_qparams_affine_floatx, |
| 63 | + choose_qparams_affine_with_min_max, |
| 64 | + choose_qparams_and_quantize_affine_hqq, |
56 | 65 | dequantize_affine,
|
| 66 | + dequantize_affine_floatx, |
| 67 | + fake_quantize_affine, |
| 68 | + fake_quantize_affine_cachemask, |
57 | 69 | quantize_affine,
|
| 70 | + quantize_affine_floatx, |
58 | 71 | )
|
59 | 72 | from .smoothquant import (
|
60 | 73 | SmoothFakeDynamicallyQuantizedLinear,
|
|
72 | 85 | from .weight_only import WeightOnlyInt8QuantLinear
|
73 | 86 |
|
74 | 87 | __all__ = [
|
75 |
| - "swap_conv2d_1x1_to_linear", |
| 88 | + # top level API - auto |
76 | 89 | "autoquant",
|
77 | 90 | "DEFAULT_AUTOQUANT_CLASS_LIST",
|
78 | 91 | "DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
|
79 | 92 | "OTHER_AUTOQUANT_CLASS_LIST",
|
80 |
| - "get_scale", |
81 |
| - "SmoothFakeDynQuantMixin", |
82 |
| - "SmoothFakeDynamicallyQuantizedLinear", |
83 |
| - "swap_linear_with_smooth_fq_linear", |
84 |
| - "smooth_fq_linear_to_inference", |
85 |
| - "set_smooth_fq_attribute", |
86 |
| - "compute_error", |
87 |
| - "Int4WeightOnlyGPTQQuantizer", |
88 |
| - "Int4WeightOnlyQuantizer", |
89 |
| - "quantize_affine", |
90 |
| - "dequantize_affine", |
91 |
| - "choose_qparams_affine", |
| 93 | + # top level API - manual |
92 | 94 | "quantize_",
|
93 | 95 | "int8_dynamic_activation_int4_weight",
|
94 | 96 | "int8_dynamic_activation_int8_weight",
|
95 | 97 | "int8_dynamic_activation_int8_semi_sparse_weight",
|
96 | 98 | "int4_weight_only",
|
97 | 99 | "int8_weight_only",
|
| 100 | + "float8_weight_only", |
| 101 | + "float8_dynamic_activation_float8_weight", |
| 102 | + "float8_static_activation_float8_weight", |
98 | 103 | "uintx_weight_only",
|
99 | 104 | "fpx_weight_only",
|
100 |
| - "LinearActivationQuantizedTensor", |
| 105 | + # smooth quant - subject to change |
| 106 | + "swap_conv2d_1x1_to_linear", |
| 107 | + "get_scale", |
| 108 | + "SmoothFakeDynQuantMixin", |
| 109 | + "SmoothFakeDynamicallyQuantizedLinear", |
| 110 | + "swap_linear_with_smooth_fq_linear", |
| 111 | + "smooth_fq_linear_to_inference", |
| 112 | + "set_smooth_fq_attribute", |
| 113 | + "compute_error", |
| 114 | + # building blocks |
101 | 115 | "to_linear_activation_quantized",
|
102 | 116 | "to_weight_tensor_with_linear_activation_scale_metadata",
|
103 |
| - "float8_weight_only", |
104 |
| - "float8_dynamic_activation_float8_weight", |
105 |
| - "float8_static_activation_float8_weight", |
106 |
| - "Int8DynActInt4WeightGPTQQuantizer", |
107 |
| - "Int8DynActInt4WeightQuantizer", |
108 |
| - "Int8DynActInt4WeightLinear", |
109 |
| - "WeightOnlyInt8QuantLinear", |
110 |
| - "TwoStepQuantizer", |
111 |
| - "Quantizer", |
112 |
| - "ZeroPointDomain", |
113 |
| - "MappingType", |
114 | 117 | "AffineQuantizedMinMaxObserver",
|
115 | 118 | "AffineQuantizedObserverBase",
|
| 119 | + # quant primitive ops |
| 120 | + "choose_qparams_affine", |
| 121 | + "choose_qparams_affine_with_min_max", |
| 122 | + "choose_qparams_affine_floatx", |
| 123 | + "quantize_affine", |
| 124 | + "quantize_affine_floatx", |
| 125 | + "dequantize_affine", |
| 126 | + "dequantize_affine_floatx", |
| 127 | + "choose_qparams_and_quantize_affine_hqq", |
| 128 | + "fake_quantize_affine", |
| 129 | + "fake_quantize_affine_cachemask", |
| 130 | + # operators/kernels |
| 131 | + "safe_int_mm", |
| 132 | + "int_scaled_matmul", |
| 133 | + # dataclasses and types |
| 134 | + "MappingType", |
| 135 | + "ZeroPointDomain", |
| 136 | + "TorchAODType", |
116 | 137 | "PerTensor",
|
117 | 138 | "PerAxis",
|
118 | 139 | "PerGroup",
|
119 | 140 | "PerRow",
|
120 | 141 | "PerToken",
|
| 142 | + "LinearActivationQuantizedTensor", |
| 143 | + "Int4WeightOnlyGPTQQuantizer", |
| 144 | + "Int4WeightOnlyQuantizer", |
| 145 | + "Int8DynActInt4WeightGPTQQuantizer", |
| 146 | + "Int8DynActInt4WeightQuantizer", |
| 147 | + "Int8DynActInt4WeightLinear", |
| 148 | + "WeightOnlyInt8QuantLinear", |
| 149 | + "TwoStepQuantizer", |
| 150 | + "Quantizer", |
121 | 151 | ]
|
0 commit comments