From 4a76579a5135a4b10ef027c07ee3c8b1cbea5fe2 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Wed, 29 Jan 2025 20:19:06 -0800 Subject: [PATCH] GPTQ: Default wikitest calibration dataset (#1581) ## Describe your changes Use wikitest train dataset by default for `GPTQQuantizer`. It uses 2048 as the sequence length. ## Checklist before requesting a review - [x] Add unit tests for this change. - [x] Make sure all tests can pass. - [x] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --- olive/cli/quantize.py | 10 +-- olive/passes/pytorch/autoawq.py | 2 +- olive/passes/pytorch/gptq.py | 77 ++++++++++++++++------ test/unit_test/passes/pytorch/test_gptq.py | 45 +++++++------ 4 files changed, 88 insertions(+), 46 deletions(-) diff --git a/olive/cli/quantize.py b/olive/cli/quantize.py index ff823c09e..f11e54e23 100644 --- a/olive/cli/quantize.py +++ b/olive/cli/quantize.py @@ -86,6 +86,11 @@ def _get_run_config(self, tempdir: str) -> Dict[str, Any]: is_hf_model = config["input_model"]["type"].lower() == "hfmodel" if is_hf_model and self.args.algorithm not in ["awq", "gptq", "rtn"]: raise ValueError("Selected algorithm is not supported for HuggingFace models.") + if not is_hf_model and "gptq" in self.args.algorithm and not self.args.data_name: + # hf model doesn't require user provided data + raise ValueError("data_name is required to use gptq.") + if self.args.data_name: + config["passes"]["gptq"]["data_config"] = "default_data_config" defaults_key = "hf_model_defaults" if is_hf_model else "onnx_model_defaults" @@ -141,9 +146,6 @@ def _get_run_config(self, tempdir: str) -> Dict[str, Any]: return config def run(self): - if ("gptq" in self.args.algorithm) and (not self.args.data_name): - raise ValueError("data_name is required to use gptq.") - self._run_workflow() @@ -168,7 +170,7 @@ def run(self): "passes": { # Pytorch algorithms "awq": {"type": "AutoAWQQuantizer", "w_bit": 4}, - "gptq": {"type": "GptqQuantizer", "bits": 4, "data_config": "default_data_config"}, + "gptq": {"type": "GptqQuantizer", "bits": 4}, # Onnx algorithms "bnb4": {"type": "OnnxBnb4Quantization", "quant_type": "nf4"}, "matmul4": {"type": "OnnxMatMul4Quantizer", "accuracy_level": 4}, diff --git a/olive/passes/pytorch/autoawq.py b/olive/passes/pytorch/autoawq.py index 00db1386c..3e83b2858 100644 --- a/olive/passes/pytorch/autoawq.py +++ b/olive/passes/pytorch/autoawq.py @@ -103,7 +103,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon "data_config": PassConfigParam( type_=Union[DataConfig, Dict], default_value=None, - description="Data config for quantization. Default value is None.", + description="Data config for quantization. If not provided, pile validation data will be used.", ), } diff --git a/olive/passes/pytorch/gptq.py b/olive/passes/pytorch/gptq.py index a81ff1161..bb42cbf52 100644 --- a/olive/passes/pytorch/gptq.py +++ b/olive/passes/pytorch/gptq.py @@ -7,7 +7,7 @@ from argparse import Namespace from copy import deepcopy from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import torch from packaging import version @@ -16,6 +16,7 @@ from olive.common.config_utils import validate_config from olive.common.hf.wrapper import ModelWrapper from olive.data.config import DataConfig +from olive.data.template import huggingface_data_config_template from olive.hardware.accelerator import AcceleratorSpec from olive.model import HfModelHandler, PyTorchModelHandler from olive.model.utils.path_utils import normalize_path_suffix @@ -93,9 +94,10 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon "data_config": PassConfigParam( type_=Union[DataConfig, Dict], default_value=None, - description=""" - Data config for quantization. Default value is None. - """, + description=( + "Data config for quantization. If not provided, wikitest train data will be used for HfModels." + " Required for PyTorch models." + ), ), } @@ -112,22 +114,7 @@ def _run_for_config( # will move each block(layer) to cuda before quantization and move back to cpu when finished. raise ValueError("Please use GPU to run gptq quantization.") - dataset = None - if config["data_config"]: - data_config = validate_config(config["data_config"], DataConfig) - dataloader = data_config.to_data_container().create_dataloader() - dataset = [data[0] for data in dataloader] - - if ( - not dataset - or not isinstance(dataset, list) - or not isinstance(dataset[0], dict) - or ("input_ids" not in dataset[0] or "attention_mask" not in dataset[0]) - ): - raise ValueError( - "Provided dataset is invalid. The returned datasets is a list of tokenized data " - "(e.g. [{ 'input_ids': [ 1, 100, 15, ... ],'attention_mask': [ 1, 1, 1, ... ]},...])" - ) + dataset = self.get_dataset(model, config) adapter_path = None if isinstance(model, HfModelHandler) and model.adapter_path: @@ -240,6 +227,35 @@ def _run_for_config( new_load_kwargs["extra_args"]["use_safetensors"] = True return inherit_hf_from_hf(model, output_model_path, adapter_path=adapter_path, load_kwargs=new_load_kwargs) + def get_dataset( + self, model: Union[HfModelHandler, PyTorchModelHandler], config: Dict[str, Any] + ) -> List[Dict[str, Any]]: + """Get the dataset for quantization.""" + data_config = config["data_config"] + if not data_config and isinstance(model, HfModelHandler): + data_config = self.get_calibration_data_config( + model.model_name_or_path, trust_remote_code=model.get_load_kwargs().get("trust_remote_code", None) + ) + elif not data_config: + raise ValueError("Data config is required for PyTorch model.") + data_config = validate_config(data_config, DataConfig) + dataloader = data_config.to_data_container().create_dataloader() + # each batch consists of (input_data, labels) + dataset = [data[0] for data in dataloader] + + if ( + not dataset + or not isinstance(dataset, list) + or not isinstance(dataset[0], dict) + or ("input_ids" not in dataset[0] or "attention_mask" not in dataset[0]) + ): + raise ValueError( + "Provided dataset is invalid. The returned datasets is a list of tokenized data " + "(e.g. [{ 'input_ids': [[ 1, 100, 15, ... ]],'attention_mask': [[ 1, 1, 1, ... ]]},...])" + ) + + return dataset + @staticmethod def get_gptq_info(model_wrapper: ModelWrapper, name: str) -> List[str]: """Get the GPTQ info from the model wrapper.""" @@ -257,3 +273,24 @@ def get_gptq_info(model_wrapper: ModelWrapper, name: str) -> List[str]: return model_wrapper.get_layers()[1] raise ValueError(f"Unknown key {name}") + + @staticmethod + def get_calibration_data_config(model_name_or_path: str, trust_remote_code: Optional[bool] = None): + return huggingface_data_config_template( + model_name=model_name_or_path, + task="text-generation", + load_dataset_config={ + "data_name": "wikitext", + "subset": "wikitext-2-raw-v1", + # only require 128 samples for calibration + "split": "train[:1000]", + "trust_remote_code": trust_remote_code, + }, + pre_process_data_config={ + # should we randomize the data? + "add_special_tokens": False, + "max_seq_len": 2048, + "max_samples": 128, + "trust_remote_code": trust_remote_code, + }, + ) diff --git a/test/unit_test/passes/pytorch/test_gptq.py b/test/unit_test/passes/pytorch/test_gptq.py index 13c832c3a..8f9241f15 100644 --- a/test/unit_test/passes/pytorch/test_gptq.py +++ b/test/unit_test/passes/pytorch/test_gptq.py @@ -13,38 +13,41 @@ from olive.passes.olive_pass import create_pass_from_dict from olive.passes.pytorch.gptq import GptqQuantizer +test_gptq_dc_config = DataConfig( + name="test_gptq_dc_config", + type="DummyDataContainer", + load_dataset_config=DataComponentConfig( + type="dummy_dataset", + params={ + "input_names": ["input_ids", "attention_mask"], + "input_shapes": [[1, 128], [1, 128]], + "input_types": ["int64", "int64"], + "max_samples": 128, + }, + ), + pre_process_data_config=DataComponentConfig(type="skip_pre_process"), + post_process_data_config=DataComponentConfig(type="skip_post_process"), +) + @pytest.mark.skipif( not torch.cuda.is_available(), reason="gptq requires GPU.", ) @pytest.mark.parametrize( - ("model_path", "expected_model_type"), - [("katuni4ka/tiny-random-phi3", "Phi3ForCausalLM"), ("facebook/opt-125m", "OPTForCausalLM")], + ("model_path", "expected_model_type", "data_config"), + [ + ("katuni4ka/tiny-random-phi3", "Phi3ForCausalLM", None), + ("katuni4ka/tiny-random-phi3", "Phi3ForCausalLM", test_gptq_dc_config), + ("facebook/opt-125m", "OPTForCausalLM", test_gptq_dc_config), + ], ) -def test_gptq_default(tmp_path: Path, model_path: str, expected_model_type: str): +def test_gptq_default(tmp_path: Path, model_path: str, expected_model_type: str, data_config: DataConfig): # setup input_model = HfModelHandler(model_path=model_path) - config = { - "data_config": DataConfig( - name="test_gptq_dc_config", - type="DummyDataContainer", - load_dataset_config=DataComponentConfig( - type="dummy_dataset", - params={ - "input_names": ["input_ids", "attention_mask"], - "input_shapes": [[1, 128], [1, 128]], - "input_types": ["int64", "int64"], - "max_samples": 128, - }, - ), - pre_process_data_config=DataComponentConfig(type="skip_pre_process"), - post_process_data_config=DataComponentConfig(type="skip_post_process"), - ) - } p = create_pass_from_dict( GptqQuantizer, - config, + {"data_config": data_config}, disable_search=True, accelerator_spec=AcceleratorSpec(accelerator_type=Device.GPU, execution_provider="CUDAExecutionProvider"), )