Skip to content

Commit

Permalink
matmul4quantizer clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
jambayk committed Feb 22, 2025
1 parent 0c7ff92 commit 3bfa232
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 175 deletions.
241 changes: 104 additions & 137 deletions olive/passes/onnx/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import inspect
import logging
import tempfile
from copy import deepcopy
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Type, Union

Expand All @@ -14,7 +14,7 @@

from olive.common.config_utils import validate_config
from olive.common.pydantic_v1 import validator
from olive.common.utils import exclude_keys, hash_string
from olive.common.utils import IntEnumBase, StrEnumBase, exclude_keys, hash_string
from olive.data.config import DataConfig
from olive.exception import OlivePassError
from olive.hardware.accelerator import AcceleratorSpec
Expand Down Expand Up @@ -462,8 +462,6 @@ def _run_for_config(
# get the dataloader
dataloader = get_calibration_dataloader(config)
if config.prepare_qnn_config:
import inspect

from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config

symmetric_options, qnn_extra_options = {}, {}
Expand Down Expand Up @@ -613,6 +611,19 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
class OnnxMatMul4Quantizer(Pass):
"""Quantize ONNX models' MatMul operations to 4-bit weights."""

class AccuracyLevel(IntEnumBase):
unset = 0
fp32 = 1
fp16 = 2
bf16 = 3
int8 = 4

class Algorithm(StrEnumBase):
DEFAULT = "DEFAULT"
HQQ = "HQQ"
RTN = "RTN"
GPTQ = "GPTQ"

@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
return {
Expand All @@ -631,74 +642,50 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
default_value=None,
description="List of node names to exclude from quantization.",
),
"nodes_to_include": PassConfigParam(
type_=list,
default_value=None,
description="List of node names to include in quantization.",
),
"op_types_to_quantize": PassConfigParam(
type_=list,
default_value=None,
description=(
'List of operator types to quantize. Default value is None = ["MatMul"]. Supported op types'
" are: MatMul, Gather."
),
),
"quant_axes": PassConfigParam(
type_=Dict[str, int],
default_value=None,
description='op:axis, which axis to quantize for an op. Default is None = {"MatMul": 0, "Gather": 1}',
),
"accuracy_level": PassConfigParam(
# TODO(trajep): to make it searchable
type_=int,
type_=OnnxMatMul4Quantizer.AccuracyLevel,
default_value=None,
description=(
"Available from onnxruntime>=1.17.0 "
"The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), "
"or 4(int8) (default unset when 0 or None). It is used to control how input A is quantized or"
" downcast "
"internally while doing computation, for example: 0 means input A will not be quantized "
"or downcast while doing computation. 4 means input A can be quantized with the same "
"block_size to int8 internally from type T1. "
"Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
"(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits)."
"Accuracy level of the 4-bit quantized MatMul computation. Refer to the MatMulNBits contrib op's"
" 'accuracy_level' attribute for details"
" (https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits)."
),
),
"algorithm": PassConfigParam(
type_=str,
type_=OnnxMatMul4Quantizer.Algorithm,
default_value=None,
description=(
"If 'None', the Matmul node with fp32 const weight will be quantize to int4."
"1. 'RTN' and 'GPTQ' are available from onnxruntime>=1.17.0 "
"- For 4b quantize a model with RTN or GPTQ algorithm. Please refer to "
"https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md "
"for more details on weight only quantization using Intel® Neural Compressor. "
"2. 'DEFAULT', 'HQQ' are available from onnxruntime>=1.18.0 "
"- `DEFAULT` takes the same effect as `None`"
"- For HQQ, please refer to onnxruntime for more details: "
"https://github.com/microsoft/onnxruntime/blob/7e613ee821405b1192d0b71b9434a4f94643f1e4/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py#L102C1-L126C25"
"The algorithm used to quantize weight. If None, the default algorithm is used with quant"
" config created from the pass configuration."
),
),
"weight_only_quant_configs": PassConfigParam(
type_=dict,
default_value=None,
description="""Available from onnxruntime>=1.17.0, if None, the default behavior
of given algorithm will be used.
The config is binding to the algorithm with following map:
1. "algorithm" is "DEFAULT", by default, the weight_only_quant_configs is:
"weight_only_quant_configs": {
"block_size": 128,
"is_symmetric": False,
"accuracy_level": None
}
https://github.com/microsoft/onnxruntime/blob/7e613ee821405b1192d0b71b9434a4f94643f1e4/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py#L129C1-L140C45
2. "algorithm" is "HQQ", by default, the weight_only_quant_configs is:
"weight_only_quant_configs": {
"block_size": 128, // channel number in one block to execute a GPTQ quantization iteration.
"bits": 4, // how many bits to represent weight.
"axis": 1, // 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf
}
https://github.com/microsoft/onnxruntime/blob/7e613ee821405b1192d0b71b9434a4f94643f1e4/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py#L129C1-L140C45
3. "algorithm" is "RTN", by default, the weight_only_quant_configs is:
"weight_only_quant_configs": {
"ratios": None, // type: dict, percentile of clip. Defaults to None.
}
https://github.com/microsoft/onnxruntime/blob/7e613ee821405b1192d0b71b9434a4f94643f1e4/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py#L42C1-L60C29
4. "algorithm" is "GPTQ", by default, the weight_only_quant_configs is:
"weight_only_quant_configs": {
"percdamp": 0.01, // percent of the average Hessian diagonal to use for dampening.
"block_size": 128,
"actorder": False, // whether rearrange Hessian matrix considering the diag's value.
"mse": False, // whether get scale and zero point with mse error.
"perchannel": True, // whether quantize weight per-channel.
}
For GPTQ's "calibration_data_reader", you can provider a dataloader function or a
data config like what we do for onnx static quantization.
https://github.com/microsoft/onnxruntime/blob/7e613ee821405b1192d0b71b9434a4f94643f1e4/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py#L63C1-L99C37
""",
description=(
"If 'algorithm' is provided and this is None, the config is constructed from the pass"
" configuration. If provided, the it takes precedence. Refer to"
" https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py"
" for details."
),
),
**get_external_data_config(),
# static_dataloder_config
Expand All @@ -708,8 +695,6 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
@classmethod
def _validators(cls) -> Dict[str, Callable]:
return {
"validate_accuracy_level": validator("accuracy_level", allow_reuse=True)(_validate_accuracy_level),
"validate_algorithm": validator("algorithm", allow_reuse=True)(_validate_algorithm),
"validate_quant_config": validator("weight_only_quant_configs", allow_reuse=True)(
_validate_weight_only_quant_config
),
Expand All @@ -718,6 +703,26 @@ def _validators(cls) -> Dict[str, Callable]:
def _run_for_config(
self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str
) -> ONNXModelHandler:
from onnxruntime import __version__ as OrtVersion

if version.parse(OrtVersion) < version.parse("1.18.0"):
raise ValueError("MatMul4BitsQuantizer is only supported for onnxruntime>=1.18.0")

from onnxruntime.quantization.matmul_4bits_quantizer import (
DefaultWeightOnlyQuantConfig,
GPTQWeightOnlyQuantConfig,
HQQWeightOnlyQuantConfig,
MatMul4BitsQuantizer,
RTNWeightOnlyQuantConfig,
)

algo_to_config = {
"DEFAULT": DefaultWeightOnlyQuantConfig,
"HQQ": HQQWeightOnlyQuantConfig,
"RTN": RTNWeightOnlyQuantConfig,
"GPTQ": GPTQWeightOnlyQuantConfig,
}

if model_has_adapters(model.model_path) and config.algorithm not in {None, "DEFAULT"}:
logger.info(
"Model has adapters which should only be quantized with algorithm=None or DEFAULT. Got %s. Returning"
Expand All @@ -726,65 +731,44 @@ def _run_for_config(
)
return model

from onnxruntime import __version__ as OrtVersion

if version.parse(OrtVersion) < version.parse("1.16.2"):
raise ValueError("MatMul4BitsQuantizer is only supported in onnxruntime >= 1.16.2")

from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer

output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name)

weight_only_quant_config_class = None
weight_only_quant_config = None
algo_config = deepcopy(config.weight_only_quant_configs) or {}
if version.parse(OrtVersion) >= version.parse("1.17.0"):
from onnxruntime.quantization.matmul_4bits_quantizer import (
GPTQWeightOnlyQuantConfig,
RTNWeightOnlyQuantConfig,
)

if config.algorithm == "RTN":
weight_only_quant_config_class = RTNWeightOnlyQuantConfig
elif config.algorithm == "GPTQ":
if "block_size" in algo_config and version.parse(OrtVersion) < version.parse("1.18.0"):
# ort 1.17.0+ uses blocksize instead of block_size :(
algo_config["blocksize"] = algo_config["block_size"]
algo_config.pop("block_size")
dataloader = get_calibration_dataloader(config)
weight_only_quant_config_class = partial(GPTQWeightOnlyQuantConfig, calibration_data_reader=dataloader)

if version.parse(OrtVersion) >= version.parse("1.18.0"):
from onnxruntime.quantization.matmul_4bits_quantizer import (
DefaultWeightOnlyQuantConfig,
HQQWeightOnlyQuantConfig,
)

if config.algorithm == "DEFAULT":
weight_only_quant_config_class = DefaultWeightOnlyQuantConfig
elif config.algorithm == "HQQ":
weight_only_quant_config_class = HQQWeightOnlyQuantConfig
elif config.algorithm in ("HQQ", "DEFAULT"):
raise ValueError("HQQ and DEFAULT algorithm are only supported in onnxruntime >= 1.18.0")

if weight_only_quant_config_class:
weight_only_quant_config = weight_only_quant_config_class(**algo_config)
quant = MatMul4BitsQuantizer(
model.load_model(),
block_size=config.block_size,
is_symmetric=config.is_symmetric,
nodes_to_exclude=config.nodes_to_exclude,
accuracy_level=config.accuracy_level,
algo_config=weight_only_quant_config,
)
kwargs = {
"block_size": config.block_size,
"is_symmetric": config.is_symmetric,
"nodes_to_exclude": config.nodes_to_exclude,
"accuracy_level": config.accuracy_level,
}
for key in ["nodes_to_include", "op_types_to_quantize", "quant_axes"]:
if value := getattr(config, key):
if version.parse(OrtVersion) < version.parse("1.20.0"):
raise ValueError(f"MatMul4BitsQuantizer {key} is only supported for onnxruntime>=1.20.0")
kwargs[key] = value

if woq_config_class := algo_to_config.get(config.algorithm, None):
algo_config = config.weight_only_quant_configs or {}
for key in inspect.signature(woq_config_class.__init__).parameters:
if key in algo_config:
if key in kwargs and kwargs[key] != algo_config[key]:
# algo config overrides pass config
logger.warning(
"The pass config parameter %s's value %s is different from the algorithm config's value %s."
" The algorithm config's value will be used.",
key,
kwargs[key],
algo_config[key],
)
kwargs[key] = algo_config[key]
elif key in kwargs:
# get value from pass config
algo_config[key] = kwargs[key]
if config.algorithm == "GPTQ":
algo_config["calibration_data_reader"] = get_calibration_dataloader(config)
kwargs["algo_config"] = woq_config_class(**algo_config)
else:
# TODO(trajep): remove this block once we migrate customer to onnxruntime>=1.17.0 all
quant = MatMul4BitsQuantizer(
model.load_model(),
block_size=config.block_size,
is_symmetric=config.is_symmetric,
nodes_to_exclude=config.nodes_to_exclude,
)
kwargs["algo_config"] = None

quant = MatMul4BitsQuantizer(model.load_model(), **kwargs)
quant.process()
# topologically sort the graph at the end since previous optimizations may have broken it
quant.model.topological_sort()
Expand All @@ -794,26 +778,6 @@ def _run_for_config(
return model_proto_to_olive_model(quant.model.model, output_model_path, config)


def _validate_accuracy_level(v, values, field):
if not v:
return v

if v not in (0, 1, 2, 3, 4):
raise ValueError(f"OnnxMatMul4Quantizer {field.name} must be 0(unset), 1(fp32), 2(fp16), 3(bf16) or 4(int8)")

return v


def _validate_algorithm(v, values, field):
if not v:
return v

if v not in ("DEFAULT", "HQQ", "RTN", "GPTQ"):
raise ValueError(f"OnnxMatMul4Quantizer {field.name} must be 'DEFAULT', 'HQQ', 'RTN', 'GPTQ'")

return v


def _validate_weight_only_quant_config(v, values, field):
if values.get("algorithm") is None:
logger.debug("algorithm is not set, skip validation for weight_only_quant_configs")
Expand All @@ -831,6 +795,9 @@ def _validate_weight_only_quant_config(v, values, field):
default_config_keys = ["block_size", "bits", "axis"]
elif values["algorithm"] == "GPTQ":
default_config_keys = ["percdamp", "block_size", "actorder", "mse", "perchannel"]
else:
raise ValueError(f"Unsupported algorithm: {values['algorithm']}")
default_config_keys.extend(["op_types_to_quantize", "quant_axes"])

if not all(key in default_config_keys for key in config_keys):
invalid_config_keys = set(config_keys) - set(default_config_keys)
Expand Down
Loading

0 comments on commit 3bfa232

Please sign in to comment.