From 3bfa2320659f14077584d112a9009040f8b9cba1 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Sat, 22 Feb 2025 01:30:58 +0000 Subject: [PATCH] matmul4quantizer clean up --- olive/passes/onnx/quantization.py | 241 ++++++++---------- .../passes/onnx/test_quantization.py | 55 ++-- 2 files changed, 121 insertions(+), 175 deletions(-) diff --git a/olive/passes/onnx/quantization.py b/olive/passes/onnx/quantization.py index 09b9c0dc0..6e2ee29a6 100644 --- a/olive/passes/onnx/quantization.py +++ b/olive/passes/onnx/quantization.py @@ -2,10 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import inspect import logging import tempfile from copy import deepcopy -from functools import partial from pathlib import Path from typing import Callable, Dict, List, Type, Union @@ -14,7 +14,7 @@ from olive.common.config_utils import validate_config from olive.common.pydantic_v1 import validator -from olive.common.utils import exclude_keys, hash_string +from olive.common.utils import IntEnumBase, StrEnumBase, exclude_keys, hash_string from olive.data.config import DataConfig from olive.exception import OlivePassError from olive.hardware.accelerator import AcceleratorSpec @@ -462,8 +462,6 @@ def _run_for_config( # get the dataloader dataloader = get_calibration_dataloader(config) if config.prepare_qnn_config: - import inspect - from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config symmetric_options, qnn_extra_options = {}, {} @@ -613,6 +611,19 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon class OnnxMatMul4Quantizer(Pass): """Quantize ONNX models' MatMul operations to 4-bit weights.""" + class AccuracyLevel(IntEnumBase): + unset = 0 + fp32 = 1 + fp16 = 2 + bf16 = 3 + int8 = 4 + + class Algorithm(StrEnumBase): + DEFAULT = "DEFAULT" + HQQ = "HQQ" + RTN = "RTN" + GPTQ = "GPTQ" + @classmethod def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]: return { @@ -631,74 +642,50 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon default_value=None, description="List of node names to exclude from quantization.", ), + "nodes_to_include": PassConfigParam( + type_=list, + default_value=None, + description="List of node names to include in quantization.", + ), + "op_types_to_quantize": PassConfigParam( + type_=list, + default_value=None, + description=( + 'List of operator types to quantize. Default value is None = ["MatMul"]. Supported op types' + " are: MatMul, Gather." + ), + ), + "quant_axes": PassConfigParam( + type_=Dict[str, int], + default_value=None, + description='op:axis, which axis to quantize for an op. Default is None = {"MatMul": 0, "Gather": 1}', + ), "accuracy_level": PassConfigParam( - # TODO(trajep): to make it searchable - type_=int, + type_=OnnxMatMul4Quantizer.AccuracyLevel, default_value=None, description=( - "Available from onnxruntime>=1.17.0 " - "The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), " - "or 4(int8) (default unset when 0 or None). It is used to control how input A is quantized or" - " downcast " - "internally while doing computation, for example: 0 means input A will not be quantized " - "or downcast while doing computation. 4 means input A can be quantized with the same " - "block_size to int8 internally from type T1. " - "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details " - "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits)." + "Accuracy level of the 4-bit quantized MatMul computation. Refer to the MatMulNBits contrib op's" + " 'accuracy_level' attribute for details" + " (https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits)." ), ), "algorithm": PassConfigParam( - type_=str, + type_=OnnxMatMul4Quantizer.Algorithm, default_value=None, description=( - "If 'None', the Matmul node with fp32 const weight will be quantize to int4." - "1. 'RTN' and 'GPTQ' are available from onnxruntime>=1.17.0 " - "- For 4b quantize a model with RTN or GPTQ algorithm. Please refer to " - "https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md " - "for more details on weight only quantization using IntelĀ® Neural Compressor. " - "2. 'DEFAULT', 'HQQ' are available from onnxruntime>=1.18.0 " - "- `DEFAULT` takes the same effect as `None`" - "- For HQQ, please refer to onnxruntime for more details: " - "https://github.com/microsoft/onnxruntime/blob/7e613ee821405b1192d0b71b9434a4f94643f1e4/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py#L102C1-L126C25" + "The algorithm used to quantize weight. If None, the default algorithm is used with quant" + " config created from the pass configuration." ), ), "weight_only_quant_configs": PassConfigParam( type_=dict, default_value=None, - description="""Available from onnxruntime>=1.17.0, if None, the default behavior - of given algorithm will be used. - The config is binding to the algorithm with following map: - 1. "algorithm" is "DEFAULT", by default, the weight_only_quant_configs is: - "weight_only_quant_configs": { - "block_size": 128, - "is_symmetric": False, - "accuracy_level": None - } - https://github.com/microsoft/onnxruntime/blob/7e613ee821405b1192d0b71b9434a4f94643f1e4/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py#L129C1-L140C45 - 2. "algorithm" is "HQQ", by default, the weight_only_quant_configs is: - "weight_only_quant_configs": { - "block_size": 128, // channel number in one block to execute a GPTQ quantization iteration. - "bits": 4, // how many bits to represent weight. - "axis": 1, // 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf - } - https://github.com/microsoft/onnxruntime/blob/7e613ee821405b1192d0b71b9434a4f94643f1e4/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py#L129C1-L140C45 - 3. "algorithm" is "RTN", by default, the weight_only_quant_configs is: - "weight_only_quant_configs": { - "ratios": None, // type: dict, percentile of clip. Defaults to None. - } - https://github.com/microsoft/onnxruntime/blob/7e613ee821405b1192d0b71b9434a4f94643f1e4/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py#L42C1-L60C29 - 4. "algorithm" is "GPTQ", by default, the weight_only_quant_configs is: - "weight_only_quant_configs": { - "percdamp": 0.01, // percent of the average Hessian diagonal to use for dampening. - "block_size": 128, - "actorder": False, // whether rearrange Hessian matrix considering the diag's value. - "mse": False, // whether get scale and zero point with mse error. - "perchannel": True, // whether quantize weight per-channel. - } - For GPTQ's "calibration_data_reader", you can provider a dataloader function or a - data config like what we do for onnx static quantization. - https://github.com/microsoft/onnxruntime/blob/7e613ee821405b1192d0b71b9434a4f94643f1e4/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py#L63C1-L99C37 - """, + description=( + "If 'algorithm' is provided and this is None, the config is constructed from the pass" + " configuration. If provided, the it takes precedence. Refer to" + " https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py" + " for details." + ), ), **get_external_data_config(), # static_dataloder_config @@ -708,8 +695,6 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon @classmethod def _validators(cls) -> Dict[str, Callable]: return { - "validate_accuracy_level": validator("accuracy_level", allow_reuse=True)(_validate_accuracy_level), - "validate_algorithm": validator("algorithm", allow_reuse=True)(_validate_algorithm), "validate_quant_config": validator("weight_only_quant_configs", allow_reuse=True)( _validate_weight_only_quant_config ), @@ -718,6 +703,26 @@ def _validators(cls) -> Dict[str, Callable]: def _run_for_config( self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: + from onnxruntime import __version__ as OrtVersion + + if version.parse(OrtVersion) < version.parse("1.18.0"): + raise ValueError("MatMul4BitsQuantizer is only supported for onnxruntime>=1.18.0") + + from onnxruntime.quantization.matmul_4bits_quantizer import ( + DefaultWeightOnlyQuantConfig, + GPTQWeightOnlyQuantConfig, + HQQWeightOnlyQuantConfig, + MatMul4BitsQuantizer, + RTNWeightOnlyQuantConfig, + ) + + algo_to_config = { + "DEFAULT": DefaultWeightOnlyQuantConfig, + "HQQ": HQQWeightOnlyQuantConfig, + "RTN": RTNWeightOnlyQuantConfig, + "GPTQ": GPTQWeightOnlyQuantConfig, + } + if model_has_adapters(model.model_path) and config.algorithm not in {None, "DEFAULT"}: logger.info( "Model has adapters which should only be quantized with algorithm=None or DEFAULT. Got %s. Returning" @@ -726,65 +731,44 @@ def _run_for_config( ) return model - from onnxruntime import __version__ as OrtVersion - - if version.parse(OrtVersion) < version.parse("1.16.2"): - raise ValueError("MatMul4BitsQuantizer is only supported in onnxruntime >= 1.16.2") - - from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer - output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) - weight_only_quant_config_class = None - weight_only_quant_config = None - algo_config = deepcopy(config.weight_only_quant_configs) or {} - if version.parse(OrtVersion) >= version.parse("1.17.0"): - from onnxruntime.quantization.matmul_4bits_quantizer import ( - GPTQWeightOnlyQuantConfig, - RTNWeightOnlyQuantConfig, - ) - - if config.algorithm == "RTN": - weight_only_quant_config_class = RTNWeightOnlyQuantConfig - elif config.algorithm == "GPTQ": - if "block_size" in algo_config and version.parse(OrtVersion) < version.parse("1.18.0"): - # ort 1.17.0+ uses blocksize instead of block_size :( - algo_config["blocksize"] = algo_config["block_size"] - algo_config.pop("block_size") - dataloader = get_calibration_dataloader(config) - weight_only_quant_config_class = partial(GPTQWeightOnlyQuantConfig, calibration_data_reader=dataloader) - - if version.parse(OrtVersion) >= version.parse("1.18.0"): - from onnxruntime.quantization.matmul_4bits_quantizer import ( - DefaultWeightOnlyQuantConfig, - HQQWeightOnlyQuantConfig, - ) - - if config.algorithm == "DEFAULT": - weight_only_quant_config_class = DefaultWeightOnlyQuantConfig - elif config.algorithm == "HQQ": - weight_only_quant_config_class = HQQWeightOnlyQuantConfig - elif config.algorithm in ("HQQ", "DEFAULT"): - raise ValueError("HQQ and DEFAULT algorithm are only supported in onnxruntime >= 1.18.0") - - if weight_only_quant_config_class: - weight_only_quant_config = weight_only_quant_config_class(**algo_config) - quant = MatMul4BitsQuantizer( - model.load_model(), - block_size=config.block_size, - is_symmetric=config.is_symmetric, - nodes_to_exclude=config.nodes_to_exclude, - accuracy_level=config.accuracy_level, - algo_config=weight_only_quant_config, - ) + kwargs = { + "block_size": config.block_size, + "is_symmetric": config.is_symmetric, + "nodes_to_exclude": config.nodes_to_exclude, + "accuracy_level": config.accuracy_level, + } + for key in ["nodes_to_include", "op_types_to_quantize", "quant_axes"]: + if value := getattr(config, key): + if version.parse(OrtVersion) < version.parse("1.20.0"): + raise ValueError(f"MatMul4BitsQuantizer {key} is only supported for onnxruntime>=1.20.0") + kwargs[key] = value + + if woq_config_class := algo_to_config.get(config.algorithm, None): + algo_config = config.weight_only_quant_configs or {} + for key in inspect.signature(woq_config_class.__init__).parameters: + if key in algo_config: + if key in kwargs and kwargs[key] != algo_config[key]: + # algo config overrides pass config + logger.warning( + "The pass config parameter %s's value %s is different from the algorithm config's value %s." + " The algorithm config's value will be used.", + key, + kwargs[key], + algo_config[key], + ) + kwargs[key] = algo_config[key] + elif key in kwargs: + # get value from pass config + algo_config[key] = kwargs[key] + if config.algorithm == "GPTQ": + algo_config["calibration_data_reader"] = get_calibration_dataloader(config) + kwargs["algo_config"] = woq_config_class(**algo_config) else: - # TODO(trajep): remove this block once we migrate customer to onnxruntime>=1.17.0 all - quant = MatMul4BitsQuantizer( - model.load_model(), - block_size=config.block_size, - is_symmetric=config.is_symmetric, - nodes_to_exclude=config.nodes_to_exclude, - ) + kwargs["algo_config"] = None + + quant = MatMul4BitsQuantizer(model.load_model(), **kwargs) quant.process() # topologically sort the graph at the end since previous optimizations may have broken it quant.model.topological_sort() @@ -794,26 +778,6 @@ def _run_for_config( return model_proto_to_olive_model(quant.model.model, output_model_path, config) -def _validate_accuracy_level(v, values, field): - if not v: - return v - - if v not in (0, 1, 2, 3, 4): - raise ValueError(f"OnnxMatMul4Quantizer {field.name} must be 0(unset), 1(fp32), 2(fp16), 3(bf16) or 4(int8)") - - return v - - -def _validate_algorithm(v, values, field): - if not v: - return v - - if v not in ("DEFAULT", "HQQ", "RTN", "GPTQ"): - raise ValueError(f"OnnxMatMul4Quantizer {field.name} must be 'DEFAULT', 'HQQ', 'RTN', 'GPTQ'") - - return v - - def _validate_weight_only_quant_config(v, values, field): if values.get("algorithm") is None: logger.debug("algorithm is not set, skip validation for weight_only_quant_configs") @@ -831,6 +795,9 @@ def _validate_weight_only_quant_config(v, values, field): default_config_keys = ["block_size", "bits", "axis"] elif values["algorithm"] == "GPTQ": default_config_keys = ["percdamp", "block_size", "actorder", "mse", "perchannel"] + else: + raise ValueError(f"Unsupported algorithm: {values['algorithm']}") + default_config_keys.extend(["op_types_to_quantize", "quant_axes"]) if not all(key in default_config_keys for key in config_keys): invalid_config_keys = set(config_keys) - set(default_config_keys) diff --git a/test/unit_test/passes/onnx/test_quantization.py b/test/unit_test/passes/onnx/test_quantization.py index 2825ea012..b986f7990 100644 --- a/test/unit_test/passes/onnx/test_quantization.py +++ b/test/unit_test/passes/onnx/test_quantization.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging from test.unit_test.utils import get_onnx_model, get_pytorch_model_dummy_input import pytest @@ -14,6 +19,7 @@ class DummyCalibrationDataReader(CalibrationDataReader): + # pylint: disable=W0223 def __init__(self, batch_size: int = 16): super().__init__() self.sample_counter = 500 @@ -112,48 +118,16 @@ def test_qnn_quantization(tmp_path): assert out is not None -@pytest.mark.skipif( - version.parse(OrtVersion) < version.parse("1.16.2"), - reason="matmul 4bit quantization is only supported in onnxruntime>=1.16.2", -) @pytest.mark.parametrize( ("algorithm", "weight_only_quant_configs"), [ (None, None), ("RTN", {"ratios": {}}), - ], -) -def test_matmul_4bit_quantization_without_dataloader(tmp_path, algorithm, weight_only_quant_configs): - input_model = get_onnx_model() - config = { - "block_size": 32, - "is_symmetric": True, - "nodes_to_exclude": [], - "accuracy_level": 4, - "algorithm": algorithm, - "weight_only_quant_configs": weight_only_quant_configs, - } - accelerator_spec = AcceleratorSpec( - accelerator_type="CPU", - execution_provider="CPUExecutionProvider", - ) - p = create_pass_from_dict(OnnxMatMul4Quantizer, config, disable_search=True, accelerator_spec=accelerator_spec) - out = p.run(input_model, tmp_path) - assert out is not None - - -@pytest.mark.skipif( - version.parse(OrtVersion) < version.parse("1.18.0"), - reason="matmul 4bit quantization with `DEFAULT` and `HQQ` is only supported in onnxruntime<1.18.0", -) -@pytest.mark.parametrize( - ("algorithm", "weight_only_quant_configs"), - [ ("DEFAULT", None), ("HQQ", None), ], ) -def test_matmul_4bit_quantization_without_dataloader_ort_1_18(tmp_path, algorithm, weight_only_quant_configs): +def test_matmul_4bit_quantization_without_dataloader(tmp_path, algorithm, weight_only_quant_configs): input_model = get_onnx_model() config = { "block_size": 32, @@ -172,7 +146,7 @@ def test_matmul_4bit_quantization_without_dataloader_ort_1_18(tmp_path, algorith assert out is not None -def test_matmul_gptq_with_dataloader(tmp_path): +def test_matmul_4bits_gptq_with_dataloader(tmp_path, caplog): input_model = get_onnx_model() config = { "block_size": 32, @@ -191,15 +165,20 @@ def test_matmul_gptq_with_dataloader(tmp_path): accelerator_type="CPU", execution_provider="CPUExecutionProvider", ) + # capture log + logger = logging.getLogger("olive") + logger.propagate = True + p = create_pass_from_dict(OnnxMatMul4Quantizer, config, disable_search=True, accelerator_spec=accelerator_spec) out = p.run(input_model, tmp_path) assert out is not None + assert "Invalid weight_only_quant_configs: {'use_less_config'} for algorithm GPTQ" in caplog.text + assert ( + "The pass config parameter block_size's value 32 is different from the algorithm config's value 128. The" + " algorithm config's value will be used." in caplog.text + ) -@pytest.mark.skipif( - version.parse(OrtVersion) < version.parse("1.16.2"), - reason="matmul 4bit quantization is only supported in onnxruntime>=1.16.2", -) def test_invalid_config_for_matmul_4bits(): config = { "block_size": 32,