From 6e5c3b4c1deda629349a9543228584be7f71bb7e Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Wed, 19 Feb 2025 19:57:59 +0000 Subject: [PATCH] latest --- examples/utils/generator.py | 11 +-- olive/passes/onnx/model_builder.py | 34 ++++++++-- olive/passes/onnx/quantization.py | 103 ++++++++++++++--------------- 3 files changed, 86 insertions(+), 62 deletions(-) diff --git a/examples/utils/generator.py b/examples/utils/generator.py index 96e34fa78..10ea55216 100644 --- a/examples/utils/generator.py +++ b/examples/utils/generator.py @@ -267,7 +267,7 @@ def generate( session = self.sessions["iterator"] io_binding = session.io_binding() if use_io_binding else None - print(np_buffers["attention_mask"]) + # print(np_buffers["attention_mask"]) if use_io_binding: if idx < 2: @@ -388,9 +388,9 @@ def generate( first_cache = outputs[1] if use_io_binding: first_cache = first_cache.numpy() - print(first_cache[:, 0, :, 0]) - if idx == 2: - sdcd + # print(first_cache[:, 0, :, 0]) + # if idx == 2: + # sdcd # update cache cache.update(outputs[1:]) @@ -439,7 +439,7 @@ def get_initial_inputs( else: padding_args = {"padding": "longest"} # encode prompt - encodings_dict = self.tokenizer(prompt, return_tensors="np", **padding_args) + encodings_dict = self.tokenizer(prompt, return_tensors="np", add_special_tokens=False, **padding_args) input_ids = encodings_dict["input_ids"].astype(self.input_info["input_ids"]["dtype"]) batch_size, prompt_length = input_ids.shape attention_mask = encodings_dict["attention_mask"] @@ -457,6 +457,7 @@ def get_initial_inputs( int(attention_mask.sum(axis=-1).max()), ) if isinstance(cache, GQASharedCache): + print(cache.max_cache_len, prompt_length) attention_mask = np.concatenate( [attention_mask, np.zeros((batch_size, cache.max_cache_len - prompt_length), dtype=np.int32)], 1 ) diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index 7e7b9cf04..edc3fff30 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -8,10 +8,11 @@ import json import logging from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import onnx import transformers +from packaging import version from olive.common.utils import IntEnumBase, StrEnumBase from olive.hardware.accelerator import AcceleratorSpec, Device @@ -66,8 +67,24 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon "int4_accuracy_level": PassConfigParam( type_=ModelBuilder.AccuracyLevel, required=False, - description="Specify the minimum accuracy level for activation of MatMul in int4 quantization.", + description=( + "Specify the minimum accuracy level for activation of MatMul in int4 quantization. Defualt is None" + " (= 0)." + ), ), + # "int4_is_symmetric": PassConfigParam( + # type_=bool, + # required=False, + # description="Whether to use symmetric quantization for int4 quantization. Default is None (= True).", + # ), + # "int4_op_types_to_quantize": PassConfigParam( + # type_=List[str], + # required=False, + # description=( + # 'Specify the op types to quantize for int4 quantization. Default is None (= [ "MatMul" ]). Example:' + # ' ["MatMul", "Gemm"]' + # ), + # ), "exclude_embeds": PassConfigParam( type_=bool, default_value=False, @@ -127,6 +144,7 @@ def _run_for_config( config: Dict[str, Any], output_model_path: str, ) -> ONNXModelHandler: + from onnxruntime_genai import __version__ as genai_version from onnxruntime_genai.models.builder import create_model precision = config["precision"] @@ -174,16 +192,24 @@ def _run_for_config( if config.get("int4_accuracy_level"): extra_args["int4_accuracy_level"] = config["int4_accuracy_level"].value - # args that are only checked for presence, not value + # args that are only checked for presence, not value, until 0.6.0 for arg in ["exclude_embeds", "exclude_lm_head"]: if config[arg]: extra_args[arg] = True # args that are checked for presence and value (if present) - for arg in ["enable_cuda_graph"]: + # for arg, min_version in [("enable_cuda_graph", None), ("int4_is_symmetric", 0.6.0)]: + for arg, min_version in [("enable_cuda_graph", None)]: if config[arg] is not None: + if min_version and version.parse(genai_version) < version.parse(min_version): + raise ValueError( + f"{arg} is not supported in genai version {genai_version}. Minimum version: {min_version}" + ) extra_args[arg] = "1" if config[arg] else "0" + # if config["int4_op_types_to_quantize"]: + # extra_args["int4_op_types_to_quantize"] = "/".join(config["int4_op_types_to_quantize"]) + model_attributes = copy.deepcopy(model.model_attributes or {}) try: diff --git a/olive/passes/onnx/quantization.py b/olive/passes/onnx/quantization.py index f41e25513..399507586 100644 --- a/olive/passes/onnx/quantization.py +++ b/olive/passes/onnx/quantization.py @@ -678,6 +678,16 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon default_value=None, description="List of node names to exclude from quantization.", ), + "nodes_to_include": PassConfigParam( + type_=list, + default_value=None, + description="List of node names to include in quantization.", + ), + "op_types_to_quantize": PassConfigParam( + type_=list, + default_value=None, + description='List of operator types to quantize. Default value is None (= [ "MatMul" ]).', + ), "accuracy_level": PassConfigParam( # TODO(trajep): to make it searchable type_=int, @@ -765,6 +775,19 @@ def _validators(cls) -> Dict[str, Callable]: def _run_for_config( self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str ) -> ONNXModelHandler: + from onnxruntime import __version__ as OrtVersion + + if version.parse(OrtVersion) < version.parse("1.18.0"): + raise ValueError("MatMul4BitsQuantizer is only supported for onnxruntime>=1.18.0") + + from onnxruntime.quantization.matmul_4bits_quantizer import ( + DefaultWeightOnlyQuantConfig, + GPTQWeightOnlyQuantConfig, + HQQWeightOnlyQuantConfig, + MatMul4BitsQuantizer, + RTNWeightOnlyQuantConfig, + ) + if model_has_adapters(model.model_path) and config["algorithm"] not in {None, "DEFAULT"}: logger.info( "Model has adapters which should only be quantized with algorithm=None or DEFAULT. Got %s. Returning" @@ -773,65 +796,38 @@ def _run_for_config( ) return model - from onnxruntime import __version__ as OrtVersion - - if version.parse(OrtVersion) < version.parse("1.16.2"): - raise ValueError("MatMul4BitsQuantizer is only supported in onnxruntime >= 1.16.2") - - from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer - output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) weight_only_quant_config_class = None weight_only_quant_config = None algo_config = deepcopy(config["weight_only_quant_configs"] or {}) - if version.parse(OrtVersion) >= version.parse("1.17.0"): - from onnxruntime.quantization.matmul_4bits_quantizer import ( - GPTQWeightOnlyQuantConfig, - RTNWeightOnlyQuantConfig, - ) - - if config["algorithm"] == "RTN": - weight_only_quant_config_class = RTNWeightOnlyQuantConfig - elif config["algorithm"] == "GPTQ": - if "block_size" in algo_config and version.parse(OrtVersion) < version.parse("1.18.0"): - # ort 1.17.0+ uses blocksize instead of block_size :( - algo_config["blocksize"] = algo_config["block_size"] - algo_config.pop("block_size") - dataloader = get_calibration_dataloader(config, model.model_path, model.io_config) - weight_only_quant_config_class = partial(GPTQWeightOnlyQuantConfig, calibration_data_reader=dataloader) - - if version.parse(OrtVersion) >= version.parse("1.18.0"): - from onnxruntime.quantization.matmul_4bits_quantizer import ( - DefaultWeightOnlyQuantConfig, - HQQWeightOnlyQuantConfig, - ) - if config["algorithm"] == "DEFAULT": - weight_only_quant_config_class = DefaultWeightOnlyQuantConfig - elif config["algorithm"] == "HQQ": - weight_only_quant_config_class = HQQWeightOnlyQuantConfig - elif config["algorithm"] in ("HQQ", "DEFAULT"): - raise ValueError("HQQ and DEFAULT algorithm are only supported in onnxruntime >= 1.18.0") - - if weight_only_quant_config_class: - weight_only_quant_config = weight_only_quant_config_class(**algo_config) - quant = MatMul4BitsQuantizer( - model.load_model(), - block_size=config["block_size"], - is_symmetric=config["is_symmetric"], - nodes_to_exclude=config["nodes_to_exclude"], - accuracy_level=config["accuracy_level"], - algo_config=weight_only_quant_config, - ) - else: - # TODO(trajep): remove this block once we migrate customer to onnxruntime>=1.17.0 all - quant = MatMul4BitsQuantizer( - model.load_model(), - block_size=config["block_size"], - is_symmetric=config["is_symmetric"], - nodes_to_exclude=config["nodes_to_exclude"], - ) + if config["algorithm"] == "RTN": + weight_only_quant_config_class = RTNWeightOnlyQuantConfig + elif config["algorithm"] == "GPTQ": + dataloader = get_calibration_dataloader(config, model.model_path, model.io_config) + weight_only_quant_config_class = partial(GPTQWeightOnlyQuantConfig, calibration_data_reader=dataloader) + elif config["algorithm"] == "DEFAULT": + weight_only_quant_config_class = DefaultWeightOnlyQuantConfig + elif config["algorithm"] == "HQQ": + weight_only_quant_config_class = HQQWeightOnlyQuantConfig + + if weight_only_quant_config_class: + weight_only_quant_config = weight_only_quant_config_class(**algo_config) + + kwargs = { + "block_size": config["block_size"], + "is_symmetric": config["is_symmetric"], + "nodes_to_exclude": config["nodes_to_exclude"], + "accuracy_level": config["accuracy_level"], + "algo_config": weight_only_quant_config, + } + for key in ["nodes_to_include", "op_types_to_quantize"]: + if config[key]: + if version.parse(OrtVersion) < version.parse("1.20.0"): + raise ValueError(f"MatMul4BitsQuantizer {key} is only supported for onnxruntime>=1.20.0") + kwargs[key] = config[key] + quant = MatMul4BitsQuantizer(model.load_model(), **kwargs) quant.process() # topologically sort the graph at the end since previous optimizations may have broken it quant.model.topological_sort() @@ -878,6 +874,7 @@ def _validate_weight_only_quant_config(v, values, field): default_config_keys = ["block_size", "bits", "axis"] elif values["algorithm"] == "GPTQ": default_config_keys = ["percdamp", "block_size", "actorder", "mse", "perchannel"] + default_config_keys.append("op_types_to_quantize") if not all(key in default_config_keys for key in config_keys): invalid_config_keys = set(config_keys) - set(default_config_keys)