From f47265f3771f48109c008f9b46d89b9e084b61b6 Mon Sep 17 00:00:00 2001 From: Xiaoyu <85524621+xiaoyu-work@users.noreply.github.com> Date: Mon, 18 Nov 2024 14:03:15 -0800 Subject: [PATCH] Fix several bugs & add Bert inc example to pipeline (#1492) ## Describe your changes Fix Olive bugs & Add Bert inc examples to example pipeline. - `batch_size` is needed for Inc dataloader. Set default size to 1. - Custom eval func doesn't have batch_size as input. - For `QuantizationAwareTraining` pass, `train_data_config` is not required if user provides `training_loop_func`. - Latest transformers package will automatically save trained model as safetensors format. Add `save_safetensors` as false to train argument. - Some passes may have nested data_config in its config. Update auto-fill data_config logic to achieve this. ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --- .azure_pipelines/olive-examples.yaml | 6 ++++ .../bert/bert_inc_smoothquant_ptq_cpu.json | 3 +- examples/bert/requirements.txt | 4 ++- examples/bert/user_script.py | 5 +-- examples/test/local/test_bert_inc.py | 35 +++++++++++++++++++ .../pytorch/quantization_aware_training.py | 4 +-- olive/workflows/run/config.py | 15 ++++++-- 7 files changed, 63 insertions(+), 9 deletions(-) create mode 100644 examples/test/local/test_bert_inc.py diff --git a/.azure_pipelines/olive-examples.yaml b/.azure_pipelines/olive-examples.yaml index efaaa7bb7..559ba90eb 100644 --- a/.azure_pipelines/olive-examples.yaml +++ b/.azure_pipelines/olive-examples.yaml @@ -25,6 +25,9 @@ jobs: bert_ptq_cpu_docker: exampleFolder: bert exampleName: bert_ptq_cpu_docker + bert_inc: + exampleFolder: bert + exampleName: bert_inc resnet_vitis_ai_ptq_cpu: exampleFolder: resnet exampleName: resnet_vitis_ai_ptq_cpu @@ -56,6 +59,9 @@ jobs: onnxruntime: onnxruntime python_version: '3.10' examples: + bert_inc: + exampleFolder: bert + exampleName: bert_inc bert_ptq_cpu_docker: exampleFolder: bert exampleName: bert_ptq_cpu_docker diff --git a/examples/bert/bert_inc_smoothquant_ptq_cpu.json b/examples/bert/bert_inc_smoothquant_ptq_cpu.json index f2b25c37c..538dd4c13 100644 --- a/examples/bert/bert_inc_smoothquant_ptq_cpu.json +++ b/examples/bert/bert_inc_smoothquant_ptq_cpu.json @@ -49,7 +49,6 @@ "quantization": { "type": "IncStaticQuantization", "quant_format": "QOperator", - "user_script": "user_script.py", "data_config": "inc_quat_data_config", "recipes": { "smooth_quant": true, "smooth_quant_args": { "alpha": 0.7 } }, "metric": { @@ -73,5 +72,5 @@ }, "evaluator": "common_evaluator", "cache_dir": "cache", - "output_dir": "models/bert_inc_static_ptq_cpu" + "output_dir": "models/bert_inc_smoothquant_ptq_cpu" } diff --git a/examples/bert/requirements.txt b/examples/bert/requirements.txt index 37904a737..d0aa14970 100644 --- a/examples/bert/requirements.txt +++ b/examples/bert/requirements.txt @@ -1,10 +1,12 @@ azure-ai-ml azure-identity -datasets +# TODO(anyone): load_metrics was removed since 3.0.0. Using evaluate instead +datasets<3.0.0 docker>=7.1.0 evaluate neural-compressor optimum +pytorch_lightning scikit-learn scipy tabulate diff --git a/examples/bert/user_script.py b/examples/bert/user_script.py index b6184294b..480b57276 100644 --- a/examples/bert/user_script.py +++ b/examples/bert/user_script.py @@ -167,7 +167,7 @@ def bert_inc_glue_calibration_dataset(data_dir, **kwargs): @Registry.register_dataloader() -def bert_inc_glue_calibration_dataloader(dataset, batch_size, **kwargs): +def bert_inc_glue_calibration_dataloader(dataset, batch_size=1, **kwargs): return DefaultDataLoader(dataset=dataset, batch_size=batch_size) @@ -176,7 +176,7 @@ def bert_inc_glue_calibration_dataloader(dataset, batch_size, **kwargs): # ------------------------------------------------------------------------- -def eval_accuracy(model: OliveModelHandler, device, execution_providers, batch_size, **kwargs): +def eval_accuracy(model: OliveModelHandler, device, execution_providers, batch_size=1, **kwargs): dataset = bert_dataset("Intel/bert-base-uncased-mrpc") dataloader = bert_dataloader(dataset, batch_size) preds = [] @@ -240,6 +240,7 @@ def training_loop_func(model): training_args.save_strategy = "steps" training_args.save_total_limit = 1 training_args.metric_for_best_model = "accuracy" + training_args.save_safetensors = False dataset = BertDataset("Intel/bert-base-uncased-mrpc") diff --git a/examples/test/local/test_bert_inc.py b/examples/test/local/test_bert_inc.py new file mode 100644 index 000000000..bf557be82 --- /dev/null +++ b/examples/test/local/test_bert_inc.py @@ -0,0 +1,35 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import json +import os + +import pytest + +from ..utils import check_output, get_example_dir + + +@pytest.fixture(scope="module", autouse=True) +def setup(): + """Setups any state specific to the execution of the given module.""" + os.chdir(get_example_dir("bert")) + + +@pytest.mark.parametrize( + "olive_json", + [ + "bert_inc_dynamic_ptq_cpu.json", + "bert_inc_ptq_cpu.json", + "bert_inc_smoothquant_ptq_cpu.json", + "bert_inc_static_ptq_cpu.json", + ], +) +def test_bert(olive_json): + from olive.workflows import run as olive_run + + with open(olive_json) as f: + olive_config = json.load(f) + + footprint = olive_run(olive_config, tempdir=os.environ.get("OLIVE_TEMPDIR", None)) + check_output(footprint) diff --git a/olive/passes/pytorch/quantization_aware_training.py b/olive/passes/pytorch/quantization_aware_training.py index 6575c94ad..6fdce4679 100755 --- a/olive/passes/pytorch/quantization_aware_training.py +++ b/olive/passes/pytorch/quantization_aware_training.py @@ -30,7 +30,6 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon **get_user_script_data_config(), "train_data_config": PassConfigParam( type_=Union[DataConfig, Dict], - required=True, description="Data config for training.", ), "val_data_config": PassConfigParam( @@ -102,7 +101,8 @@ def _run_for_config( if Path(output_model_path).suffix != ".pt": output_model_path += ".pt" - qat_trainer_config.train_data_config = validate_config(config["train_data_config"], DataConfig) + if config["train_data_config"]: + qat_trainer_config.train_data_config = validate_config(config["train_data_config"], DataConfig) if config["val_data_config"]: qat_trainer_config.val_data_config = validate_config(config["val_data_config"], DataConfig) if config["training_loop_func"]: diff --git a/olive/workflows/run/config.py b/olive/workflows/run/config.py index 9c1aa0e17..ef9cd3ef3 100644 --- a/olive/workflows/run/config.py +++ b/olive/workflows/run/config.py @@ -256,8 +256,8 @@ def validate_pass_search(cls, v, values): for param_name in v["config"]: if v["config"][param_name] == PassParamDefault.SEARCHABLE_VALUES: searchable_configs.add(param_name) - if param_name.endswith("data_config"): - v["config"] = _resolve_data_config(v["config"], values, param_name) + + resolve_all_data_configs(v["config"], values) if not values["engine"].search_strategy and searchable_configs: raise ValueError( @@ -274,6 +274,17 @@ def validate_workflow_host(cls, v, values): return _resolve_config(values, v) +def resolve_all_data_configs(config, values): + """Recursively traverse the config dictionary to resolve all 'data_config' keys.""" + for param_name, param_value in config.items(): + if param_name.endswith("data_config"): + _resolve_data_config(config, values, param_name) + continue + + if isinstance(param_value, dict): + resolve_all_data_configs(param_value, values) + + def _insert_azureml_client(config, azureml_client): """Insert azureml_client into config recursively.