From 3e6fbfd0d5f48d4166f77eadf8cfc2e480bf6da3 Mon Sep 17 00:00:00 2001 From: shaahji Date: Thu, 16 Jan 2025 16:03:27 -0800 Subject: [PATCH] Implement pass search TODO: Remove old implementation and update documentation --- .gitignore | 1 + examples/bert/bert.py | 27 + ...a_gpu.json => bert_cuda_gpu.template.json} | 6 +- examples/bert/bert_inc_dynamic_ptq_cpu.json | 2 +- examples/bert/bert_inc_ptq_cpu.json | 2 +- examples/bert/bert_ptq_cpu.json | 2 +- examples/bert/notebook/bert_auto_opt_gpu.json | 2 +- examples/bert/notebook/multi_ep_search.ipynb | 10 +- examples/deberta/deberta.json | 2 +- examples/directml/llm/config_llm.json | 2 +- .../config_text_encoder.json | 1 - .../config_text_encoder_2.json | 1 - .../stable_diffusion_xl/config_unet.json | 1 - .../config_vae_decoder.json | 1 - .../config_vae_encoder.json | 1 - .../stable_diffusion_xl.py | 10 +- examples/llama2/llama2.py | 17 +- examples/llama2/llama2_lmeval.json | 2 +- examples/llama2/llama2_model_builder.py | 7 +- ...nfig_cpu.json => config_cpu.template.json} | 0 ...nfig_gpu.json => config_gpu.template.json} | 0 ..._ep.json => config_multi_ep.template.json} | 0 .../llama2/notebook/llama2_multiep/llama2.py | 32 + examples/mistral/mistral.py | 2 +- examples/mistral/mistral_fp16.json | 1 - examples/mistral/mistral_int4.json | 1 - examples/mobilenet/prepare_config.py | 32 +- examples/mobilenet/raw_qnn_sdk_template.json | 4 +- examples/phi2/phi2.py | 17 +- examples/phi3/phi3.py | 7 +- examples/phi3/phi3_template.json | 1 - examples/resnet/resnet_dynamic_ptq_cpu.json | 2 +- examples/resnet/resnet_ptq_cpu.json | 2 +- .../resnet/resnet_ptq_cpu_aml_dataset.json | 2 +- examples/resnet/resnet_static_ptq_cpu.json | 2 +- .../config_safety_checker.json | 1 - .../stable_diffusion/config_text_encoder.json | 1 - examples/stable_diffusion/config_unet.json | 1 - .../stable_diffusion/config_vae_decoder.json | 1 - .../stable_diffusion/config_vae_encoder.json | 1 - examples/stable_diffusion/sd_utils/ort.py | 5 +- examples/stable_diffusion/sd_utils/ov.py | 2 +- .../test/azureml/test_bert_ptq_cpu_aml.py | 2 +- examples/test/azureml/test_llama2.py | 6 +- .../test/azureml/test_resnet_ptq_cpu_aml.py | 6 +- examples/test/local/test_bert_cuda_gpu.py | 6 +- examples/test/local/test_bert_ptq_cpu.py | 6 +- .../test/local/test_bert_ptq_cpu_docker.py | 6 +- examples/test/local/test_llama2.py | 6 +- examples/test/local/test_resnet_ptq_cpu.py | 6 +- examples/test/local/test_resnet_qat.py | 6 +- examples/test/utils.py | 12 +- olive/cache.py | 2 +- olive/cli/auto_opt.py | 2 +- olive/cli/base.py | 2 +- olive/cli/quantize.py | 4 +- olive/engine/config.py | 40 +- olive/engine/engine.py | 322 ++- olive/evaluator/olive_evaluator.py | 1 + olive/passes/olive_pass.py | 67 +- olive/passes/onnx/inc_quantization.py | 2 +- olive/passes/onnx/nvmo_quantization.py | 2 +- olive/passes/onnx/quantization.py | 2 +- olive/passes/onnx/session_params_tuning.py | 14 +- olive/passes/onnx/vitis_ai_quantization.py | 2 +- olive/passes/pass_config.py | 2 +- olive/passes/pytorch/lora.py | 2 +- olive/passes/snpe/quantization.py | 2 +- olive/search/__init__.py | 4 + olive/search/samplers/__init__.py | 12 + olive/search/samplers/optuna_sampler.py | 140 ++ olive/search/samplers/random_sampler.py | 61 + olive/search/samplers/search_sampler.py | 69 + olive/search/samplers/sequential_sampler.py | 46 + olive/search/samplers/tpe_sampler.py | 45 + olive/search/search_parameter.py | 315 +++ olive/search/search_point.py | 56 + olive/search/search_results.py | 124 ++ olive/search/search_sample.py | 60 + olive/search/search_space.py | 126 ++ olive/search/search_strategy.py | 318 +++ olive/search/utils.py | 87 + olive/systems/azureml/aml_system.py | 7 +- olive/systems/docker/docker_system.py | 13 +- olive/systems/olive_system.py | 7 +- olive/workflows/run/config.py | 118 +- olive/workflows/run/run.py | 86 +- .../packaging/test_packaging_generator.py | 2 +- test/unit_test/engine/test_engine.py | 57 +- .../passes/common/test_user_script.py | 4 +- .../passes/test_pass_serialization.py | 3 +- .../search/samplers/test_random_sampler.py | 101 + .../samplers/test_sequential_sampler.py | 93 + .../search/samplers/test_tpe_sampler.py | 168 ++ test/unit_test/search/test_search_results.py | 106 + test/unit_test/search/test_search_space.py | 1963 +++++++++++++++++ test/unit_test/search/test_search_strategy.py | 849 +++++++ test/unit_test/utils.py | 9 +- test/unit_test/workflows/test_run_config.py | 55 +- test/unit_test/workflows/test_workflow_run.py | 3 +- 100 files changed, 5311 insertions(+), 540 deletions(-) create mode 100644 examples/bert/bert.py rename examples/bert/{bert_cuda_gpu.json => bert_cuda_gpu.template.json} (89%) rename examples/llama2/notebook/llama2_multiep/{config_cpu.json => config_cpu.template.json} (100%) rename examples/llama2/notebook/llama2_multiep/{config_gpu.json => config_gpu.template.json} (100%) rename examples/llama2/notebook/llama2_multiep/{config_multi_ep.json => config_multi_ep.template.json} (100%) create mode 100644 examples/llama2/notebook/llama2_multiep/llama2.py create mode 100644 olive/search/__init__.py create mode 100644 olive/search/samplers/__init__.py create mode 100644 olive/search/samplers/optuna_sampler.py create mode 100644 olive/search/samplers/random_sampler.py create mode 100644 olive/search/samplers/search_sampler.py create mode 100644 olive/search/samplers/sequential_sampler.py create mode 100644 olive/search/samplers/tpe_sampler.py create mode 100644 olive/search/search_parameter.py create mode 100644 olive/search/search_point.py create mode 100644 olive/search/search_results.py create mode 100644 olive/search/search_sample.py create mode 100644 olive/search/search_space.py create mode 100644 olive/search/search_strategy.py create mode 100644 olive/search/utils.py create mode 100644 test/unit_test/search/samplers/test_random_sampler.py create mode 100644 test/unit_test/search/samplers/test_sequential_sampler.py create mode 100644 test/unit_test/search/samplers/test_tpe_sampler.py create mode 100644 test/unit_test/search/test_search_results.py create mode 100644 test/unit_test/search/test_search_space.py create mode 100644 test/unit_test/search/test_search_strategy.py diff --git a/.gitignore b/.gitignore index e2c5a77c4..bc06032a2 100644 --- a/.gitignore +++ b/.gitignore @@ -126,6 +126,7 @@ celerybeat.pid # Environments .env .venv +.vs env/ venv/ ENV/ diff --git a/examples/bert/bert.py b/examples/bert/bert.py new file mode 100644 index 000000000..cda726c3b --- /dev/null +++ b/examples/bert/bert.py @@ -0,0 +1,27 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import argparse +import json +from pathlib import Path + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--optimize", + action="store_true", + help="If set, run transformers optimization pass", + ) + args = parser.parse_args() + + input_filename = "bert_cuda_gpu.template.json" + with Path(input_filename).open("r") as f: + config = json.load(f) + + if not args.optimize: + del config["passes"]["transformers_optimization"] + + output_filename = input_filename.replace(".template", "") + with Path(output_filename).open("w") as strm: + json.dump(config, fp=strm, indent=4) diff --git a/examples/bert/bert_cuda_gpu.json b/examples/bert/bert_cuda_gpu.template.json similarity index 89% rename from examples/bert/bert_cuda_gpu.json rename to examples/bert/bert_cuda_gpu.template.json index 85aaad56a..0eae68317 100644 --- a/examples/bert/bert_cuda_gpu.json +++ b/examples/bert/bert_cuda_gpu.template.json @@ -51,11 +51,7 @@ "transformers_optimization": { "type": "OrtTransformersOptimization", "float16": true }, "session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "glue_mrpc", "io_bind": true } }, - "pass_flows": [ - [ "conversion", "transformers_optimization", "session_params_tuning" ], - [ "conversion", "session_params_tuning" ] - ], - "search_strategy": { "execution_order": "joint", "search_algorithm": "tpe", "num_samples": 3, "seed": 0 }, + "search_strategy": { "execution_order": "joint", "sampler": "tpe", "max_samples": 3, "seed": 0 }, "host": "local_system", "target": "local_system", "evaluator": "common_evaluator", diff --git a/examples/bert/bert_inc_dynamic_ptq_cpu.json b/examples/bert/bert_inc_dynamic_ptq_cpu.json index 9d9f52cfa..889beff3c 100644 --- a/examples/bert/bert_inc_dynamic_ptq_cpu.json +++ b/examples/bert/bert_inc_dynamic_ptq_cpu.json @@ -34,7 +34,7 @@ "transformers_optimization": { "type": "OrtTransformersOptimization", "model_type": "bert" }, "dynamic_quantization": { "type": "IncDynamicQuantization" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "evaluator": "common_evaluator", "cache_dir": "cache", "output_dir": "models/bert_inc_dynamic_ptq_cpu" diff --git a/examples/bert/bert_inc_ptq_cpu.json b/examples/bert/bert_inc_ptq_cpu.json index 09558de46..82220ac84 100644 --- a/examples/bert/bert_inc_ptq_cpu.json +++ b/examples/bert/bert_inc_ptq_cpu.json @@ -64,7 +64,7 @@ } } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "evaluator": "common_evaluator", "cache_dir": "cache", "output_dir": "models/bert_inc_ptq_cpu" diff --git a/examples/bert/bert_ptq_cpu.json b/examples/bert/bert_ptq_cpu.json index 5a717acf1..625415925 100644 --- a/examples/bert/bert_ptq_cpu.json +++ b/examples/bert/bert_ptq_cpu.json @@ -67,7 +67,7 @@ }, "session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "glue_mrpc" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "tpe", "num_samples": 3, "seed": 0 }, + "search_strategy": { "execution_order": "joint", "sampler": "tpe", "max_samples": 3, "seed": 0 }, "evaluator": "common_evaluator", "host": "local_system", "target": "local_system", diff --git a/examples/bert/notebook/bert_auto_opt_gpu.json b/examples/bert/notebook/bert_auto_opt_gpu.json index f1ed055a5..929626f46 100644 --- a/examples/bert/notebook/bert_auto_opt_gpu.json +++ b/examples/bert/notebook/bert_auto_opt_gpu.json @@ -43,7 +43,7 @@ ] } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "tpe", "num_samples": 1, "seed": 0 }, + "search_strategy": { "execution_order": "joint", "sampler": "tpe", "max_samples": 1, "seed": 0 }, "evaluator": "common_evaluator", "host": "local_system", "target": "local_system", diff --git a/examples/bert/notebook/multi_ep_search.ipynb b/examples/bert/notebook/multi_ep_search.ipynb index fcedbfa1b..fd9ccfe70 100644 --- a/examples/bert/notebook/multi_ep_search.ipynb +++ b/examples/bert/notebook/multi_ep_search.ipynb @@ -141,17 +141,15 @@ "#### Engine and search strategy\n", "\n", "Engine is used to manage the optimization process where we run optimization on host device, and run evaluation on target device.\n", - "Search strategy is used to search the optimal optimization among different EPs. In this notebook, we use `joint` as the `execution_order` and `tpe` as the `search_algorithm`. We set the `num_samples` to 1 and `seed` to 0.\n", + "Search strategy is used to search the optimal optimization among different EPs. In this notebook, we use `joint` as the `execution_order` and `tpe` as the `sampler`. We set the `max_samples` to 1 and `seed` to 0.\n", "\n", "```json\n", "\"engine\": {\n", " \"search_strategy\": {\n", " \"execution_order\": \"joint\",\n", - " \"search_algorithm\": \"tpe\",\n", - " \"search_algorithm_config\": {\n", - " \"num_samples\": 1,\n", - " \"seed\": 0\n", - " }\n", + " \"sampler\": \"tpe\",\n", + " \"max_samples\": 1,\n", + " \"seed\": 0\n", " },\n", " \"evaluator\": \"common_evaluator\",\n", " \"host\": \"local_system\",\n", diff --git a/examples/deberta/deberta.json b/examples/deberta/deberta.json index 025a0c2cf..56c8a60ae 100644 --- a/examples/deberta/deberta.json +++ b/examples/deberta/deberta.json @@ -46,7 +46,7 @@ "quantization": { "type": "OnnxQuantization", "data_config": "glue_mnli_matched" }, "session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "glue_mnli_matched" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "tpe", "num_samples": 3, "seed": 0 }, + "search_strategy": { "execution_order": "joint", "sampler": "tpe", "max_samples": 3, "seed": 0 }, "evaluator": "common_evaluator", "cache_dir": "cache", "output_dir": "models/microsoft-deberta" diff --git a/examples/directml/llm/config_llm.json b/examples/directml/llm/config_llm.json index a052fc3dc..8b1178df6 100644 --- a/examples/directml/llm/config_llm.json +++ b/examples/directml/llm/config_llm.json @@ -90,7 +90,7 @@ } } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "evaluator": "common_evaluator", "evaluate_input_model": false, "host": "local_system", diff --git a/examples/directml/stable_diffusion_xl/config_text_encoder.json b/examples/directml/stable_diffusion_xl/config_text_encoder.json index c131054a9..a205209f7 100644 --- a/examples/directml/stable_diffusion_xl/config_text_encoder.json +++ b/examples/directml/stable_diffusion_xl/config_text_encoder.json @@ -112,7 +112,6 @@ "keep_io_types": true } }, - "pass_flows": [ [ "convert", "optimize" ] ], "evaluator": "common_evaluator", "evaluate_input_model": false, "host": "local_system", diff --git a/examples/directml/stable_diffusion_xl/config_text_encoder_2.json b/examples/directml/stable_diffusion_xl/config_text_encoder_2.json index dffc5de9f..a26dd0833 100644 --- a/examples/directml/stable_diffusion_xl/config_text_encoder_2.json +++ b/examples/directml/stable_diffusion_xl/config_text_encoder_2.json @@ -152,7 +152,6 @@ "keep_io_types": true } }, - "pass_flows": [ [ "convert", "optimize" ] ], "evaluator": "common_evaluator", "evaluate_input_model": false, "host": "local_system", diff --git a/examples/directml/stable_diffusion_xl/config_unet.json b/examples/directml/stable_diffusion_xl/config_unet.json index 6e38b41c0..9b2c513af 100644 --- a/examples/directml/stable_diffusion_xl/config_unet.json +++ b/examples/directml/stable_diffusion_xl/config_unet.json @@ -96,7 +96,6 @@ "keep_io_types": true } }, - "pass_flows": [ [ "convert", "optimize" ] ], "evaluator": "common_evaluator", "evaluate_input_model": false, "host": "local_system", diff --git a/examples/directml/stable_diffusion_xl/config_vae_decoder.json b/examples/directml/stable_diffusion_xl/config_vae_decoder.json index ad22ac2b9..c6c3c493e 100644 --- a/examples/directml/stable_diffusion_xl/config_vae_decoder.json +++ b/examples/directml/stable_diffusion_xl/config_vae_decoder.json @@ -108,7 +108,6 @@ "use_gpu": true } }, - "pass_flows": [ [ "convert", "optimize" ] ], "evaluator": "common_evaluator", "evaluate_input_model": false, "host": "local_system", diff --git a/examples/directml/stable_diffusion_xl/config_vae_encoder.json b/examples/directml/stable_diffusion_xl/config_vae_encoder.json index ef889387f..fc4c8fd1d 100644 --- a/examples/directml/stable_diffusion_xl/config_vae_encoder.json +++ b/examples/directml/stable_diffusion_xl/config_vae_encoder.json @@ -87,7 +87,6 @@ "keep_io_types": true } }, - "pass_flows": [ [ "convert", "optimize" ] ], "evaluator": "common_evaluator", "evaluate_input_model": false, "host": "local_system", diff --git a/examples/directml/stable_diffusion_xl/stable_diffusion_xl.py b/examples/directml/stable_diffusion_xl/stable_diffusion_xl.py index 8bb69dca4..ac5d592ea 100644 --- a/examples/directml/stable_diffusion_xl/stable_diffusion_xl.py +++ b/examples/directml/stable_diffusion_xl/stable_diffusion_xl.py @@ -7,6 +7,7 @@ import shutil import sys import warnings +from collections import OrderedDict from pathlib import Path from typing import Dict @@ -289,9 +290,10 @@ def run_inference( def update_config_with_provider(config: Dict, provider: str, is_fp16: bool) -> Dict: + used_passes = {} if provider == "dml": # DirectML EP is the default, so no need to update config. - return config + used_passes = {"convert", "optimize"} elif provider == "cuda": if version.parse(OrtVersion) < version.parse("1.17.0"): # disable skip_group_norm fusion since there is a shape inference bug which leads to invalid models @@ -299,12 +301,14 @@ def update_config_with_provider(config: Dict, provider: str, is_fp16: bool) -> D # keep model fully in fp16 if use_fp16_fixed_vae is set if is_fp16: config["passes"]["optimize_cuda"].update({"float16": True, "keep_io_types": False}) - config["pass_flows"] = [["convert", "optimize_cuda"]] + used_passes = {"convert", "optimize_cuda"} config["systems"]["local_system"]["accelerators"][0]["execution_providers"] = ["CUDAExecutionProvider"] - return config else: raise ValueError(f"Unsupported provider: {provider}") + config["passes"] = OrderedDict([(k, v) for k, v in config["passes"].items() if k in used_passes]) + return config + def optimize( model_id: str, diff --git a/examples/llama2/llama2.py b/examples/llama2/llama2.py index 388655764..5cb7424a3 100644 --- a/examples/llama2/llama2.py +++ b/examples/llama2/llama2.py @@ -173,17 +173,26 @@ def get_general_config(args): gqa = "gqa" if args.use_gqa else "mha" config_name = f"llama2_{device}_{gqa}" - # add pass flows + # add pass names if not args.use_gptq: - template_json["pass_flows"] = [flow for flow in SUPPORTED_WORKFLOWS[device] if "gptq" not in flow[0]] + used_passes = [ + pass_name + for passes_names in SUPPORTED_WORKFLOWS[device] + if "gptq" not in passes_names[0] + for pass_name in passes_names + ] else: - template_json["pass_flows"] = [flow for flow in SUPPORTED_WORKFLOWS[device] if "gptq" in flow[0]] + used_passes = [ + pass_name + for passes_names in SUPPORTED_WORKFLOWS[device] + if "gptq" in passes_names[0] + for pass_name in passes_names + ] auto_gptq_logger = logging.getLogger("auto_gptq") auto_gptq_logger.addHandler(logging.StreamHandler(sys.stdout)) auto_gptq_logger.setLevel(logging.INFO) # remove unused passes and set gqa related configs - used_passes = {pass_name for pass_flow in SUPPORTED_WORKFLOWS[device] for pass_name in pass_flow} for pass_name in list(template_json["passes"].keys()): if pass_name not in used_passes: del template_json["passes"][pass_name] diff --git a/examples/llama2/llama2_lmeval.json b/examples/llama2/llama2_lmeval.json index b6630d25d..4ea3e1050 100644 --- a/examples/llama2/llama2_lmeval.json +++ b/examples/llama2/llama2_lmeval.json @@ -57,7 +57,7 @@ } }, "auto_optimizer_config": { "opt_level": 0, "disable_auto_optimizer": true, "precision": "fp16" }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "exhaustive" }, "evaluator": "evaluator", "host": "local_system", "target": "local_system", diff --git a/examples/llama2/llama2_model_builder.py b/examples/llama2/llama2_model_builder.py index f613f2c37..40263ba74 100644 --- a/examples/llama2/llama2_model_builder.py +++ b/examples/llama2/llama2_model_builder.py @@ -5,6 +5,7 @@ import argparse import json +from collections import OrderedDict from olive.workflows import run as olive_run @@ -38,10 +39,8 @@ def main(raw_args=None): template_json = json.loads(template_json_str) # add pass flows - if args.metadata_only: - template_json["pass_flows"] = [["conversion", "metadata"]] - else: - template_json["pass_flows"] = [["builder", "session_params_tuning"]] + used_passes = {"conversion", "metadata"} if args.metadata_only else {"builder", "session_params_tuning"} + template_json["passes"] = OrderedDict([(k, v) for k, v in template_json["passes"].items() if k in used_passes]) template_json["output_dir"] = f"models/{model_name}" # dump config diff --git a/examples/llama2/notebook/llama2_multiep/config_cpu.json b/examples/llama2/notebook/llama2_multiep/config_cpu.template.json similarity index 100% rename from examples/llama2/notebook/llama2_multiep/config_cpu.json rename to examples/llama2/notebook/llama2_multiep/config_cpu.template.json diff --git a/examples/llama2/notebook/llama2_multiep/config_gpu.json b/examples/llama2/notebook/llama2_multiep/config_gpu.template.json similarity index 100% rename from examples/llama2/notebook/llama2_multiep/config_gpu.json rename to examples/llama2/notebook/llama2_multiep/config_gpu.template.json diff --git a/examples/llama2/notebook/llama2_multiep/config_multi_ep.json b/examples/llama2/notebook/llama2_multiep/config_multi_ep.template.json similarity index 100% rename from examples/llama2/notebook/llama2_multiep/config_multi_ep.json rename to examples/llama2/notebook/llama2_multiep/config_multi_ep.template.json diff --git a/examples/llama2/notebook/llama2_multiep/llama2.py b/examples/llama2/notebook/llama2_multiep/llama2.py new file mode 100644 index 000000000..069ee8dbb --- /dev/null +++ b/examples/llama2/notebook/llama2_multiep/llama2.py @@ -0,0 +1,32 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import argparse +import json +from pathlib import Path + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--device", + choices=["cpu", "gpu", "multi_ep"], + help="Device to use", + ) + parser.add_argument( + "--quantize", + action="store_true", + help="If set, run transformers optimization pass", + ) + args = parser.parse_args() + + input_filename = f"config_{args.device}.template.json" + with Path(input_filename).open("r") as f: + config = json.load(f) + + if not args.quantize: + del config["passes"]["blockwise_quant_int4"] + + output_filename = input_filename.replace(".template", "") + with Path(output_filename).open("w") as strm: + json.dump(config, fp=strm, indent=4) diff --git a/examples/mistral/mistral.py b/examples/mistral/mistral.py index 39ec7dd2b..2fdaa8450 100644 --- a/examples/mistral/mistral.py +++ b/examples/mistral/mistral.py @@ -108,7 +108,7 @@ def main(raw_args=None): optimized_model_dir = ( script_dir / config["output_dir"] - / "-".join(config["pass_flows"][0]) + / "-".join(config["passes"].keys()) / f"{config['output_name']}_{accelerator}-{ep_header}_model" ) diff --git a/examples/mistral/mistral_fp16.json b/examples/mistral/mistral_fp16.json index 926d13e6e..364e51630 100644 --- a/examples/mistral/mistral_fp16.json +++ b/examples/mistral/mistral_fp16.json @@ -45,7 +45,6 @@ "enable_profiling": false } }, - "pass_flows": [ [ "convert", "optimize", "session_params_tuning" ] ], "evaluate_input_model": false, "evaluator": "common_evaluator", "host": "local_system", diff --git a/examples/mistral/mistral_int4.json b/examples/mistral/mistral_int4.json index adfe1949f..16f89acbd 100644 --- a/examples/mistral/mistral_int4.json +++ b/examples/mistral/mistral_int4.json @@ -57,7 +57,6 @@ "all_tensors_to_one_file": true } }, - "pass_flows": [ [ "convert", "optimize", "quantization" ] ], "evaluate_input_model": false, "evaluator": "common_evaluator", "host": "local_system", diff --git a/examples/mobilenet/prepare_config.py b/examples/mobilenet/prepare_config.py index db299bdea..756eed3b3 100644 --- a/examples/mobilenet/prepare_config.py +++ b/examples/mobilenet/prepare_config.py @@ -2,27 +2,32 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import argparse import json import platform +from collections import OrderedDict from pathlib import Path from olive.common.constants import OS -def raw_qnn_config(): +def raw_qnn_config(mode: str): # pylint: disable=redefined-outer-name - with Path("./raw_qnn_sdk_template.json").open("r") as f: + with Path("raw_qnn_sdk_template.json").open("r") as f: raw_qnn_config = json.load(f) sys_platform = platform.system() + if mode == "convert": + used_passes = {"converter", "build_model_lib"} + elif mode == "quantize": + used_passes = {"quantization", "build_model_lib"} + else: + raise ValueError(f"Unsupported mode: {mode}") + if sys_platform == OS.LINUX: - raw_qnn_config["passes"]["qnn_context_binary"] = { - "type": "QNNContextBinaryGenerator", - "backend": "libQnnHtp.so", - } - raw_qnn_config["pass_flows"].append(["converter", "build_model_lib", "qnn_context_binary"]) + used_passes.add("qnn_context_binary") raw_qnn_config["passes"]["build_model_lib"]["lib_targets"] = "x86_64-linux-clang" elif sys_platform == OS.WINDOWS: raw_qnn_config["passes"]["build_model_lib"]["lib_targets"] = "x86_64-windows-msvc" @@ -33,11 +38,20 @@ def raw_qnn_config(): elif sys_platform == OS.LINUX: metric_config["user_config"]["inference_settings"]["qnn"]["backend"] = "libQnnCpu" + raw_qnn_config["passes"] = OrderedDict([(k, v) for k, v in raw_qnn_config["passes"].items() if k in used_passes]) + with Path("raw_qnn_sdk_config.json").open("w") as f: json_str = json.dumps(raw_qnn_config, indent=4) - input_file_path = Path("./data/eval/input_order.txt").resolve().as_posix() + input_file_path = Path("data/eval/input_order.txt").resolve().as_posix() f.write(json_str.replace("", str(input_file_path))) if __name__ == "__main__": - raw_qnn_config() + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", + choices=["convert", "quantize"], + help="Mode selection", + ) + args = parser.parse_args() + raw_qnn_config(args.mode) diff --git a/examples/mobilenet/raw_qnn_sdk_template.json b/examples/mobilenet/raw_qnn_sdk_template.json index c94d40837..e4f5d1eeb 100644 --- a/examples/mobilenet/raw_qnn_sdk_template.json +++ b/examples/mobilenet/raw_qnn_sdk_template.json @@ -46,9 +46,9 @@ "dynamic_shape_to_fixed": { "type": "DynamicToFixedShape", "dim_param": [ "batch_size" ], "dim_value": [ 1 ] }, "converter": { "type": "QNNConversion" }, "quantization": { "type": "QNNConversion", "extra_args": "--input_list " }, - "build_model_lib": { "type": "QNNModelLibGenerator", "lib_targets": "x86_64-linux-clang" } + "build_model_lib": { "type": "QNNModelLibGenerator", "lib_targets": "x86_64-linux-clang" }, + "qnn_context_binary": { "type": "QNNContextBinaryGenerator", "backend": "libQnnHtp.so" } }, - "pass_flows": [ [ "converter", "build_model_lib" ], [ "quantization", "build_model_lib" ] ], "log_severity_level": 0, "host": "local_system", "target": "local_system", diff --git a/examples/phi2/phi2.py b/examples/phi2/phi2.py index 1f9370a6a..bd2ef06d7 100644 --- a/examples/phi2/phi2.py +++ b/examples/phi2/phi2.py @@ -186,25 +186,23 @@ def main(raw_args=None): legacy_optimization_setting(template_json) # add pass flows - pass_flows = [[]] + used_passes = {} if args.finetune_method: - pass_flows[0].append(args.finetune_method) + used_passes.add(args.finetune_method) # torch fine tuning does not require execution provider, just set it to CUDAExecutionProvider update_accelerator(template_json, "gpu") if args.slicegpt: - pass_flows[0].extend(SUPPORTED_WORKFLOWS["slicegpt"][0]) + used_passes.update(SUPPORTED_WORKFLOWS["slicegpt"][0]) update_accelerator(template_json, "gpu") del template_json["input_model"]["io_config"] if model_type: - pass_flows[0].extend(SUPPORTED_WORKFLOWS[model_type][0]) - template_json["pass_flows"] = pass_flows + used_passes.update(SUPPORTED_WORKFLOWS[model_type][0]) if args.optimum_optimization: legacy_optimization_setting(template_json) - for pass_flow in template_json["pass_flows"]: - pass_flow[0] = "optimum_convert" - if "session_params_tuning" in pass_flow: - pass_flow.remove("session_params_tuning") + used_passes.pop("convert", None) + used_passes.pop("session_params_tuning", None) + used_passes.add("optimum_convert") if "cuda" in model_type: update_accelerator(template_json, "gpu") @@ -217,7 +215,6 @@ def main(raw_args=None): template_json["evaluate_input_model"] = False del template_json["evaluator"] - used_passes = {pass_name for pass_flow in pass_flows for pass_name in pass_flow} for pass_name in list(template_json["passes"].keys()): if pass_name not in used_passes: del template_json["passes"][pass_name] diff --git a/examples/phi3/phi3.py b/examples/phi3/phi3.py index 16277dd82..4d2e114b7 100644 --- a/examples/phi3/phi3.py +++ b/examples/phi3/phi3.py @@ -191,7 +191,10 @@ def use_passes(template_json, *passes): else: del template_json["data_configs"] - template_json["pass_flows"] = [passes] + for pass_name in set(template_json["passes"].keys()): + if pass_name not in passes: + template_json["passes"].pop(pass_name, None) + return template_json @@ -223,7 +226,7 @@ def generate_config(args): if args.tune_session_params: passes_to_use.append("tune_session_params") - template_json["search_strategy"] = {"execution_order": "joint", "search_algorithm": "exhaustive"} + template_json["search_strategy"] = {"execution_order": "joint", "sampler": "sequential"} template_json["evaluator"] = "common_evaluator" else: del template_json["evaluators"] diff --git a/examples/phi3/phi3_template.json b/examples/phi3/phi3_template.json index e5508d878..04e7151cf 100644 --- a/examples/phi3/phi3_template.json +++ b/examples/phi3/phi3_template.json @@ -105,7 +105,6 @@ "execution_mode_list": [ 0, 1 ] } }, - "pass_flows": [ [ "" ] ], "cache_dir": "cache", "output_dir": "models", "host": "local_system", diff --git a/examples/resnet/resnet_dynamic_ptq_cpu.json b/examples/resnet/resnet_dynamic_ptq_cpu.json index 17a2c65d6..9351eadf8 100644 --- a/examples/resnet/resnet_dynamic_ptq_cpu.json +++ b/examples/resnet/resnet_dynamic_ptq_cpu.json @@ -56,7 +56,7 @@ "onnx_dynamic_quantization": { "type": "OnnxDynamicQuantization", "weight_type": "QUInt8" }, "session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "cifar10_data_config" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "host": "local_system", "target": "local_system", "evaluator": "common_evaluator", diff --git a/examples/resnet/resnet_ptq_cpu.json b/examples/resnet/resnet_ptq_cpu.json index 1c0922d3e..fdfaa07d4 100644 --- a/examples/resnet/resnet_ptq_cpu.json +++ b/examples/resnet/resnet_ptq_cpu.json @@ -72,7 +72,7 @@ }, "session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "cifar10_data_config" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "host": "local_system", "target": "local_system", "evaluator": "common_evaluator", diff --git a/examples/resnet/resnet_ptq_cpu_aml_dataset.json b/examples/resnet/resnet_ptq_cpu_aml_dataset.json index 1ee41bb03..7a05b4f38 100644 --- a/examples/resnet/resnet_ptq_cpu_aml_dataset.json +++ b/examples/resnet/resnet_ptq_cpu_aml_dataset.json @@ -81,7 +81,7 @@ }, "session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "cifar10_data_config" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "evaluator": "common_evaluator", "cache_dir": "cache", "output_dir": "models/resnet_ptq_cpu" diff --git a/examples/resnet/resnet_static_ptq_cpu.json b/examples/resnet/resnet_static_ptq_cpu.json index 65e7b2950..771d8e5f4 100644 --- a/examples/resnet/resnet_static_ptq_cpu.json +++ b/examples/resnet/resnet_static_ptq_cpu.json @@ -65,7 +65,7 @@ "data_config": "session_params_tuning_data_config" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "host": "local_system", "target": "local_system", "evaluator": "common_evaluator", diff --git a/examples/stable_diffusion/config_safety_checker.json b/examples/stable_diffusion/config_safety_checker.json index fc36d6c8c..33bb6fcd4 100644 --- a/examples/stable_diffusion/config_safety_checker.json +++ b/examples/stable_diffusion/config_safety_checker.json @@ -88,7 +88,6 @@ "keep_io_types": false } }, - "pass_flows": [ [ "convert", "optimize" ] ], "log_severity_level": 0, "evaluator": "common_evaluator", "evaluate_input_model": false, diff --git a/examples/stable_diffusion/config_text_encoder.json b/examples/stable_diffusion/config_text_encoder.json index 3747e4f80..baa6de7d3 100644 --- a/examples/stable_diffusion/config_text_encoder.json +++ b/examples/stable_diffusion/config_text_encoder.json @@ -85,7 +85,6 @@ "keep_io_types": false } }, - "pass_flows": [ [ "convert", "optimize" ] ], "log_severity_level": 0, "evaluator": "common_evaluator", "evaluate_input_model": false, diff --git a/examples/stable_diffusion/config_unet.json b/examples/stable_diffusion/config_unet.json index dfad5e88c..2ee455b9b 100644 --- a/examples/stable_diffusion/config_unet.json +++ b/examples/stable_diffusion/config_unet.json @@ -100,7 +100,6 @@ "keep_io_types": false } }, - "pass_flows": [ [ "convert", "optimize" ] ], "log_severity_level": 0, "evaluator": "common_evaluator", "evaluate_input_model": false, diff --git a/examples/stable_diffusion/config_vae_decoder.json b/examples/stable_diffusion/config_vae_decoder.json index 362f49cb9..5288528c9 100644 --- a/examples/stable_diffusion/config_vae_decoder.json +++ b/examples/stable_diffusion/config_vae_decoder.json @@ -92,7 +92,6 @@ "keep_io_types": false } }, - "pass_flows": [ [ "convert", "optimize" ] ], "log_severity_level": 0, "evaluator": "common_evaluator", "evaluate_input_model": false, diff --git a/examples/stable_diffusion/config_vae_encoder.json b/examples/stable_diffusion/config_vae_encoder.json index 61e46d298..a038ecdcb 100644 --- a/examples/stable_diffusion/config_vae_encoder.json +++ b/examples/stable_diffusion/config_vae_encoder.json @@ -87,7 +87,6 @@ "keep_io_types": false } }, - "pass_flows": [ [ "convert", "optimize" ] ], "log_severity_level": 0, "evaluator": "common_evaluator", "evaluate_input_model": false, diff --git a/examples/stable_diffusion/sd_utils/ort.py b/examples/stable_diffusion/sd_utils/ort.py index 01eab8923..57e90504f 100644 --- a/examples/stable_diffusion/sd_utils/ort.py +++ b/examples/stable_diffusion/sd_utils/ort.py @@ -23,7 +23,10 @@ def update_cuda_config(config_cuda: Dict): if version.parse(OrtVersion) < version.parse("1.17.0"): # disable skip_group_norm fusion since there is a shape inference bug which leads to invalid models config_cuda["passes"]["optimize_cuda"]["optimization_options"] = {"enable_skip_group_norm": False} - config_cuda["pass_flows"] = [["convert", "optimize_cuda"]] + used_passes = {"convert", "optimize_cuda"} + for pass_name in set(config_cuda["passes"].keys()): + if pass_name not in used_passes: + config_cuda["passes"].pop(pass_name, None) config_cuda["systems"]["local_system"]["accelerators"][0]["execution_providers"] = ["CUDAExecutionProvider"] return config_cuda diff --git a/examples/stable_diffusion/sd_utils/ov.py b/examples/stable_diffusion/sd_utils/ov.py index 95905227f..5c17e228c 100644 --- a/examples/stable_diffusion/sd_utils/ov.py +++ b/examples/stable_diffusion/sd_utils/ov.py @@ -442,7 +442,7 @@ def get_timesteps(self, num_inference_steps: int, strength: float): def update_ov_config(config: Dict): - config["pass_flows"] = [["ov_convert"]] + config["passes"] = {"ov_convert": config["passes"]["ov_convert"]} config["search_strategy"] = False config["systems"]["local_system"]["accelerators"][0]["execution_providers"] = ["CPUExecutionProvider"] del config["evaluators"] diff --git a/examples/test/azureml/test_bert_ptq_cpu_aml.py b/examples/test/azureml/test_bert_ptq_cpu_aml.py index 0e9384ebf..f01378d61 100644 --- a/examples/test/azureml/test_bert_ptq_cpu_aml.py +++ b/examples/test/azureml/test_bert_ptq_cpu_aml.py @@ -27,7 +27,7 @@ def setup(): ], ) def test_bert(olive_test_knob): - # olive_config: (config_json_path, search_algorithm, execution_order, system) + # olive_config: (config_json_path, sampler, execution_order, system) # bert_ptq_cpu.json: use huggingface model id # bert_ptq_cpu_aml.json: use aml model path from olive.workflows import run as olive_run diff --git a/examples/test/azureml/test_llama2.py b/examples/test/azureml/test_llama2.py index 85da88e71..1e5c17fa2 100644 --- a/examples/test/azureml/test_llama2.py +++ b/examples/test/azureml/test_llama2.py @@ -22,15 +22,15 @@ def setup(): os.chdir(get_example_dir("llama2")) -@pytest.mark.parametrize("search_algorithm", [False]) +@pytest.mark.parametrize("sampler", [False]) @pytest.mark.parametrize("execution_order", [None]) @pytest.mark.parametrize("system", ["local_system"]) @pytest.mark.parametrize("cache_config", [None, {"account_name": account_name, "container_name": container_name}]) @pytest.mark.parametrize("olive_json", ["llama2_qlora.json"]) -def test_llama2(search_algorithm, execution_order, system, cache_config, olive_json): +def test_llama2(sampler, execution_order, system, cache_config, olive_json): from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system, is_gpu=False, hf_token=True) + olive_config = patch_config(olive_json, sampler, execution_order, system, is_gpu=False, hf_token=True) # reduce qlora steps for faster test olive_config["passes"]["f"]["training_args"]["max_steps"] = 5 diff --git a/examples/test/azureml/test_resnet_ptq_cpu_aml.py b/examples/test/azureml/test_resnet_ptq_cpu_aml.py index 8550a0e24..402eff8d6 100644 --- a/examples/test/azureml/test_resnet_ptq_cpu_aml.py +++ b/examples/test/azureml/test_resnet_ptq_cpu_aml.py @@ -23,7 +23,7 @@ def setup(): retry_func(run_subprocess, kwargs={"cmd": "python prepare_model_data.py", "check": True}) -@pytest.mark.parametrize("search_algorithm", ["random"]) +@pytest.mark.parametrize("sampler", ["random"]) @pytest.mark.parametrize("execution_order", ["pass-by-pass"]) @pytest.mark.parametrize("system", ["aml_system"]) @pytest.mark.parametrize( @@ -40,10 +40,10 @@ def setup(): version.parse(OrtVersion) == version.parse("1.16.0"), reason="resnet is not supported in ORT 1.16.0 caused by https://github.com/microsoft/onnxruntime/issues/17627", ) -def test_resnet(search_algorithm, execution_order, system, olive_json): +def test_resnet(sampler, execution_order, system, olive_json): from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system) + olive_config = patch_config(olive_json, sampler, execution_order, system) footprint = olive_run(olive_config, tempdir=os.environ.get("OLIVE_TEMPDIR", None)) check_output(footprint) diff --git a/examples/test/local/test_bert_cuda_gpu.py b/examples/test/local/test_bert_cuda_gpu.py index a5c6b0d82..19d088ebb 100644 --- a/examples/test/local/test_bert_cuda_gpu.py +++ b/examples/test/local/test_bert_cuda_gpu.py @@ -16,14 +16,14 @@ def setup(): @pytest.mark.skip(reason="Disable failing tests") -@pytest.mark.parametrize("search_algorithm", ["tpe"]) +@pytest.mark.parametrize("sampler", ["tpe"]) @pytest.mark.parametrize("execution_order", ["joint", "pass-by-pass"]) @pytest.mark.parametrize("system", ["local_system"]) @pytest.mark.parametrize("olive_json", ["bert_cuda_gpu.json"]) -def test_bert(search_algorithm, execution_order, system, olive_json): +def test_bert(sampler, execution_order, system, olive_json): from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system, is_gpu=True) + olive_config = patch_config(olive_json, sampler, execution_order, system, is_gpu=True) olive_config["passes"]["session_params_tuning"]["enable_cuda_graph"] = False footprint = olive_run(olive_config, tempdir=os.environ.get("OLIVE_TEMPDIR", None)) diff --git a/examples/test/local/test_bert_ptq_cpu.py b/examples/test/local/test_bert_ptq_cpu.py index ee2975c44..9c2e9bdab 100644 --- a/examples/test/local/test_bert_ptq_cpu.py +++ b/examples/test/local/test_bert_ptq_cpu.py @@ -15,14 +15,14 @@ def setup(): os.chdir(get_example_dir("bert")) -@pytest.mark.parametrize("search_algorithm", ["tpe"]) +@pytest.mark.parametrize("sampler", ["tpe"]) @pytest.mark.parametrize("execution_order", ["joint"]) @pytest.mark.parametrize("system", ["local_system"]) @pytest.mark.parametrize("olive_json", ["bert_ptq_cpu.json"]) -def test_bert(search_algorithm, execution_order, system, olive_json): +def test_bert(sampler, execution_order, system, olive_json): from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system) + olive_config = patch_config(olive_json, sampler, execution_order, system) footprint = olive_run(olive_config, tempdir=os.environ.get("OLIVE_TEMPDIR", None)) check_output(footprint) diff --git a/examples/test/local/test_bert_ptq_cpu_docker.py b/examples/test/local/test_bert_ptq_cpu_docker.py index 1f812cd9c..27958daab 100644 --- a/examples/test/local/test_bert_ptq_cpu_docker.py +++ b/examples/test/local/test_bert_ptq_cpu_docker.py @@ -18,17 +18,17 @@ def setup(): os.chdir(get_example_dir("bert")) -@pytest.mark.parametrize("search_algorithm", ["tpe"]) +@pytest.mark.parametrize("sampler", ["tpe"]) @pytest.mark.parametrize("execution_order", ["joint"]) @pytest.mark.parametrize("system", ["docker_system"]) @pytest.mark.parametrize("olive_json", ["bert_ptq_cpu.json"]) -def test_bert(search_algorithm, execution_order, system, olive_json): +def test_bert(sampler, execution_order, system, olive_json): if system == "docker_system" and platform.system() == OS.WINDOWS: pytest.skip("Skip Linux containers on Windows host test case.") from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system) + olive_config = patch_config(olive_json, sampler, execution_order, system) footprint = olive_run(olive_config, tempdir=os.environ.get("OLIVE_TEMPDIR", None)) check_output(footprint) diff --git a/examples/test/local/test_llama2.py b/examples/test/local/test_llama2.py index 2ec2b1691..ee47a475d 100644 --- a/examples/test/local/test_llama2.py +++ b/examples/test/local/test_llama2.py @@ -16,16 +16,16 @@ def setup(): os.chdir(get_example_dir("llama2")) -@pytest.mark.parametrize("search_algorithm", [False]) +@pytest.mark.parametrize("sampler", [False]) @pytest.mark.parametrize("execution_order", [None]) @pytest.mark.parametrize("system", ["local_system"]) @pytest.mark.parametrize("olive_json", ["llama2_qlora.json"]) -def test_llama2(search_algorithm, execution_order, system, olive_json): +def test_llama2(sampler, execution_order, system, olive_json): from onnxruntime import __version__ as ort_version from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system) + olive_config = patch_config(olive_json, sampler, execution_order, system) # replace meta-llama with open-llama version of the model # doesn't require login diff --git a/examples/test/local/test_resnet_ptq_cpu.py b/examples/test/local/test_resnet_ptq_cpu.py index 0f7528167..5bccfe731 100644 --- a/examples/test/local/test_resnet_ptq_cpu.py +++ b/examples/test/local/test_resnet_ptq_cpu.py @@ -23,7 +23,7 @@ def setup(): retry_func(run_subprocess, kwargs={"cmd": "python prepare_model_data.py", "check": True}) -@pytest.mark.parametrize("search_algorithm", ["random"]) +@pytest.mark.parametrize("sampler", ["random"]) @pytest.mark.parametrize("execution_order", ["pass-by-pass"]) @pytest.mark.parametrize("system", ["local_system"]) @pytest.mark.parametrize( @@ -40,10 +40,10 @@ def setup(): version.parse(OrtVersion) == version.parse("1.16.0"), reason="resnet is not supported in ORT 1.16.0 caused by https://github.com/microsoft/onnxruntime/issues/17627", ) -def test_resnet(search_algorithm, execution_order, system, olive_json): +def test_resnet(sampler, execution_order, system, olive_json): from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system) + olive_config = patch_config(olive_json, sampler, execution_order, system) footprint = olive_run(olive_config, tempdir=os.environ.get("OLIVE_TEMPDIR", None)) check_output(footprint) diff --git a/examples/test/local/test_resnet_qat.py b/examples/test/local/test_resnet_qat.py index afdfc9efc..8c745876e 100644 --- a/examples/test/local/test_resnet_qat.py +++ b/examples/test/local/test_resnet_qat.py @@ -21,16 +21,16 @@ def setup(): retry_func(run_subprocess, kwargs={"cmd": "python prepare_model_data.py", "check": True}) -@pytest.mark.parametrize("search_algorithm", ["random"]) +@pytest.mark.parametrize("sampler", ["random"]) @pytest.mark.parametrize("execution_order", ["pass-by-pass"]) @pytest.mark.parametrize("system", ["local_system"]) @pytest.mark.parametrize( "olive_json", ["resnet_qat_default_train_loop_cpu.json", "resnet_qat_lightning_module_cpu.json"] ) -def test_resnet(search_algorithm, execution_order, system, olive_json): +def test_resnet(sampler, execution_order, system, olive_json): from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system) + olive_config = patch_config(olive_json, sampler, execution_order, system) footprint = olive_run(olive_config, tempdir=os.environ.get("OLIVE_TEMPDIR", None)) check_output(footprint) diff --git a/examples/test/utils.py b/examples/test/utils.py index 77e67e247..462e07705 100644 --- a/examples/test/utils.py +++ b/examples/test/utils.py @@ -35,7 +35,7 @@ def assert_metrics(footprints): def patch_config( config_json_path: str, - search_algorithm: str, + sampler: str, execution_order: str, system: str, is_gpu: bool = False, @@ -50,15 +50,15 @@ def patch_config( olive_config["clean_cache"] = True # update search strategy - if not search_algorithm: + if not sampler: olive_config["search_strategy"] = False else: olive_config["search_strategy"] = { - "search_algorithm": search_algorithm, + "sampler": sampler, "execution_order": execution_order, } - if search_algorithm in ("random", "tpe"): - olive_config["search_strategy"].update({"num_samples": 3, "seed": 0}) + if sampler in ("random", "tpe"): + olive_config["search_strategy"].update({"max_samples": 3, "seed": 0}) update_azureml_config(olive_config) if system == "aml_system": @@ -76,7 +76,7 @@ def patch_config( # as our docker image is big, we need to reduce the agent size to avoid timeout # for the docker system test, we skip to search for transformers optimization as # it is tested in other olive system tests - olive_config["search_strategy"]["num_samples"] = 2 + olive_config["search_strategy"]["max_samples"] = 2 return olive_config diff --git a/olive/cache.py b/olive/cache.py index 23384adcd..0029c52b8 100644 --- a/olive/cache.py +++ b/olive/cache.py @@ -280,7 +280,7 @@ def get_output_model_id( input_model_id: str, accelerator_spec: "AcceleratorSpec" = None, ): - run_json = self.get_run_json(pass_name, pass_config, input_model_id, accelerator_spec) + run_json = self.get_run_json(pass_name.lower(), pass_config, input_model_id, accelerator_spec) return hash_dict(run_json)[:8] def get_cache_dir(self) -> Path: diff --git a/olive/cli/auto_opt.py b/olive/cli/auto_opt.py index 2c358fdc3..11213b50e 100644 --- a/olive/cli/auto_opt.py +++ b/olive/cli/auto_opt.py @@ -415,7 +415,7 @@ def _get_passes_config(self, config: Dict[str, Any], olive_config: OlivePackageC TEMPLATE = { "input_model": {"type": "HfModel"}, "auto_optimizer_config": {}, - "search_strategy": {"execution_order": "joint", "search_algorithm": "tpe", "num_samples": 5, "seed": 0}, + "search_strategy": {"execution_order": "joint", "sampler": "tpe", "max_samples": 5, "seed": 0}, "systems": { "local_system": { "type": "LocalSystem", diff --git a/olive/cli/base.py b/olive/cli/base.py index aec4cfd70..7f6bd1fc9 100644 --- a/olive/cli/base.py +++ b/olive/cli/base.py @@ -610,7 +610,7 @@ def update_search_options(args, config): "search_strategy", { "execution_order": "joint", - "search_algorithm": args.enable_search, + "sampler": args.enable_search, "seed": args.seed, }, ), diff --git a/olive/cli/quantize.py b/olive/cli/quantize.py index ca6b8ac62..70eb39eb8 100644 --- a/olive/cli/quantize.py +++ b/olive/cli/quantize.py @@ -8,6 +8,7 @@ import tempfile from argparse import ArgumentParser +from collections import OrderedDict from copy import deepcopy from typing import Any, Dict @@ -121,7 +122,6 @@ def _get_run_config(self, tempdir: str) -> Dict[str, Any]: self.args.implementation = [self.args.implementation] to_replace = [ - ("pass_flows", [self.args.implementation]), (("passes", "awq", "w_bit"), precision), (("passes", "gptq", "bits"), precision), (("passes", "bnb4", "quant_type"), precision), @@ -137,6 +137,7 @@ def _get_run_config(self, tempdir: str) -> Dict[str, Any]: if v is not None: set_nested_dict_value(config, k, v) + config["passes"] = OrderedDict([(k, v) for k, v in config["passes"].items() if k in self.args.implementation]) return config def run(self): @@ -184,7 +185,6 @@ def run(self): # "inc_static": {"type": "IncStaticQuantization", "data_config": "default_data_config"}, # "vitis": {"type": "VitisAIQuantization", "data_config": "default_data_config"}, }, - "pass_flows": [], "output_dir": "models", "host": "local_system", "target": "local_system", diff --git a/olive/engine/config.py b/olive/engine/config.py index a64d5d189..db820b952 100644 --- a/olive/engine/config.py +++ b/olive/engine/config.py @@ -5,9 +5,10 @@ from typing import Union from olive.common.config_utils import ConfigBase -from olive.common.pydantic_v1 import Extra +from olive.common.pydantic_v1 import Extra, Field from olive.evaluator.olive_evaluator import OliveEvaluatorConfig -from olive.strategy.search_strategy import SearchStrategyConfig +from olive.passes.pass_config import AbstractPassConfig +from olive.search.search_strategy import SearchStrategyConfig from olive.systems.system_config import SystemConfig # pass search-point was pruned due to failed run @@ -25,3 +26,38 @@ class EngineConfig(ConfigBase, extra=Extra.forbid): evaluator: OliveEvaluatorConfig = None plot_pareto_frontier: bool = False no_artifacts: bool = False + + +class RunPassConfig(AbstractPassConfig): + """Pass configuration for Olive workflow. + + This is the configuration for a single pass in Olive workflow. It includes configurations for pass type, config, + etc. + + Example: + .. code-block:: json + + { + "type": "OlivePass", + "config": { + "param1": "value1", + "param2": "value2" + } + } + + """ + + host: Union[SystemConfig, str] = Field( + None, + description=( + "Host system for the pass. If it is a string, must refer to a system config under `systems` section. If not" + " provided, use the engine's host system." + ), + ) + evaluator: Union[OliveEvaluatorConfig, str] = Field( + None, + description=( + "Evaluator for the pass. If it is a string, must refer to an evaluator config under `evaluators` section." + " If not provided, use the engine's evaluator." + ), + ) diff --git a/olive/engine/engine.py b/olive/engine/engine.py index 06ba28137..e6ead573c 100644 --- a/olive/engine/engine.py +++ b/olive/engine/engine.py @@ -8,14 +8,15 @@ import time from collections import OrderedDict, defaultdict from contextlib import contextmanager +from copy import deepcopy from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union from olive.cache import CacheConfig, OliveCache from olive.common.config_utils import validate_config from olive.common.constants import DEFAULT_WORKFLOW_ID, LOCAL_INPUT_MODEL_ID -from olive.engine.config import FAILED_CONFIG, INVALID_CONFIG, PRUNED_CONFIGS +from olive.engine.config import FAILED_CONFIG, INVALID_CONFIG, PRUNED_CONFIGS, RunPassConfig from olive.engine.footprint import Footprint, FootprintNode, FootprintNodeMetric, get_best_candidate_node from olive.engine.packaging.packaging_generator import generate_output_artifacts from olive.evaluator.metric import Metric @@ -26,8 +27,8 @@ from olive.logging import enable_filelog from olive.model import ModelConfig from olive.package_config import OlivePackageConfig -from olive.strategy.search_parameter import SearchParameter -from olive.strategy.search_strategy import SearchStrategy, SearchStrategyConfig +from olive.search.search_sample import SearchSample +from olive.search.search_strategy import SearchStrategy, SearchStrategyConfig from olive.systems.common import SystemType from olive.systems.system_config import SystemConfig from olive.systems.utils import create_managed_system_with_cache @@ -35,7 +36,7 @@ if TYPE_CHECKING: from olive.engine.packaging.packaging_config import PackagingConfig from olive.passes.olive_pass import Pass - from olive.systems.olive_system import OliveSystem + from olive.search.search_parameter import SearchParameter logger = logging.getLogger(__name__) @@ -84,10 +85,9 @@ def __init__( self.skip_saving_artifacts = no_artifacts self.azureml_client_config = azureml_client_config - self.pass_run_configs: Dict[str, Dict[str, Any]] = OrderedDict() - self.pass_flows: List[List[str]] = [] - self.search_spaces: List[List[Tuple[str, Dict[str, SearchParameter]]]] = [] - self.footprints = defaultdict(Footprint) + self.input_passes_configs: Dict[str, List[RunPassConfig]] = OrderedDict() + self.computed_passes_configs: Dict[str, RunPassConfig] = OrderedDict() + self.footprints: Dict[AcceleratorSpec, Footprint] = defaultdict(Footprint) self._initialized = False @@ -108,18 +108,17 @@ def initialize(self, log_to_file: bool = False, log_severity_level: int = 1): if self.evaluator_config: self.evaluator_config = self.cache.prepare_resources_for_local(self.evaluator_config) - for pass_run_config in self.pass_run_configs.values(): - if pass_run_config["evaluator"]: - pass_run_config["evaluator"] = self.cache.prepare_resources_for_local(pass_run_config["evaluator"]) + for passes_configs in self.input_passes_configs.values(): + for pass_config in passes_configs: + if pass_config.evaluator: + pass_config.evaluator = self.cache.prepare_resources_for_local(pass_config.evaluator) - for pass_run_config in self.pass_run_configs.values(): - host_type = pass_run_config["host"].system_type if pass_run_config["host"] else self.host_config.type - if host_type != SystemType.AzureML: - pass_run_config["input_config"] = self.cache.prepare_resources_for_local( - pass_run_config["input_config"] - ) + for passes_configs in self.input_passes_configs.values(): + for pass_config in passes_configs: + host_type = pass_config.host.system_type if pass_config.host else self.host_config.type + if host_type != SystemType.AzureML: + pass_config.config = self.cache.prepare_resources_for_local(pass_config.config) - self.set_pass_flows(self.pass_flows) self._initialized = True def register( @@ -127,12 +126,12 @@ def register( pass_type: Union[Type["Pass"], str], config: Dict[str, Any] = None, name: str = None, - host: "OliveSystem" = None, + host: SystemConfig = None, evaluator_config: OliveEvaluatorConfig = None, ): """Register a pass configuration so that it could be instantiated and executed later.""" if name: - assert name not in self.pass_run_configs, f"Pass with name {name} already registered" + assert name not in self.input_passes_configs, f"Pass with name {name} already registered" else: idx = 0 while True: @@ -140,26 +139,22 @@ def register( if idx > 0: name = f"{name}_{idx}" idx += 1 - if name not in self.pass_run_configs: + if name not in self.input_passes_configs: break pass_type_name = pass_type if isinstance(pass_type, str) else pass_type.__name__ - logger.debug("Registering pass %s", pass_type_name) - self.pass_run_configs[name] = { - "type": pass_type_name, - "input_config": config or {}, - "host": host, - "evaluator": evaluator_config, - } - - def set_pass_flows(self, pass_flows: List[List[str]] = None): - """Construct pass flows from a list of pass names. - - Args: - pass_flows: a list of pass names, each pass name is a string. + logger.debug("Registering pass %s:%s", name, pass_type_name) + self.input_passes_configs[name] = [ + RunPassConfig( + type=pass_type_name, + config=config or {}, + host=host, + evaluator=evaluator_config, + ) + ] - """ - self.pass_flows = pass_flows or [list(self.pass_run_configs.keys())] + def set_input_passes_configs(self, pass_configs: Dict[str, List[RunPassConfig]]): + self.input_passes_configs = pass_configs def run( self, @@ -248,7 +243,7 @@ def run( run_history = self.footprints[accelerator_spec].summarize_run_history() self.dump_run_history(run_history, output_subdirs[accelerator_spec] / "run_history.txt") - if packaging_config and self.pass_run_configs: + if packaging_config and self.input_passes_configs: # TODO(trajep): should we support packaging pytorch model? logger.info("Package top ranked %d models as artifacts", sum(len(f.nodes) for f in outputs.values())) generate_output_artifacts( @@ -264,7 +259,7 @@ def run( # TODO(team): refactor output structure # Do not change condition order. For no search, values of outputs are MetricResult # Consolidate the output structure for search and no search mode - if outputs and self.pass_run_configs and not next(iter(outputs.values())).check_empty_nodes(): + if outputs and self.input_passes_configs and not next(iter(outputs.values())).check_empty_nodes(): best_node: FootprintNode = get_best_candidate_node(outputs, self.footprints) self.cache.save_model(model_id=best_node.model_id, output_dir=output_dir, overwrite=True) if len(accelerator_output_dir_list) > 1 and self.skip_saving_artifacts: @@ -280,12 +275,6 @@ def run_accelerator( evaluate_input_model: bool, accelerator_spec: "AcceleratorSpec", ): - # Setup pass configs - self._setup_pass_configs(accelerator_spec) - - # generate search space - self._setup_search_spaces(accelerator_spec) - # hash the input model input_model_id = input_model_config.get_model_id() if input_model_id == LOCAL_INPUT_MODEL_ID and self.cache.enable_shared_cache: @@ -300,32 +289,29 @@ def run_accelerator( try: if evaluate_input_model and not self.evaluator_config: logger.debug("evaluate_input_model is True but no evaluator provided. Skipping input model evaluation.") + elif evaluate_input_model: results = self._evaluate_model( input_model_config, input_model_id, self.evaluator_config, accelerator_spec ) logger.info("Input model evaluation results: %s", results) + if not self.skip_saving_artifacts: results_path = output_dir / "input_model_metrics.json" with results_path.open("w") as f: json.dump(results.to_json(), f, indent=4) logger.info("Saved evaluation results of input model to %s", results_path) - if not self.pass_run_configs: + if not self.input_passes_configs: logger.debug("No passes registered, return input model evaluation results.") return results if self.search_strategy: logger.debug("Running Olive in search mode ...") - output_footprint = self.run_search( - input_model_config, - input_model_id, - accelerator_spec, - output_dir, - ) + output_footprint = self._run_search(input_model_config, input_model_id, accelerator_spec, output_dir) else: logger.debug("Running Olive in no-search mode ...") - output_footprint = self.run_no_search(input_model_config, input_model_id, accelerator_spec, output_dir) + output_footprint = self._run_no_search(input_model_config, input_model_id, accelerator_spec, output_dir) except EXCEPTIONS_TO_RAISE: raise except Exception: @@ -343,29 +329,16 @@ def get_host_device(self): # for host device, we will always use the first accelerator device return self.host_config.config.accelerators[0].device if self.host_config.config.accelerators else None - def _setup_pass_configs(self, accelerator_spec: "AcceleratorSpec"): - disable_search = self.search_strategy is None - for pass_run_config in self.pass_run_configs.values(): - pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_run_config["type"]) - pass_run_config["config"] = pass_cls.generate_config( - accelerator_spec, pass_run_config["input_config"], disable_search - ) + def _compute_no_search_pass_configs(self, accelerator_spec: "AcceleratorSpec"): + self.computed_passes_configs.clear() + for name, passes_configs in self.input_passes_configs.items(): + pass_config = validate_config(deepcopy(passes_configs[0].dict()), RunPassConfig) - def _setup_search_spaces(self, accelerator_spec: "AcceleratorSpec"): - self.search_spaces.clear() - if self.search_strategy is None: - return + pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_config.type) + pass_config.config = pass_cls.generate_config(accelerator_spec, pass_config.config, {}, True) + self.computed_passes_configs[name] = pass_config - for pass_flow in self.pass_flows: - pass_search_spaces: List[Tuple[str, Dict[str, SearchParameter]]] = [] - for pass_name in pass_flow: - pass_run_config = self.pass_run_configs[pass_name] - pass_search_spaces.append( - (pass_name, {k: v for k, v in pass_run_config["config"].items() if isinstance(v, SearchParameter)}) - ) - self.search_spaces.append(pass_search_spaces) - - def run_no_search( + def _run_no_search( self, input_model_config: ModelConfig, input_model_id: str, @@ -375,46 +348,91 @@ def run_no_search( """Run all the registered Olive pass flows in no-search mode.""" output_model_dir = Path(output_dir) - output_model_ids = [] - for pass_flow in self.pass_flows: - # search point is empty since there is no search - passes_to_run = [(pass_name, {}) for pass_name in pass_flow] - - # run all the passes in the pass flow - logger.debug("Running %s with no search ...", pass_flow) - should_prune, signal, model_ids = self._run_passes( - passes_to_run, - input_model_config, - input_model_id, - accelerator_spec, - ) + # Compute pas configs + self._compute_no_search_pass_configs(accelerator_spec) - if should_prune: - failed_pass = pass_flow[len(model_ids)] - logger.warning( - "Flow %s is pruned due to failed or invalid config for pass '%s'", pass_flow, failed_pass - ) - continue + # run all the passes in the pass flow + pass_flow = list(self.computed_passes_configs.keys()) + logger.debug("Running %s with no search ...", pass_flow) + should_prune, signal, model_ids = self._run_passes(input_model_config, input_model_id, accelerator_spec) - # use output_model_dir if there is only one pass flow - # else output_model_dir/pass_flow - flow_output_dir = output_model_dir / "-".join(pass_flow) if len(self.pass_flows) > 1 else output_model_dir - flow_output_dir.mkdir(parents=True, exist_ok=True) + if should_prune: + failed_pass = pass_flow[len(model_ids)] + logger.warning("Flow %s is pruned due to failed or invalid config for pass '%s'", pass_flow, failed_pass) + return Footprint() - if signal is not None and not self.skip_saving_artifacts: - results_path = flow_output_dir / "metrics.json" - with open(results_path, "w") as f: - json.dump(signal.to_json(), f, indent=4) - logger.info("Saved evaluation results of output model to %s", results_path) + # use output_model_dir if there is only one pass flow + # else output_model_dir/pass_flow + flow_output_dir = output_model_dir / "-".join(pass_flow) if len(pass_flow) > 1 else output_model_dir + flow_output_dir.mkdir(parents=True, exist_ok=True) - output_model_ids.append(model_ids[-1]) + if signal is not None and not self.skip_saving_artifacts: + results_path = flow_output_dir / "metrics.json" + with open(results_path, "w") as f: + json.dump(signal.to_json(), f, indent=4) + logger.info("Saved evaluation results of output model to %s", results_path) - output_footprints = self.footprints[accelerator_spec].create_footprints_by_model_ids(output_model_ids) + output_footprints = self.footprints[accelerator_spec].create_footprints_by_model_ids([model_ids[-1]]) if not self.skip_saving_artifacts: output_footprints.to_file(output_dir / "output_footprints.json") return output_footprints - def run_search( + def _get_search_space_config(self, accelerator_spec: "AcceleratorSpec"): + space_config: Dict[str, List[Dict[str, SearchParameter]]] = OrderedDict() + for pass_name, passes_configs in self.input_passes_configs.items(): + space_config[pass_name] = pass_params_config = [] + for pass_config in passes_configs: + pass_cls = self.olive_config.import_pass_module(pass_config.type) + pass_params_config.append(pass_cls.get_search_space_params(accelerator_spec, pass_config.config, False)) + return space_config + + def _get_search_space_objectives( + self, + input_model_config: ModelConfig, + input_model_id: str, + accelerator_spec: "AcceleratorSpec", + ) -> Dict[str, List[Dict[str, Any]]]: + objectives_by_pass_name: Dict[str, List[Dict[str, Any]]] = {} + objectives_by_evaluator_name: Dict[str, Dict[str, Any]] = {} + for pass_name, passes_configs in self.input_passes_configs.items(): + objectives_by_pass_name[pass_name] = passes_objectives = [] + for pass_config in passes_configs: + evaluator_config = pass_config.evaluator or self.evaluator_config + if evaluator_config.name not in objectives_by_evaluator_name: + objectives_by_evaluator_name[evaluator_config.name] = self.resolve_objectives( + input_model_config, input_model_id, evaluator_config.metrics, accelerator_spec + ) + passes_objectives.append(objectives_by_evaluator_name[evaluator_config.name]) + + accelerator_objectives: Dict[str, Any] = {} + for objectives in objectives_by_evaluator_name.values(): + accelerator_objectives.update(objectives) + self.footprints[accelerator_spec].record_objective_dict(accelerator_objectives) + return objectives_by_pass_name + + def _compute_search_pass_configs(self, accelerator_spec: "AcceleratorSpec", sample: SearchSample): + self.computed_passes_configs.clear() + sample_passes_configs = sample.passes_configs + if not sample_passes_configs: + return + + disable_pass_params_search = not self.search_strategy.config.include_pass_params + for pass_name, passes_configs in self.input_passes_configs.items(): + if pass_name in sample_passes_configs: + sample_pass_config = sample_passes_configs[pass_name] + pass_config = passes_configs[sample_pass_config["index"]] + pass_config = validate_config(deepcopy(pass_config.dict()), RunPassConfig) + + pass_cls = self.olive_config.import_pass_module(pass_config.type) + pass_config.config = pass_cls.generate_config( + accelerator_spec, + pass_config.config, + sample_pass_config["params"], + disable_pass_params_search, + ) + self.computed_passes_configs[pass_name] = pass_config + + def _run_search( self, input_model_config: ModelConfig, input_model_id: str, @@ -422,55 +440,34 @@ def run_search( output_dir: Path, ): """Run all the registered Olive passes in search model where search strategy is not None.""" - # get objective_dict - evaluator_config = self.evaluator_for_pass(list(self.pass_run_configs.keys())[-1]) - - if evaluator_config is None: - raise ValueError("No evaluator provided for the last pass") - else: - objective_dict = self.resolve_objectives( - input_model_config, input_model_id, evaluator_config.metrics, accelerator_spec - ) - self.footprints[accelerator_spec].record_objective_dict(objective_dict) - # initialize the search strategy - self.search_strategy.initialize(self.search_spaces, input_model_id, objective_dict) - output_model_num = self.search_strategy.get_output_model_num() - - # record start time - start_time = time.time() - iter_num = 0 - while True: - iter_num += 1 - - # get the next step - next_step = self.search_strategy.next_step() + search_space_config = self._get_search_space_config(accelerator_spec) + search_space_objectives = self._get_search_space_objectives( + input_model_config, input_model_id, accelerator_spec + ) + self.search_strategy.initialize(search_space_config, input_model_id, search_space_objectives) - # if no more steps, break - if next_step is None: - break + for sample in self.search_strategy: # pylint: disable=not-an-iterable + self._compute_search_pass_configs(accelerator_spec, sample) - # get the model id of the first input model - model_id = next_step["model_id"] - model_config = input_model_config if model_id == input_model_id else self._load_model(model_id) + if self.computed_passes_configs: + # get the model id of the first input model + model_id = sample.model_ids[0] + model_config = input_model_config if model_id == input_model_id else self._load_model(model_id) - logger.debug("Step %d with search point %s ...", iter_num, next_step["search_point"]) + logger.info( + "Step %d with search point %s ...", self.search_strategy.iteration_count, sample.search_point + ) - # run all the passes in the step - should_prune, signal, model_ids = self._run_passes( - next_step["passes"], - model_config, - model_id, - accelerator_spec, - ) + # run all the passes in the step + should_prune, signal, model_ids = self._run_passes(model_config, model_id, accelerator_spec) + else: + should_prune, signal, model_ids = True, None, [] # record feedback signal - self.search_strategy.record_feedback_signal(next_step["search_point"], signal, model_ids, should_prune) - - time_diff = time.time() - start_time - self.search_strategy.check_exit_criteria(iter_num, time_diff, signal) + self.search_strategy.record_feedback_signal(sample.search_point.index, signal, model_ids, should_prune) - return self.create_pareto_frontier_footprints(accelerator_spec, output_model_num, output_dir) + return self.create_pareto_frontier_footprints(accelerator_spec, None, output_dir) def create_pareto_frontier_footprints( self, accelerator_spec: "AcceleratorSpec", output_model_num: int, output_dir: Path @@ -598,13 +595,13 @@ def resolve_goals( return resolved_goals - def host_for_pass(self, pass_name: str) -> "OliveSystem": - host: SystemConfig = self.pass_run_configs[pass_name]["host"] - return host or self.host + def host_for_pass(self, pass_name: str) -> SystemConfig: + host: SystemConfig = self.computed_passes_configs[pass_name].host + return host.create_system() if host else self.host def evaluator_for_pass(self, pass_name: str) -> OliveEvaluatorConfig: """Return evaluator for the given pass.""" - return self.pass_run_configs[pass_name]["evaluator"] or self.evaluator_config + return self.computed_passes_configs[pass_name].evaluator or self.evaluator_config def _cache_model(self, model_id: str, model: Union[ModelConfig, str], check_object: bool = True): # TODO(trajep): move model/pass run/evaluation cache into footprints @@ -623,7 +620,6 @@ def _load_model(self, model_id: str) -> Union[ModelConfig, str]: def _run_passes( self, - passes: List[Tuple[str, Dict[str, Any]]], model_config: ModelConfig, model_id: str, accelerator_spec: "AcceleratorSpec", @@ -637,10 +633,9 @@ def _run_passes( model_ids = [] pass_name = None - for pass_name, pass_search_point in passes: + for pass_name in self.computed_passes_configs: model_config, model_id = self._run_pass( pass_name, - pass_search_point, model_config, model_id, accelerator_spec, @@ -663,7 +658,7 @@ def _run_passes( else: logger.info("Run model evaluation for the final model...") signal = self._evaluate_model(model_config, model_id, evaluator_config, accelerator_spec) - logger.debug("Signal: %s", signal) + logger.debug("Signal: %s, %s", signal, model_ids) else: signal = None logger.warning("Skipping evaluation as model was pruned") @@ -673,22 +668,21 @@ def _run_passes( def _run_pass( self, pass_name: str, - pass_search_point: Dict[str, Any], input_model_config: ModelConfig, input_model_id: str, accelerator_spec: "AcceleratorSpec", ): """Run a pass on the input model.""" run_start_time = datetime.now().timestamp() - pass_run_config: Dict[str, Any] = self.pass_run_configs[pass_name] - pass_type_name = pass_run_config["type"] - logger.info("Running pass %s:%s %s", pass_name, pass_type_name, pass_search_point) + pass_config: RunPassConfig = self.computed_passes_configs[pass_name] + pass_type_name = pass_config.type + + logger.info("Running pass %s:%s", pass_name, pass_type_name) # check whether the config is valid - pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_run_config["type"]) - pass_config = pass_cls.config_at_search_point(pass_search_point, accelerator_spec, pass_run_config["config"]) - if not pass_cls.validate_config(pass_config, accelerator_spec, self.search_strategy is None): + pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_config.type) + if not pass_cls.validate_config(pass_config.config, accelerator_spec, self.search_strategy is None): logger.warning("Invalid config, pruned.") logger.debug(pass_config) # no need to record in footprint since there was no run and thus no valid/failed model @@ -697,12 +691,13 @@ def _run_pass( # this helps reusing cached models for different accelerator specs return INVALID_CONFIG, None - p: Pass = pass_cls(accelerator_spec, pass_config, self.get_host_device()) - pass_config = p.serialize_config(pass_config, check_object=True) + p: Pass = pass_cls(accelerator_spec, pass_config.config, self.get_host_device()) + pass_config = p.serialize_config(pass_config.config, check_object=True) output_model_config = None # load run from cache if it exists run_accel = None if p.is_accelerator_agnostic(accelerator_spec) else accelerator_spec + output_model_id = self.cache.get_output_model_id(pass_type_name, pass_config, input_model_id, run_accel) run_cache = self.cache.load_run_from_model_id(output_model_id) if run_cache: @@ -869,6 +864,7 @@ def create_system(config: "SystemConfig", accelerator_spec): target_start_time = time.time() self.target = create_system(self.target_config, accelerator_spec) logger.info("Target system created in %f seconds", time.time() - target_start_time) + if not self.host: host_accelerators = self.host_config.config.accelerators if host_accelerators and host_accelerators[0].execution_providers: diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index d23b85748..2503503fb 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -1141,6 +1141,7 @@ def evaluate( class OliveEvaluatorConfig(NestedConfig): _nested_field_name = "type_args" + name: str = None type: str = None type_args: Dict = Field(default_factory=dict) diff --git a/olive/passes/olive_pass.py b/olive/passes/olive_pass.py index 1f9cbb061..2de00a31b 100644 --- a/olive/passes/olive_pass.py +++ b/olive/passes/olive_pass.py @@ -23,15 +23,14 @@ create_config_class, ) from olive.resource_path import ResourcePath -from olive.strategy.search_parameter import ( +from olive.search.search_parameter import ( Categorical, Conditional, ConditionalDefault, SearchParameter, SpecialParamValue, ) -from olive.strategy.search_space import SearchSpace -from olive.strategy.utils import cyclic_search_space, order_search_parameters +from olive.search.utils import cyclic_search_space, order_search_parameters logger = logging.getLogger(__name__) @@ -110,11 +109,14 @@ def generate_config( cls, accelerator_spec: AcceleratorSpec, config: Optional[Dict[str, Any]] = None, + point: Optional[Dict[str, Any]] = None, disable_search: Optional[bool] = False, ) -> Dict[str, Any]: - """Generate search space for the pass.""" + """Get the configuration for the pass at a specific point in the search space.""" assert accelerator_spec is not None, "Please specify the accelerator spec for the pass" + config = config or {} + point = point or {} # Get the config class with default value or default search value config_class, default_config = cls.get_config_class(accelerator_spec, disable_search) @@ -125,10 +127,36 @@ def generate_config( # Generate the search space by using both default value and default search value and user provided config config = validate_config(config, config_class) - config = cls._resolve_config(config, default_config) fixed_values, search_params = cls._init_fixed_and_search_params(config, default_config) - return {**fixed_values, **search_params} + assert ( + set(point.keys()).intersection(set(search_params.keys())) == point.keys() + ), "Search point is not in the search space." + return {**fixed_values, **search_params, **point} + + @classmethod + def get_search_space_params( + cls, + accelerator_spec: AcceleratorSpec, + config: Optional[Dict[str, Any]] = None, + disable_search: Optional[bool] = False, + ) -> Dict[str, Any]: + """Generate search space for the pass.""" + assert accelerator_spec is not None, "Please specify the accelerator spec for the pass" + config = config or {} + + # Get the config class with default value or default search value + config_class, default_config = cls.get_config_class(accelerator_spec, disable_search) + + if not disable_search: + # Replace user-provided values with Categorical if user intended to search + config = cls._identify_search_values(config, default_config) + + # Generate the search space by using both default value and default search value and user provided config + config = validate_config(config, config_class) + config = cls._resolve_config(config, default_config) + _, search_params = cls._init_fixed_and_search_params(config, default_config) + return search_params @classmethod def _identify_search_values( @@ -184,29 +212,6 @@ def default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConf ), f"{param} ending with data_config must be of type DataConfig." return config - @classmethod - def config_at_search_point( - cls, - point: Dict[str, Any], - accelerator_spec: AcceleratorSpec, - config: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - """Get the configuration for the pass at a specific point in the search space.""" - assert accelerator_spec is not None, "Please specify the accelerator spec for the pass" - - # Get the config class with default search value - config_class, default_config = cls.get_config_class(accelerator_spec) - - # Replace user-provided values with Categorical if user intended to search - config = cls._identify_search_values(config or {}, default_config) - - # Generate the search space by using both default value and default search value and user provided config - config = validate_config(config, config_class) - config = cls._resolve_config(config, default_config) - fixed_values, search_params = cls._init_fixed_and_search_params(config, default_config) - assert set(point.keys()) == set(search_params.keys()), "Search point is not in the search space." - return {**fixed_values, **search_params, **point} - @classmethod def validate_config( cls, @@ -451,8 +456,6 @@ def _init_fixed_and_search_params( value = str(Path(value).resolve()) fixed_params[key] = value assert not cyclic_search_space(search_space), "Search space is cyclic." - # TODO(jambayk): better error message, e.g. which parameters are invalid, how they are invalid - assert SearchSpace({"search_space": search_space}).size() > 0, "There are no valid points in the search space." return fixed_params, search_space @classmethod @@ -528,5 +531,5 @@ def create_pass_from_dict( if accelerator_spec is None: accelerator_spec = DEFAULT_CPU_ACCELERATOR - config = pass_cls.generate_config(accelerator_spec, config, disable_search) + config = pass_cls.generate_config(accelerator_spec, config, disable_search=disable_search) return pass_cls(accelerator_spec, config, host_device) diff --git a/olive/passes/onnx/inc_quantization.py b/olive/passes/onnx/inc_quantization.py index 0ff18ddd6..d697ad2ab 100644 --- a/olive/passes/onnx/inc_quantization.py +++ b/olive/passes/onnx/inc_quantization.py @@ -24,7 +24,7 @@ from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_has_adapters, model_proto_to_olive_model from olive.passes.pass_config import PassConfigParam -from olive.strategy.search_parameter import Boolean, Categorical, Conditional +from olive.search.search_parameter import Boolean, Categorical, Conditional logger = logging.getLogger(__name__) diff --git a/olive/passes/onnx/nvmo_quantization.py b/olive/passes/onnx/nvmo_quantization.py index b44feb65f..f3d88b09c 100644 --- a/olive/passes/onnx/nvmo_quantization.py +++ b/olive/passes/onnx/nvmo_quantization.py @@ -20,7 +20,7 @@ from olive.passes import Pass from olive.passes.onnx.common import model_proto_to_olive_model from olive.passes.pass_config import PassConfigParam -from olive.strategy.search_parameter import Categorical +from olive.search.search_parameter import Categorical logger = logging.getLogger(__name__) diff --git a/olive/passes/onnx/quantization.py b/olive/passes/onnx/quantization.py index 42504b90b..e7c8a78ba 100644 --- a/olive/passes/onnx/quantization.py +++ b/olive/passes/onnx/quantization.py @@ -29,7 +29,7 @@ ) from olive.passes.pass_config import PassConfigParam from olive.resource_path import LocalFile -from olive.strategy.search_parameter import Boolean, Categorical, Conditional, ConditionalDefault +from olive.search.search_parameter import Boolean, Categorical, Conditional, ConditionalDefault logger = logging.getLogger(__name__) diff --git a/olive/passes/onnx/session_params_tuning.py b/olive/passes/onnx/session_params_tuning.py index 9e69f4ec1..701c35080 100644 --- a/olive/passes/onnx/session_params_tuning.py +++ b/olive/passes/onnx/session_params_tuning.py @@ -23,7 +23,7 @@ from olive.model import ONNXModelHandler from olive.passes import Pass from olive.passes.pass_config import PassConfigParam, get_user_script_data_config -from olive.strategy.search_parameter import Categorical +from olive.search.search_parameter import Categorical logger = logging.getLogger(__name__) @@ -114,25 +114,25 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon "providers_list": PassConfigParam( type_=str, default_value=execution_provider, - searchable_values=Categorical(AcceleratorLookup.get_execution_providers_for_device(device)), + search_defaults=Categorical(AcceleratorLookup.get_execution_providers_for_device(device)), description="Execution providers framework list to execute the ONNX models.", ), "provider_options_list": PassConfigParam( type_=Dict[str, Any], default_value={}, - searchable_values=Categorical([{}]), + search_defaults=Categorical([{}]), description="Execution provider options to execute the ONNX models.", ), "execution_mode_list": PassConfigParam( type_=int, default_value=None, - searchable_values=Categorical([None]), + search_defaults=Categorical([None]), description="Parallelism list between operators.", ), "opt_level_list": PassConfigParam( type_=int, default_value=None, - searchable_values=Categorical([None]), + search_defaults=Categorical([None]), description="Optimization level list for ONNX model.", ), "trt_fp16_enable": PassConfigParam( @@ -141,13 +141,13 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon "intra_thread_num_list": PassConfigParam( type_=int, default_value=None, - searchable_values=Categorical([None]), + search_defaults=Categorical([None]), description="List of intra thread number for test.", ), "inter_thread_num_list": PassConfigParam( type_=int, default_value=None, - searchable_values=Categorical([None]), + search_defaults=Categorical([None]), description="List of inter thread number for test.", ), "extra_session_config": PassConfigParam( diff --git a/olive/passes/onnx/vitis_ai_quantization.py b/olive/passes/onnx/vitis_ai_quantization.py index c61324619..bd0088cac 100644 --- a/olive/passes/onnx/vitis_ai_quantization.py +++ b/olive/passes/onnx/vitis_ai_quantization.py @@ -25,7 +25,7 @@ ) from olive.passes.pass_config import PassConfigParam from olive.resource_path import LocalFile -from olive.strategy.search_parameter import Boolean, Categorical, Conditional +from olive.search.search_parameter import Boolean, Categorical, Conditional logger = logging.getLogger(__name__) diff --git a/olive/passes/pass_config.py b/olive/passes/pass_config.py index 82695f9d9..27144e420 100644 --- a/olive/passes/pass_config.py +++ b/olive/passes/pass_config.py @@ -18,7 +18,7 @@ from olive.hardware.accelerator import Device from olive.hardware.constants import DEVICE_TO_EXECUTION_PROVIDERS from olive.resource_path import validate_resource_path -from olive.strategy.search_parameter import SearchParameter, SpecialParamValue, json_to_search_parameter +from olive.search.search_parameter import SearchParameter, SpecialParamValue, json_to_search_parameter class PassParamDefault(StrEnumBase): diff --git a/olive/passes/pytorch/lora.py b/olive/passes/pytorch/lora.py index d08aba474..6ce05aca9 100644 --- a/olive/passes/pytorch/lora.py +++ b/olive/passes/pytorch/lora.py @@ -31,7 +31,7 @@ from olive.model.config.hf_config import HfLoadKwargs from olive.passes import Pass from olive.passes.olive_pass import PassConfigParam -from olive.strategy.search_parameter import Categorical +from olive.search.search_parameter import Categorical if TYPE_CHECKING: import torch diff --git a/olive/passes/snpe/quantization.py b/olive/passes/snpe/quantization.py index a7cd97f57..398cfed21 100644 --- a/olive/passes/snpe/quantization.py +++ b/olive/passes/snpe/quantization.py @@ -14,7 +14,7 @@ from olive.platform_sdk.qualcomm.snpe.tools.dev import quantize_dlc from olive.platform_sdk.qualcomm.utils.data_loader import FileListCommonDataLoader, FileListDataLoader from olive.resource_path import LocalFile -from olive.strategy.search_parameter import Boolean +from olive.search.search_parameter import Boolean class SNPEQuantization(Pass): diff --git a/olive/search/__init__.py b/olive/search/__init__.py new file mode 100644 index 000000000..862c45ce3 --- /dev/null +++ b/olive/search/__init__.py @@ -0,0 +1,4 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- diff --git a/olive/search/samplers/__init__.py b/olive/search/samplers/__init__.py new file mode 100644 index 000000000..70e45d124 --- /dev/null +++ b/olive/search/samplers/__init__.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from olive.search.samplers.random_sampler import RandomSampler +from olive.search.samplers.search_sampler import SearchSampler +from olive.search.samplers.sequential_sampler import SequentialSampler +from olive.search.samplers.tpe_sampler import TPESampler + +REGISTRY = SearchSampler.registry + +__all__ = ["REGISTRY", "RandomSampler", "SearchSampler", "SequentialSampler", "TPESampler"] diff --git a/olive/search/samplers/optuna_sampler.py b/olive/search/samplers/optuna_sampler.py new file mode 100644 index 000000000..01cd7f36d --- /dev/null +++ b/olive/search/samplers/optuna_sampler.py @@ -0,0 +1,140 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from abc import abstractmethod +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +import optuna + +from olive.common.config_utils import ConfigBase, ConfigParam +from olive.search.samplers.search_sampler import SearchSampler +from olive.search.search_parameter import Categorical, Conditional, SearchParameter +from olive.search.search_point import SearchPoint +from olive.search.search_space import SearchSpace + +if TYPE_CHECKING: + from optuna.trial import Trial + + from olive.evaluator.metric_result import MetricResult + + +optuna.logging.set_verbosity(optuna.logging.WARNING) + + +class OptunaSampler(SearchSampler): + """Optuna sampler for search sampling.""" + + name = "optuna" + + @classmethod + def _default_config(cls) -> Dict[str, ConfigParam]: + return { + **super()._default_config(), + "seed": ConfigParam(type_=int, default_value=1, description="Seed for the rng."), + } + + def __init__( + self, + search_space: SearchSpace, + config: Optional[Union[Dict[str, Any], ConfigBase]] = None, + ): + super().__init__(search_space, config) + + # Initialize the searcher + self._sampler = self._create_sampler() + # TODO(olivedev): There is no absolute direction to set. + # directions = ["maximize" if hib else "minimize" for hib in self._higher_is_betters] + # self._study = optuna.create_study(directions=directions, sampler=self._sampler) + self._study = optuna.create_study(sampler=self._sampler) + self._num_samples_suggested = 0 + self._search_point_index_to_trail_id = {} + + @property + def num_samples_suggested(self) -> int: + """Returns the number of samples suggested so far.""" + return self._num_samples_suggested + + @abstractmethod + def _create_sampler(self) -> optuna.samplers.BaseSampler: + """Create the sampler.""" + + def suggest(self) -> SearchPoint: + """Suggest a new configuration to try.""" + if self.should_stop: + return None + + while True: + trial = self._study.ask() + values: Dict[str, Tuple[int, Any]] = OrderedDict() + spi, _, values = self._get_search_point_values("", "", self._search_space, trial, values) + if spi not in self._search_point_index_to_trail_id: + break + + self._search_point_index_to_trail_id[spi] = trial.number + self._num_samples_suggested += 1 + return SearchPoint(spi, values) + + def _get_search_point_values( + self, + prefix: str, + name: str, + param: Union[SearchParameter, SearchSpace], + trial: "Trial", + values: Dict[str, Tuple[int, Any]], + ) -> Tuple[int, int, Union[Dict[str, Any], Any]]: + if isinstance(param, SearchParameter): + suggestion_name = f"{prefix}__{name}" if prefix else name + + if isinstance(param, Categorical): + suggestions = param.get_support() + elif isinstance(param, Conditional): + parent_values = {parent: values[parent][1] for parent in param.parents} + suggestions = param.get_support_with_args(parent_values) + max_length = max(len(support.get_support()) for support in param.support.values()) + suggestions += param.default.get_support() * (max_length - len(suggestions)) + suggestion_name += "___" + " ".join(str(v) for v in parent_values.values()) + + suggestion_len = len(suggestions) + suggestion_index = trial.suggest_categorical(suggestion_name, list(range(suggestion_len))) + suggestion = suggestions[suggestion_index] + + if isinstance(suggestion, (SearchParameter, SearchSpace)): + suggestion_index, suggestion_len, _ = self._get_search_point_values( + prefix, name, suggestion, trial, values + ) + else: + values[name] = suggestion_index, suggestion + + return suggestion_index, suggestion_len, suggestion + + elif isinstance(param, SearchSpace): + child_values = OrderedDict() + indicies_lengths = [] + for child_name, child_param in param.parameters: + child_index, suggestions_len, _ = self._get_search_point_values( + prefix, child_name, child_param, trial, child_values + ) + indicies_lengths.append((child_index, suggestions_len)) + values[name] = (0, child_values) + + spi = 0 + for child_index, suggestions_len in reversed(indicies_lengths): + spi *= suggestions_len + spi += child_index + + return spi, len(param), child_values + + else: + raise ValueError(f"Unsupported parameter type: {type(param)}") + + def record_feedback_signal( + self, search_point_index: int, objectives: Dict[str, dict], signal: "MetricResult", should_prune: bool = False + ): + trial_id = self._search_point_index_to_trail_id[search_point_index] + if should_prune: + self._study.tell(trial_id, state=optuna.trial.TrialState.PRUNED) + else: + values = [signal[objective].value for objective in objectives] + self._study.tell(trial_id, values) diff --git a/olive/search/samplers/random_sampler.py b/olive/search/samplers/random_sampler.py new file mode 100644 index 000000000..a0f8b4099 --- /dev/null +++ b/olive/search/samplers/random_sampler.py @@ -0,0 +1,61 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from random import Random +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from olive.common.config_utils import ConfigBase, ConfigParam +from olive.search.samplers.search_sampler import SearchSampler +from olive.search.search_space import SearchSpace + +if TYPE_CHECKING: + from olive.search.search_point import SearchPoint + + +class RandomSampler(SearchSampler): + """Random sampler. Samples random points from the search space.""" + + name = "random" + + @classmethod + def _default_config(cls) -> Dict[str, ConfigParam]: + return { + **super()._default_config(), + "seed": ConfigParam(type_=int, default_value=1, description="Seed for the rng."), + } + + def __init__( + self, + search_space: SearchSpace, + config: Optional[Union[Dict[str, Any], ConfigBase]] = None, + ): + super().__init__(search_space, config) + + self._rng = Random(self.config.seed) + self._search_points = [None] * len(self._search_space) + self._available = list(range(len(self._search_space))) + + def reset_rng(self): + """Reset the random number generator.""" + self._rng = Random(self.config.seed) + + @property + def num_samples_suggested(self) -> int: + """Returns the number of samples suggested so far.""" + return len(self._search_space) - len(self._available) + + @property + def should_stop(self) -> bool: + """Check if the searcher should stop at the current trial.""" + return super().should_stop or (len(self._available) == 0) + + def suggest(self) -> "SearchPoint": + """Suggest a new configuration to try.""" + if self.should_stop: + return None + + index = self._available[self._rng.randrange(len(self._available))] + self._available.remove(index) + self._search_points[index] = self._search_space[index] + return self._search_points[index] diff --git a/olive/search/samplers/search_sampler.py b/olive/search/samplers/search_sampler.py new file mode 100644 index 000000000..8e09a37ac --- /dev/null +++ b/olive/search/samplers/search_sampler.py @@ -0,0 +1,69 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Type, Union + +from olive.common.auto_config import AutoConfigClass +from olive.common.config_utils import ConfigBase, ConfigParam +from olive.search.search_space import SearchSpace + +if TYPE_CHECKING: + from olive.evaluator.metric_result import MetricResult + from olive.search.search_point import SearchPoint + + +class SearchSampler(AutoConfigClass): + """Abstract base class for searchers.""" + + registry: ClassVar[Dict[str, Type["SearchSampler"]]] = {} + + @classmethod + def _default_config(cls) -> Dict[str, ConfigParam]: + return { + "max_samples": ConfigParam( + type_=int, + default_value=0, + description="Maximum number of samples to suggest. Exhaustive if set to zero.", + ), + } + + def __init__( + self, + search_space: SearchSpace, + config: Optional[Union[Dict[str, Any], ConfigBase]] = None, + ): + super().__init__(config) + + self._search_space = search_space + + @property + @abstractmethod + def num_samples_suggested(self) -> int: + """Returns the number of samples suggested so far.""" + return 0 + + @property + def max_samples(self) -> int: + """Returns the maximum number of samples to suggest.""" + return self.config.max_samples + + @property + def should_stop(self) -> bool: + """Check if the searcher should stop at the current trial.""" + return ( + (len(self._search_space) == 0) + or (self.num_samples_suggested >= len(self._search_space)) + or ((self.max_samples > 0) and (self.num_samples_suggested >= self.max_samples)) + ) + + @abstractmethod + def suggest(self) -> "SearchPoint": + """Suggest a new configuration to try.""" + return None + + def record_feedback_signal( + self, search_point_index: int, objectives: Dict[str, dict], signal: "MetricResult", should_prune: bool = False + ): + """Report the result of a configuration.""" diff --git a/olive/search/samplers/sequential_sampler.py b/olive/search/samplers/sequential_sampler.py new file mode 100644 index 000000000..e0e6f6663 --- /dev/null +++ b/olive/search/samplers/sequential_sampler.py @@ -0,0 +1,46 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from olive.common.config_utils import ConfigBase +from olive.search.samplers.search_sampler import SearchSampler +from olive.search.search_space import SearchSpace + +if TYPE_CHECKING: + from olive.search.search_point import SearchPoint + + +class SequentialSampler(SearchSampler): + """Sequential Search Algorithm. Does a grid search over the search space.""" + + name = "sequential" + + @classmethod + def _default_config(cls): + return super()._default_config() + + def __init__( + self, + search_space: SearchSpace, + config: Optional[Union[Dict[str, Any], ConfigBase]] = None, + ): + super().__init__(search_space, config) + + self._index = 0 + + @property + def num_samples_suggested(self) -> int: + """Returns the number of samples suggested so far.""" + return self._index + + def suggest(self) -> "SearchPoint": + """Suggest a new configuration to try.""" + if self.should_stop: + return None + + index = self._index + self._index += 1 + + return self._search_space[index] diff --git a/olive/search/samplers/tpe_sampler.py b/olive/search/samplers/tpe_sampler.py new file mode 100644 index 000000000..f52f2831f --- /dev/null +++ b/olive/search/samplers/tpe_sampler.py @@ -0,0 +1,45 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from typing import Dict + +import optuna + +from olive.common.config_utils import ConfigParam +from olive.search.samplers.optuna_sampler import OptunaSampler + + +class TPESampler(OptunaSampler): + """Sample using TPE (Tree-structured Parzen Estimator) algorithm. Uses optuna TPESampler underneath. + + Refer to https://optuna.readthedocs.io/en/stable/reference/samplers/generated/optuna.samplers.TPESampler.html + for more details about the sampler. + """ + + name = "tpe" + + @classmethod + def _default_config(cls) -> Dict[str, ConfigParam]: + return { + **super()._default_config(), + "multivariate": ConfigParam( + type_=bool, default_value=True, description="Use multivariate TPE when suggesting parameters." + ), + "group": ConfigParam( + type_=bool, + default_value=True, + description=( + "If this and multivariate are True, the multivariate TPE with the group decomposed search space is" + " used when suggesting parameters. Refer to 'group' at" + " https://optuna.readthedocs.io/en/stable/reference/samplers/generated/" + "optuna.samplers.TPESampler.html for more information." + ), + ), + } + + def _create_sampler(self) -> optuna.samplers.TPESampler: + """Create the sampler.""" + return optuna.samplers.TPESampler( + multivariate=self.config.multivariate, group=self.config.group, seed=self.config.seed + ) diff --git a/olive/search/search_parameter.py b/olive/search/search_parameter.py new file mode 100644 index 000000000..058170036 --- /dev/null +++ b/olive/search/search_parameter.py @@ -0,0 +1,315 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple, Union + +from olive.common.utils import StrEnumBase, flatten_dict, unflatten_dict + + +class SearchParameter(ABC): + """Base class for search elements. + + Each search element should derive its own class. + """ + + @abstractmethod + def __init__(self, **kwargs): + pass + + @abstractmethod + def get_support(self) -> List[Any]: + """Get the support for the search parameter.""" + raise NotImplementedError + + @abstractmethod + def __repr__(self): + raise NotImplementedError + + @abstractmethod + def to_json(self): + raise NotImplementedError + + +class SpecialParamValue(StrEnumBase): + """Special values for parameters. + + IGNORED: the parameter gets the value "OLIVE_IGNORED_PARAM_VALUE". The pass might ignore this parameter. + INVALID: Any search point with this value is invalid. The search algorithm will not suggest such a search point. + """ + + IGNORED = "OLIVE_IGNORED_PARAM_VALUE" + INVALID = "OLIVE_INVALID_PARAM_VALUE" + + +class Categorical(SearchParameter): + """Search parameter that supports a list of values. + + Examples + -------- + >>> Categorical([1, 2, 3]) + + """ + + def __init__(self, support: Union[List[str], List[int], List[float], List[bool]]): + self.support = support + + def get_support(self) -> Union[List[str], List[int], List[float], List[bool]]: + """Get the support for the search parameter.""" + return self.support + + def __repr__(self): + return f"Categorical({self.support})" + + def to_json(self): + return {"olive_parameter_type": "SearchParameter", "type": "Categorical", "support": self.support} + + +class Boolean(Categorical): + """Search parameter that supports a boolean value. + + Examples + -------- + >>> Boolean() + + """ + + def __init__(self): + super().__init__([True, False]) + + +class Conditional(SearchParameter): + """Conditional search parameter. + + Examples + -------- + # conditional search parameter with one parent + # when parent1 is value1, the support is [1, 2, 3], + # when parent1 is value2, the support is [4, 5, 6], + # otherwise the support is [7, 8, 9] + >>> Conditional( + parents=("parent1",), + support={ + ("value1",): Categorical([1, 2, 3]), + ("value2",): Categorical([4, 5, 6]) + }, + default=Categorical([4, 5, 6]) + ) + + # conditional search parameter with two parents + # when parent1 is value1 and parent2 is value2, the support is [1, 2, 3], otherwise the support is Invalid + >>> Conditional(parents=("parent1", "parent2"), support={("value1", "value2"): Categorical([1, 2, 3])}) + + # when parent1 is value1 and parent2 is value2, the support is [1, 2, 3], + # when parent1 is value1 and parent2 is value3, the support is Invalid, + # otherwise the support is Ignored + >>> Conditional( + parents=("parent1", "parent2"), + support={ + ("value1", "value2"): Categorical([1, 2, 3]), + ("value1", "value3"): Conditional.get_invalid_choice() + }, + default=Conditional.get_ignored_choice() + ) + + # NOTE: The order of "parents" and keys in "support" should be the same. + + # When the value is set to be "invalid", the generated config would also be invalid. + # When the value is set to be "ignored", the generated config would not include the parameter at all. + + """ + + def __init__( + self, + parents: Tuple[str], + support: Dict[Tuple[Any], SearchParameter], + default: SearchParameter = None, + ): + assert isinstance(parents, tuple), "parents must be a tuple" + for key in support: + assert isinstance(key, tuple), "support key must be a tuple" + assert len(key) == len(parents), "support key length must match the number of parents" + + self.parents = parents + self.support = support + self.default = default or self.get_invalid_choice() + + def get_support(self) -> List[Any]: + raise NotImplementedError("Use get_support_with_args instead") + + def get_support_with_args( + self, parent_values: Dict[str, Any] + ) -> Union[List[str], List[int], List[float], List[bool]]: + """Get the support for the search parameter for a given parent value.""" + # pylint: disable=arguments-differ + assert parent_values.keys() == set(self.parents), "parent values keys do not match the parents" + parent_values = tuple(parent_values[parent] for parent in self.parents) + return self.support.get(parent_values, self.default).get_support() + + def condition(self, parent_values: Dict[str, Any]) -> SearchParameter: + """Fix the parent value and return a new search parameter.""" + assert set(parent_values.keys()).issubset(set(self.parents)), "parent values keys not a subset of the parents" + + # if there is only one parent, return the support for the given parent value + if len(self.parents) == 1: + parent_values = (parent_values[self.parents[0]],) + return self.support.get(parent_values, self.default) + + # condition the first parent and create a new conditional + parent_idx = len(self.parents) - 1 + parent = None + for i, parent in enumerate(self.parents): + if parent in parent_values: + parent_value = parent_values[parent] + parent_idx = i + break + new_parents = self.parents[:parent_idx] + self.parents[parent_idx + 1 :] # noqa: E203, RUF100 + new_support = { + key[:parent_idx] + key[parent_idx + 1 :]: value # noqa: E203, RUF100 + for key, value in self.support.items() + if key[parent_idx] == parent_value + } + # if there is no support for the given parent value, return the default + if new_support == {}: + return self.default + # create a new conditional + new_conditional = Conditional(new_parents, new_support, self.default) + + # condition the new conditional if there are more parents to condition, else return the new conditional + del parent_values[parent] + if len(parent_values) == 0: + return new_conditional + return new_conditional.condition(parent_values) + + def __repr__(self): + return f"Conditional(parents: {self.parents}, support: {self.support}, default: {self.default})" + + def to_json(self): + support = {} + for key, value in self.support.items(): + support[key] = value.to_json() + support = unflatten_dict(support) + + return { + "olive_parameter_type": "SearchParameter", + "type": "Conditional", + "parents": self.parents, + "support": support, + "default": self.default.to_json(), + } + + @staticmethod + def get_invalid_choice(): + """Return a categorical search parameter with the invalid choice.""" + return Categorical([SpecialParamValue.INVALID]) + + @staticmethod + def get_ignored_choice(): + """Return a categorical search parameter with the ignored choice.""" + return Categorical([SpecialParamValue.IGNORED]) + + +class ConditionalDefault(Conditional): + """Parameter with conditional default value. + + Examples + -------- + # conditional default with one parent + # when parent1 is value1, the default is 1, + # when parent1 is value2, the default is 2, + # otherwise the default is 3 + >>> ConditionalDefault( + parents=("parent1",), + support={ + ("value1",): 1, + ("value2",): 2 + }, + default=3 + ) + + # conditional default with two parents + # when parent1 is value1 and parent2 is value2, the default is 1, + # otherwise the default is Invalid + >>> ConditionalDefault( + parents=("parent1", "parent2"), + support={("value1", "value2"): 1} + ) + + """ + + def __init__(self, parents: Tuple[str], support: Dict[Tuple[Any], Any], default: Any = SpecialParamValue.INVALID): + support = {key: Categorical([value]) for key, value in support.items()} + default = Categorical([default]) + super().__init__(parents, support, default) + + def get_support_with_args(self, parent_values: Dict[str, Any]) -> Union[bool, int, float, str]: + """Get the support for the search parameter for a given parent value.""" + return super().get_support_with_args(parent_values)[0] + + def condition(self, parent_values: Dict[str, Any]) -> Union[bool, int, float, str, "ConditionalDefault"]: + """Fix the parent value and return a new search parameter.""" + value = super().condition(parent_values) + if isinstance(value, Categorical): + return value.get_support()[0] + if isinstance(value, Conditional): + return self.conditional_to_conditional_default(value) + raise ValueError(f"Unknown search parameter type {type(value)}") + + @staticmethod + def conditional_to_conditional_default(conditional: Conditional) -> "ConditionalDefault": + """Convert a conditional to a conditional default.""" + support = {} + for key, value in conditional.support.items(): + assert isinstance(value, Categorical), "Conditional support must be categorical" + assert len(value.get_support()) == 1, "Conditional support must have only one value" + support[key] = value.get_support()[0] + assert isinstance(conditional.default, Categorical), "Conditional default must be categorical" + assert len(conditional.default.get_support()) == 1, "Conditional default must have only one value" + return ConditionalDefault(conditional.parents, support, conditional.default.get_support()[0]) + + @staticmethod + def conditional_default_to_conditional(conditional_default: "ConditionalDefault") -> Conditional: + """Convert a conditional default to a conditional.""" + return Conditional(conditional_default.parents, conditional_default.support, conditional_default.default) + + def __repr__(self): + support = {key: value.get_support()[0] for key, value in self.support.items()} + default = self.default.get_support()[0] + return f"ConditionalDefault(parents: {self.parents}, support: {support}, default: {default})" + + def to_json(self): + json_data = super().to_json() + json_data["type"] = "ConditionalDefault" + return json_data + + @staticmethod + def get_invalid_choice(): + """Return a categorical search parameter with the invalid choice.""" + return SpecialParamValue.INVALID + + @staticmethod + def get_ignored_choice(): + """Return a categorical search parameter with the ignored choice.""" + return SpecialParamValue.IGNORED + + +def json_to_search_parameter(json: Dict[str, Any]) -> SearchParameter: + """Convert a json to a search parameter.""" + assert json["olive_parameter_type"] == "SearchParameter", "Not a search parameter" + search_parameter_type = json["type"] + if search_parameter_type == "Categorical": + return Categorical(json["support"]) + if search_parameter_type in ("Conditional", "ConditionalDefault"): + + def stop_condition(x): + return isinstance(x, dict) and x.get("olive_parameter_type") == "SearchParameter" + + support = flatten_dict(json["support"], stop_condition=stop_condition) + for key, value in support.items(): + support[key] = json_to_search_parameter(value) + conditional = Conditional(json["parents"], support, json_to_search_parameter(json["default"])) + if search_parameter_type == "ConditionalDefault": + return ConditionalDefault.conditional_to_conditional_default(conditional) + return conditional + raise ValueError(f"Unknown search parameter type {search_parameter_type}") diff --git a/olive/search/search_point.py b/olive/search/search_point.py new file mode 100644 index 000000000..ee03e7e97 --- /dev/null +++ b/olive/search/search_point.py @@ -0,0 +1,56 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Dict, Tuple + +from olive.search.search_parameter import SpecialParamValue + +# ruff: noqa: PD011 + + +@dataclass +class SearchPoint: + """Search point for search space.""" + + def __init__(self, index: int, values: Dict[str, Tuple[int, Any]]): + self.index = index + self.values = values + + def _format(self, arg: Any) -> str: + return ( + "{" + ", ".join(f"{k}[{i}]: {self._format(v)}" for k, (i, v) in arg.items()) + "}" + if isinstance(arg, OrderedDict) + else str(arg) + ) + + def __repr__(self): + return f"SearchPoint({self.index}, {self._format(self.values)})" + + def __eq__(self, other): + return ( + (self.index == other.index) and (self.values == other.values) if isinstance(other, SearchPoint) else False + ) + + def is_valid(self) -> bool: + def _is_valid(values: Dict[str, Tuple[int, Any]]) -> bool: + for v in values.values(): + if isinstance(v, OrderedDict): + if not _is_valid(v): + return False + elif v == SpecialParamValue.INVALID: + return False + return True + + return _is_valid(self.values) + + def to_json(self): + """Return a json representation.""" + return {"index": self.index, "values": self.values} + + @classmethod + def from_json(cls, json_dict): + """Create a SearchPoint object from a json representation.""" + return cls(json_dict["index"], json_dict["values"]) diff --git a/olive/search/search_results.py b/olive/search/search_results.py new file mode 100644 index 000000000..6dcef65cb --- /dev/null +++ b/olive/search/search_results.py @@ -0,0 +1,124 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import sys +from typing import TYPE_CHECKING, Any, Dict, List, Tuple + +import numpy as np + +if TYPE_CHECKING: + from olive.evaluator.metric_result import MetricResult + + +class SearchResults: + def __init__(self): + self._results: Tuple[MetricResult, List[str], Dict[str, Any]] = [] + self._sorted_indicies: List[int] = [] + + def record_feedback_signal( + self, search_point_index: int, objectives: Dict[str, dict], result: "MetricResult", model_ids: List[str] + ): + """Record the evaluation result of a search point.""" + self._results += [None] * ((search_point_index + 1) - len(self._results)) + self._results[search_point_index] = (result, model_ids, objectives) + + def meets_goals(self, search_point_index: int) -> bool: + """Check if the result satisfies the constraints.""" + if search_point_index >= len(self._results): + return False + + if not self._results[search_point_index]: + return False + + result, _, objectives = self._results[search_point_index] + goals = {name: obj["goal"] for name, obj in objectives.items() if obj.get("goal") is not None} + if not goals: + return True # if goals are not set, always return True + + # multiplier for each objective and goals + multipliers = { + name: 1 if objective.get("higher_is_better", False) else -1 for name, objective in objectives.items() + } + return all((multipliers[obj] * result[obj].value) >= (multipliers[obj] * goal) for obj, goal in goals.items()) + + def sort(self, apply_goals: bool = False): + indicies, results = self._get_results_list(apply_goals) + if not results: + self._sorted_indicies = indicies + return False + + # sort by objectives, left most objective has highest priority + # flip the order of the objectives since np.lexsort prioritizes the last column + # negate the results since np.lexsort sorts in ascending order + results = -np.flip(np.array(results), 1) + sorted_indices = np.lexsort(results.T) + self._sorted_indicies = [indicies[i] for i in sorted_indices] + + return True + + def get_next_best_result(self, start_index: int) -> Tuple[int, int, List[str]]: + assert start_index is not None, "Expecting an index, got None" + + if start_index < -1: + return None, None, None + + next_best_index = start_index + 1 + if next_best_index >= len(self._sorted_indicies): + return None, None, None + + _, model_ids, _ = self._results[self._sorted_indicies[next_best_index]] + return next_best_index, self._sorted_indicies[next_best_index], model_ids + + def _get_results_list(self, apply_goals: bool = False) -> Tuple[List[int], List[float]]: + """Return the results as a tuple of indicies and values. + + Values are multiplied by the objective multiplier so that higher is better for all objectives. + """ + all_objectives = {} + for spi, entry in enumerate(self._results): + if entry and (not apply_goals or self.meets_goals(spi)): + _, _, objectives = entry + for name in objectives: + if name in all_objectives: + assert all_objectives[name] == objectives[name].get( + "higher_is_better", False + ), "Conflicting values for higher_is_better across same named objectives" + else: + all_objectives[name] = objectives[name].get("higher_is_better", False) + + indicies = [] + values = [] + if not all_objectives: + # If no objectives, then use the indicies of the valid results in no specific order + indicies = [spi for spi, entry in enumerate(self._results) if entry] + return indicies, values + + # NOTE: values array need to be packed but a simple loop thru' each entry could + # possibly create a zagged array if the number of objectives are different. + + for spi, entry in enumerate(self._results): + if entry and (not apply_goals or self.meets_goals(spi)): + result, _, objectives = entry + if objectives: + indicies.append(spi) + v = [] + for name, hib in all_objectives.items(): + if name in objectives: + v.append((1 if hib else -1) * result[name].value) + else: + v.append(-sys.maxsize - 1 if hib else sys.maxsize) + values.append(v) + + return indicies, values + + def to_json(self): + """Return a json representation of the search results.""" + return {"results": self._results} + + @classmethod + def from_json(cls, json_dict): + """Create a SearchResults object from a json representation.""" + search_results = cls() + search_results._results = json_dict["results"] + return search_results diff --git a/olive/search/search_sample.py b/olive/search/search_sample.py new file mode 100644 index 000000000..a453c36d0 --- /dev/null +++ b/olive/search/search_sample.py @@ -0,0 +1,60 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Dict, List + +from olive.search.search_parameter import SpecialParamValue +from olive.search.search_point import SearchPoint + +# ruff: noqa: PD011 + + +@dataclass +class SearchSample: + """Search step result from search strategy.""" + + def __init__(self, search_point: SearchPoint, model_ids: List[str]): + self.search_point = search_point + self.model_ids = model_ids + + def __repr__(self): + return f"SearchSample({self.search_point.index}, {self.passes_configs}, {self.model_ids})" + + @property + def passes_configs(self) -> Dict[str, Any]: + passes_configs = OrderedDict() + for pass_name, (pass_index, params) in self.search_point.values.items(): + passes_configs[pass_name] = OrderedDict( + [ + ("index", pass_index), + ("params", OrderedDict()), + ] + ) + + for param_name, (_, param_value) in params.items(): + if param_value == SpecialParamValue.INVALID: + return None # Prune out invalid configurations + elif param_value != SpecialParamValue.IGNORED: + passes_configs[pass_name]["params"][param_name] = param_value + + return passes_configs + + def to_json(self): + """Return a json representation.""" + return { + "search_point": self.search_point.to_json(), + "passes_configs": self.passes_configs, + "model_ids": self.model_ids, + } + + @classmethod + def from_json(cls, json_dict): + """Create a SearchPoint object from a json representation.""" + return cls( + SearchPoint.from_json(json_dict["search_point"]), + json_dict["passes_configs"], + json_dict["model_ids"], + ) diff --git a/olive/search/search_space.py b/olive/search/search_space.py new file mode 100644 index 000000000..aa0ae182d --- /dev/null +++ b/olive/search/search_space.py @@ -0,0 +1,126 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict +from typing import Any, Dict, Generator, List, Tuple, Union + +from olive.search.search_parameter import Categorical, Conditional, SearchParameter +from olive.search.search_point import SearchPoint +from olive.search.utils import order_search_parameters + + +class SearchSpace: + """Search space for a search algorithm.""" + + def __init__(self, parameters: List[Tuple[str, Union[SearchParameter, "SearchSpace"]]]): + assert len(parameters) == len( + {name for name, _ in parameters} + ), "Parameter name in search space should be unique." + + self._parameters = self._order_search_space(parameters) + + self._length = 1 + for _, param in self._parameters: + self._length *= SearchSpace.get_param_length(param) + + @property + def parameters(self) -> List[Tuple[str, Union[SearchParameter, "SearchSpace"]]]: + return self._parameters + + def __repr__(self): + return f"SearchSpace({self._parameters}, {self._length})" + + def __len__(self) -> int: + return self._length + + def __iter__(self) -> Generator[SearchPoint, None, None]: + for index in range(self._length): + yield self[index] + + def __getitem__(self, index: int) -> SearchPoint: + assert index < self._length + return SearchPoint(index, self.get_sample_point_values(index)) + + def _order_search_space(self, parameters: List[Tuple[str, SearchParameter]]) -> List[Tuple[str, SearchParameter]]: + """Order the search space by topological order of parameters for each pass_id/space_name.""" + unordered_nodes = dict(parameters) + ordered_nodes = order_search_parameters(unordered_nodes) + return [(name, unordered_nodes[name]) for name in ordered_nodes] + + def get_sample_point_values(self, index: int) -> Dict[str, Tuple[int, Any]]: + assert index < self._length + + values = OrderedDict() + for name, param in self._parameters: + index, values[name] = SearchSpace.get_suggestion(param, index, values) + return values + + @staticmethod + def get_param_length(param: Any) -> int: + if isinstance(param, SearchParameter): + if isinstance(param, Categorical): + return sum( + ( + SearchSpace.get_param_length(suggestion) + if isinstance(suggestion, (SearchParameter, SearchSpace)) + else 1 + ) + for suggestion in param.get_support() + ) + + elif isinstance(param, Conditional): + return max(SearchSpace.get_param_length(support) for support in param.support.values()) + + elif isinstance(param, SearchSpace): + return len(param) + + return 0 + + @staticmethod + def get_param_suggestions(param: Any, values: Dict[str, Any]) -> Union[List[Any], "SearchSpace"]: + if isinstance(param, SearchParameter): + if isinstance(param, Categorical): + return param.get_support() + + elif isinstance(param, Conditional): + parent_values = {k: values[k][1] for k in param.parents} + suggestions = param.get_support_with_args(parent_values) + max_length = max(SearchSpace.get_param_length(support) for support in param.support.values()) + suggestions += param.default.get_support() * (max_length - len(suggestions)) + return suggestions + + elif isinstance(param, SearchSpace): + return param + + return [] + + @staticmethod + def get_suggestion(param: Any, index: int, values: Dict[str, Any]) -> Tuple[int, Tuple[int, Any]]: + length = SearchSpace.get_param_length(param) + + if index < length: + if isinstance(param, SearchParameter): + suggestions = SearchSpace.get_param_suggestions(param, values) + + for i, suggestion in enumerate(suggestions): + if isinstance(suggestion, (SearchParameter, SearchSpace)): + suggestion_length = SearchSpace.get_param_length(suggestion) + if index < suggestion_length: + _, (_, isuggestion) = SearchSpace.get_suggestion(suggestion, index, values) + return 0, (i, isuggestion) + else: + index -= suggestion_length + elif index > 0: + index -= 1 + else: + return 0, (i, suggestion) + + elif isinstance(param, SearchSpace): + return 0, (index, param.get_sample_point_values(index)) + + else: + return index, param + + _, suggestion = SearchSpace.get_suggestion(param, index % length, values) + return index // length, suggestion diff --git a/olive/search/search_strategy.py b/olive/search/search_strategy.py new file mode 100644 index 000000000..fd5dd33d3 --- /dev/null +++ b/olive/search/search_strategy.py @@ -0,0 +1,318 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple, Union + +from olive.common.config_utils import CaseInsensitiveEnum, ConfigBase, NestedConfig, validate_config, validate_enum +from olive.common.pydantic_v1 import validator +from olive.search.samplers import REGISTRY, SearchSampler +from olive.search.search_parameter import Categorical +from olive.search.search_results import SearchResults +from olive.search.search_sample import SearchSample +from olive.search.search_space import SearchSpace + +if TYPE_CHECKING: + from olive.evaluator.metric_result import MetricResult + from olive.search.search_parameter import SearchParameter + +logger = logging.getLogger(__name__) + +# ruff: noqa: PD011 + + +class SearchStrategyExecutionOrder(CaseInsensitiveEnum): + JOINT = "joint" + PASS_BY_PASS = "pass-by-pass" + + +class SearchStrategyConfig(NestedConfig): + _nested_field_name = "sampler_config" + execution_order: Union[str, SearchStrategyExecutionOrder] = None + sampler: str = None + sampler_config: ConfigBase = None + output_model_num: int = None + stop_when_goals_met: bool = False + max_iter: int = None + max_time: int = None + include_pass_params: bool = True + + @validator("execution_order", pre=True) + def _validate_execution_order(cls, v): + return validate_enum(SearchStrategyExecutionOrder, v) + + @validator("sampler", pre=True) + def _validate_sampler(cls, v): + if v not in REGISTRY: + raise ValueError(f"Unknown sampler: {v}") + return v + + @validator("sampler_config", pre=True, always=True) + def _validate_sampler_config(cls, v, values): + if "sampler" not in values: + raise ValueError("Invalid sampler") + + config_class = REGISTRY[values["sampler"]].get_config_class() + return validate_config(v, config_class) + + @validator("stop_when_goals_met", "max_iter", "max_time", pre=True) + def _validate_stop_when_goals_met(cls, v, values, field): + if "execution_order" not in values: + raise ValueError("Invalid execution_order") + if v and values["execution_order"] != SearchStrategyExecutionOrder.JOINT: + logger.info("%s is only supported for joint execution order. Ignoring...", field.name) + return field.default + return v + + +@dataclass +class SearchWalkState: + def __init__(self, path: List[int], model_ids: List[str], sampler: SearchSampler, results: SearchResults): + self.path: List[int] = path + self.sampler: SearchSampler = sampler + self.results: SearchResults = results + self.model_ids: List[str] = model_ids + self.best_result_index = -1 + + +class SearchStrategy: + def __init__(self, config: Union[Dict[str, Any], SearchStrategyConfig]): + self.config: SearchStrategyConfig = validate_config(config, SearchStrategyConfig) + + # Initialization variables + self._search_spaces: List[SearchSpace] = None + self._objectives: Dict[str, Dict[str, Any]] = None + self._init_model_id: str = None + + # State variables + self._path: List[int] = None + self._state: Dict[Tuple, SearchWalkState] = None + + self._start_time: float = 0 + self._iteration_count: int = 0 + self._num_samples_suggested: int = 0 + + self._initialized: bool = False + + def initialize( + self, + space_config: Dict[str, List[Dict[str, "SearchParameter"]]], + init_model_id: str, + objectives: Dict[str, List[Dict[str, Dict[str, Any]]]], + ): + """Initialize the search strategy. + + space_config: Ordered dictionary of format {pass_name, [{param_name: SearchParameter}]} + objectives: dictionary of format {objective_name: {"higher_is_better": bool, "goal": float}} + """ + # for the fist search group, init_model_id must be provided + if not init_model_id: + raise ValueError("init_model_id must be provided for search") + + if self.config.execution_order == SearchStrategyExecutionOrder.JOINT: + self._search_spaces = [ + SearchSpace( + [ + (pass_name, Categorical([SearchSpace(list(params.items())) for params in passes])) + for pass_name, passes in space_config.items() + ] + ) + ] + elif self.config.execution_order == SearchStrategyExecutionOrder.PASS_BY_PASS: + self._search_spaces = [ + SearchSpace([(pass_name, Categorical([SearchSpace(list(params.items())) for params in passes]))]) + for pass_name, passes in space_config.items() + ] + else: + raise ValueError(f"Unsupported execution order: {self.config.execution_order}") + + if objectives: + for pass_name, pass_params in space_config.items(): + pass_objectives = objectives.get(pass_name) + assert not pass_objectives or len(pass_objectives) == len( + pass_params + ), "Expect none or all passes to have objectives" + + self._objectives = objectives or {} + self._init_model_id = init_model_id + + self._path = None + self._state = None + + self._start_time = 0 + self._iteration_count = 0 + self._num_samples_suggested = 0 + + self._initialized = True + + @property + def search_spaces(self): + """Returns the list of search spaces.""" + return self._search_spaces + + @property + def iteration_count(self) -> int: + """Returns the number of iterations so far across all search spaces.""" + return self._iteration_count + + @property + def start_time(self) -> float: + """Returns the start time of current iteration.""" + return self._start_time + + @property + def elapsed_time(self) -> float: + """Returns elapsed time of the current iteration.""" + return (time.time() - self._start_time) if self._start_time else 0 + + @property + def num_samples_suggested(self) -> int: + """Returns the number of samples suggested so far across all search spaces.""" + return self._num_samples_suggested + + @property + def max_samples(self) -> int: + """Returns the maximum number of samples.""" + count = 1 + for space in self._search_spaces: + count *= len(space) + return count + + def __iter__(self) -> Generator[SearchSample, None, None]: + self._path = [] + search_space = self._search_spaces[len(self._path)] + self._state = { + tuple(self._path): SearchWalkState( + self._path, + [self._init_model_id], + self._create_sampler(search_space), + self._create_results(), + ) + } + + self._start_time = time.time() + self._iteration_count = 0 + self._num_samples_suggested = 0 + + while True: + state = self._state[tuple(self._path)] + + while not state.sampler.should_stop: + self._iteration_count += 1 + search_point = state.sampler.suggest() + + self._num_samples_suggested += 1 + yield SearchSample(search_point, state.model_ids) + + # If this is the last pass in the walk, evaluate the model to see if all goals are met. + if ( + self.config.stop_when_goals_met + and (len(self._path) == (len(self._search_spaces) - 1)) + and state.results.meets_goals(search_point.index) + ): + return None + + # Check is any of the global stop criteria are met. + # NOTE: Search will run at least one step before stopping. + if self.should_stop: + return None + + # Try stepping down the tree, and if that fails, try stepping up. If both fails, we are done + if not self._step_down() and not self._step_up(): + return None + + def _create_sampler(self, search_space: SearchSpace) -> SearchSampler: + """Create a search algorithm.""" + if self.config.sampler not in REGISTRY: + raise ValueError(f"Unsupported search algorithm: {self.config.sampler}") + + return REGISTRY[self.config.sampler](search_space, self.config.sampler_config) + + def _create_results(self) -> SearchResults: + """Create and return a search result.""" + return SearchResults() + + def _initialize_step(self) -> bool: + state = self._state[tuple(self._path)] + + # NOTE: Two scenarios for pass-by-pass mode - + # 1. All search points are evaluated for each search space before moving down the tree. + # 2. Evaluting search points until we find a suitable candidate and move down. If failed, move up to continue + # finding next candidate. + # Implementing option 1 currently i.e. all search points are evaluated for search space before moving down. + # Logic here can be customized to support the other if need be. + + # Get the next best result index + state.best_result_index, next_search_point, next_model_ids = state.results.get_next_best_result( + state.best_result_index + ) + if state.best_result_index is not None: + self._path.append(next_search_point) + search_space = self._search_spaces[len(self._path)] + self._state[tuple(self._path)] = SearchWalkState( + self._path, + next_model_ids, + self._create_sampler(search_space), + self._create_results(), + ) + return True + + return False + + def _step_up(self) -> bool: + if not self._path: + return False + + self._path.pop() + + # Results here are already sorted, don't have to do it again! + return self._initialize_step() + + def _step_down(self) -> bool: + if len(self._path) == (len(self._search_spaces) - 1): + return False + + state = self._state[tuple(self._path)] + + # Current state is potentially modified, sort the collected results again! + if not state.results.sort(apply_goals=True): + logger.warning( + "No models in path %s met the goals. Sorting the models without applying goals...", self._path + ) + state.results.sort(apply_goals=False) + + return self._initialize_step() + + def record_feedback_signal( + self, + search_point_index: int, + signal: "MetricResult", + model_ids: List[str], + should_prune: bool = False, + ): + """Record the feedback signal for the given search point.""" + assert self._initialized, "Search strategy is not initialized" + + state = self._state[tuple(self._path)] + search_space = self._search_spaces[len(self._path)] + search_point = search_space[search_point_index] + pass_name, _ = search_space.parameters[-1] + pass_index, _ = search_point.values[pass_name] + passes_objectives = self._objectives.get(pass_name, []) + objectives = passes_objectives[pass_index] if pass_index < len(passes_objectives) else {} + state.results.record_feedback_signal(search_point_index, objectives, signal, model_ids) + state.sampler.record_feedback_signal(search_point_index, objectives, signal, should_prune) + + @property + def should_stop(self): + """Check if the search should stop.""" + # NOTE: Individual goal criteria is checked at the end of each step in the iteration loop + return ((self.config.max_iter is not None) and (self._iteration_count > self.config.max_iter)) or ( + (self.config.max_time is not None) and (self.elapsed_time > self.config.max_time) + ) + + def get_output_model_num(self): + return self.config.output_model_num diff --git a/olive/search/utils.py b/olive/search/utils.py new file mode 100644 index 000000000..2fe2a4efa --- /dev/null +++ b/olive/search/utils.py @@ -0,0 +1,87 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from typing import Dict, List, Set, Tuple + +from olive.search.search_parameter import Conditional, SearchParameter + + +class DirectedGraph: + def __init__(self, vertices: List[str], edges: List[Tuple[str, str]] = None): + self.vertices = vertices + self.graph = {v: [] for v in vertices} + edges = edges or [] + for v1, v2 in edges: + self.add_edge(v1, v2) + + def add_edge(self, v1: str, v2: str): + assert v1 in self.vertices + assert v2 in self.vertices + self.graph[v1].append(v2) + + def _is_cyclic_util(self, v: str, visited: Set[str], rec_stack: Set[str]): + visited.add(v) + rec_stack.add(v) + + for neighbor in self.graph[v]: + if neighbor not in visited: + if self._is_cyclic_util(neighbor, visited, rec_stack): + return True + elif neighbor in rec_stack: + return True + + rec_stack.remove(v) + return False + + def is_cyclic(self): + visited = set() + rec_stack = set() + + return any(v not in visited and self._is_cyclic_util(v, visited, rec_stack) for v in self.vertices) + + def _topological_sort_util(self, v: str, visited: Set[str], order: List[str]): + visited.add(v) + + for neighbor in self.graph[v]: + if neighbor not in visited: + self._topological_sort_util(neighbor, visited, order) + + order.insert(0, v) + + def topological_sort(self): + assert not self.is_cyclic(), "Graph is cyclic, cannot perform topological sort." + visited = set() + order = [] + + # Since the dependee vertex is inserted in front, iterate the vertices in + # reverse order to retain the relative order of vertices in the graph. + # Without it the graph vertices are reversed even in cases where no + # dependency exist. + for v in reversed(self.vertices): + if v not in visited: + self._topological_sort_util(v, visited, order) + + return order + + +def _search_space_graph(search_space: Dict[str, SearchParameter]) -> DirectedGraph: + """Create a directed graph from the search space.""" + graph = DirectedGraph(list(search_space.keys())) + for name, param in search_space.items(): + if isinstance(param, Conditional): + for parent in param.parents: + graph.add_edge(parent, name) + return graph + + +def cyclic_search_space(search_space: Dict[str, SearchParameter]) -> bool: + """Check if the search space is cyclic.""" + graph = _search_space_graph(search_space) + return graph.is_cyclic() + + +def order_search_parameters(search_space: Dict[str, SearchParameter]) -> List[str]: + """Order the search parameters in a topological order.""" + graph = _search_space_graph(search_space) + return graph.topological_sort() diff --git a/olive/systems/azureml/aml_system.py b/olive/systems/azureml/aml_system.py index a286dbbf2..76634908d 100644 --- a/olive/systems/azureml/aml_system.py +++ b/olive/systems/azureml/aml_system.py @@ -243,12 +243,7 @@ def _assert_not_none(self, obj): if obj is None: raise ValueError(f"{obj.__class__.__name__} is missing in the inputs!") - def run_pass( - self, - the_pass: "Pass", - model_config: ModelConfig, - output_model_path: str, - ) -> ModelConfig: + def run_pass(self, the_pass: "Pass", model_config: ModelConfig, output_model_path: str) -> ModelConfig: """Run the pass on the model.""" ml_client = self.azureml_client_config.create_client() diff --git a/olive/systems/docker/docker_system.py b/olive/systems/docker/docker_system.py index 29560fe41..47d84d0cf 100644 --- a/olive/systems/docker/docker_system.py +++ b/olive/systems/docker/docker_system.py @@ -103,22 +103,13 @@ def __init__( _print_docker_logs(e.build_log, logging.ERROR) raise - def run_pass( - self, - the_pass: "Pass", - model_config: "ModelConfig", - output_model_path: str, - ) -> "ModelConfig": + def run_pass(self, the_pass: "Pass", model_config: "ModelConfig", output_model_path: str) -> "ModelConfig": """Run the pass on the model.""" with tempfile.TemporaryDirectory() as tempdir: return self._run_pass_container(Path(tempdir), the_pass, model_config, output_model_path) def _run_pass_container( - self, - workdir: Path, - the_pass: "Pass", - model_config: "ModelConfig", - output_model_path: str, + self, workdir: Path, the_pass: "Pass", model_config: "ModelConfig", output_model_path: str ) -> "ModelConfig": pass_config = the_pass.to_json(check_object=True) diff --git a/olive/systems/olive_system.py b/olive/systems/olive_system.py index 7a39a1b90..6c3339d9b 100644 --- a/olive/systems/olive_system.py +++ b/olive/systems/olive_system.py @@ -41,12 +41,7 @@ def __init__( self.hf_token = hf_token @abstractmethod - def run_pass( - self, - the_pass: "Pass", - model_config: "ModelConfig", - output_model_path: str, - ) -> "ModelConfig": + def run_pass(self, the_pass: "Pass", model_config: "ModelConfig", output_model_path: str) -> "ModelConfig": """Run the pass on the model at a specific point in the search space.""" raise NotImplementedError diff --git a/olive/workflows/run/config.py b/olive/workflows/run/config.py index 5265e58e0..e9d5fac6b 100644 --- a/olive/workflows/run/config.py +++ b/olive/workflows/run/config.py @@ -15,50 +15,16 @@ from olive.data.config import DataComponentConfig, DataConfig from olive.data.container.dummy_data_container import TRANSFORMER_DUMMY_DATA_CONTAINER from olive.data.container.huggingface_container import HuggingfaceContainer -from olive.engine import Engine, EngineConfig +from olive.engine import Engine +from olive.engine.config import EngineConfig, RunPassConfig from olive.engine.packaging.packaging_config import PackagingConfig from olive.evaluator.olive_evaluator import OliveEvaluatorConfig from olive.model import ModelConfig -from olive.passes.pass_config import AbstractPassConfig, PassParamDefault +from olive.passes.pass_config import PassParamDefault from olive.resource_path import AZUREML_RESOURCE_TYPES from olive.systems.system_config import SystemConfig -class RunPassConfig(AbstractPassConfig): - """Pass configuration for Olive workflow. - - This is the configuration for a single pass in Olive workflow. It includes configurations for pass type, config, - etc. - - Example: - .. code-block:: json - - { - "type": "OlivePass", - "config": { - "param1": "value1", - "param2": "value2" - } - } - - """ - - host: Union[SystemConfig, str] = Field( - None, - description=( - "Host system for the pass. If it is a string, must refer to a system config under `systems` section. If not" - " provided, use the engine's host system." - ), - ) - evaluator: Union[OliveEvaluatorConfig, str] = Field( - None, - description=( - "Evaluator for the pass. If it is a string, must refer to an evaluator config under `evaluators` section." - " If not provided, use the engine's evaluator." - ), - ) - - class RunEngineConfig(EngineConfig): evaluate_input_model: bool = True output_dir: Union[Path, str] = None @@ -138,15 +104,7 @@ class RunConfig(NestedConfig): " no-search or auto-optimizer mode based on whether passes field is provided." ), ) - passes: Dict[str, RunPassConfig] = Field(default_factory=dict, description="Pass configurations.") - pass_flows: List[List[str]] = Field( - None, - description=( - "Pass flows. Each member must be a list of pass names from `passes` field. If provided," - " each flow will be run sequentially. If not provided, all passes will be run as a single flow in the order" - " of `passes` field." - ), - ) + passes: Dict[str, List[RunPassConfig]] = Field(None, description="Pass configurations.") auto_optimizer_config: AutoOptimizerConfig = Field( default_factory=AutoOptimizerConfig, description="Auto optimizer configuration. Only valid when passes field is empty or not provided.", @@ -155,11 +113,25 @@ class RunConfig(NestedConfig): None, description="Workflow host. None by default. If provided, the workflow will be run on the specified host." ) + @root_validator(pre=True) + def patch_evaluators(cls, values): + if "evaluators" in values: + for name, evaluator_config in values["evaluators"].items(): + evaluator_config["name"] = name + return values + + @root_validator(pre=True) + def patch_passes(cls, values): + if "passes" in values: + for name, passes_config in values["passes"].items(): + if isinstance(passes_config, dict): + values["passes"][name] = [passes_config] + return values + @root_validator(pre=True) def insert_azureml_client(cls, values): values = convert_configs_to_dicts(values) _insert_azureml_client(values, values.get("azureml_client")) - return values @validator("data_configs", pre=True) @@ -237,43 +209,45 @@ def validate_evaluators(cls, v, values): def validate_engine(cls, v, values): v = _resolve_system(v, values, "host") v = _resolve_system(v, values, "target") + if v.get("search_strategy") and not v.get("evaluator"): + raise ValueError( + "Can't search without a valid evaluator config. " + "Either provider a valid evaluator config or disable search." + ) return _resolve_evaluator(v, values) @validator("passes", pre=True, each_item=True) def validate_pass_host_evaluator(cls, v, values): - v = _resolve_system(v, values, "host") - return _resolve_evaluator(v, values) + for i, _ in enumerate(v): + v[i] = _resolve_system(v[i], values, "host") + v[i] = _resolve_evaluator(v[i], values) + return v @validator("passes", pre=True, each_item=True) def validate_pass_search(cls, v, values): if "engine" not in values: raise ValueError("Invalid engine") - # validate first to gather config params - v = validate_config(v, RunPassConfig).dict() - - if not v.get("config"): - return v - - searchable_configs = set() - for param_name in v["config"]: - if v["config"][param_name] == PassParamDefault.SEARCHABLE_VALUES: - searchable_configs.add(param_name) - - resolve_all_data_configs(v["config"], values) - - if not values["engine"].search_strategy and searchable_configs: - raise ValueError( - f"You cannot disable search for {v['type']} and" - f" set {searchable_configs} to SEARCHABLE_VALUES at the same time." - " Please remove SEARCHABLE_VALUES or enable search(needs search strategy configs)." - ) + for i, _ in enumerate(v): + # validate first to gather config params + v[i] = iv = validate_config(v[i], RunPassConfig).dict() + + if iv.get("config"): + searchable_configs = set() + for param_name in iv["config"]: + if iv["config"][param_name] == PassParamDefault.SEARCHABLE_VALUES: + searchable_configs.add(param_name) + if param_name.endswith("data_config"): + iv["config"] = _resolve_data_config(iv["config"], values, param_name) + + if not values["engine"].search_strategy and searchable_configs: + raise ValueError( + f"You cannot disable search for {iv['type']} and" + f" set {searchable_configs} to SEARCHABLE_VALUES at the same time." + " Please remove SEARCHABLE_VALUES or enable search (needs search strategy configs)." + ) return v - @validator("pass_flows", pre=True) - def validate_pass_flows(cls, v, values): - return v or [] - @validator("workflow_host", pre=True) def validate_workflow_host(cls, v, values): if v is None: diff --git a/olive/workflows/run/run.py b/olive/workflows/run/run.py index 67c0b691b..4f339160c 100644 --- a/olive/workflows/run/run.py +++ b/olive/workflows/run/run.py @@ -8,14 +8,17 @@ import sys from copy import deepcopy from pathlib import Path -from typing import Generator, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union from olive.common.utils import set_tempdir from olive.logging import set_default_logger_severity, set_ort_logger_severity, set_verbosity_info from olive.package_config import OlivePackageConfig from olive.systems.accelerator_creator import create_accelerators from olive.systems.common import SystemType -from olive.workflows.run.config import RunConfig, RunPassConfig +from olive.workflows.run.config import RunConfig + +if TYPE_CHECKING: + from olive.engine.config import RunPassConfig logger = logging.getLogger(__name__) @@ -58,18 +61,19 @@ def get_pass_extras(pass_type): # add dependencies for passes if run_config.passes: - for pass_config in run_config.passes.values(): - host = pass_config.host or run_config.engine.host - if (host and host.type == SystemType.Local) or not host: - local_packages.extend(get_pass_extras(pass_config.type)) - else: - remote_packages.extend(get_pass_extras(pass_config.type)) - if pass_config.type in ["SNPEConversion", "SNPEQuantization", "SNPEtoONNXConversion"]: - logger.info( - "Please refer to https://microsoft.github.io/Olive/tutorials/passes/snpe.html to install SNPE" - " prerequisites for pass %s", - pass_config.type, - ) + for passes_configs in run_config.passes.values(): + for pass_config in passes_configs: + host = pass_config.host or run_config.engine.host + if (host and host.type == SystemType.Local) or not host: + local_packages.extend(get_pass_extras(pass_config.type)) + else: + remote_packages.extend(get_pass_extras(pass_config.type)) + if pass_config.type in ["SNPEConversion", "SNPEQuantization", "SNPEtoONNXConversion"]: + logger.info( + "Please refer to https://microsoft.github.io/Olive/tutorials/passes/snpe.html to install SNPE" + " prerequisites for pass %s", + pass_config.type, + ) # add dependencies for engine host_type = None @@ -126,8 +130,11 @@ def get_pass_module_path(pass_type: str, package_config: OlivePackageConfig) -> def is_execution_provider_required(run_config: RunConfig, package_config: OlivePackageConfig) -> bool: - passes = get_used_passes(run_config) - return any(get_pass_module_path(p.type, package_config).startswith("olive.passes.onnx") for p in passes) + return any( + get_pass_module_path(p.type, package_config).startswith("olive.passes.onnx") + for passes_configs in run_config.passes.values() + for p in passes_configs + ) def run_engine(package_config: OlivePackageConfig, run_config: RunConfig): @@ -164,48 +171,31 @@ def run_engine(package_config: OlivePackageConfig, run_config: RunConfig): and run_config.auto_optimizer_config is not None and not run_config.auto_optimizer_config.disable_auto_optimizer ) - is_ep_required = auto_optimizer_enabled or is_execution_provider_required(run_config, package_config) - - # Register passes since we need to know whether they need to run on target - used_passes = list(get_used_passes(run_config)) - for pass_config in used_passes: - logger.debug("Registering pass %s", pass_config.type) - package_config.import_pass_module(pass_config.type) # check if target is not used + used_passes_configs = get_used_passes_configs(run_config) target_not_used = ( # no evaluator given (also implies no search) engine.evaluator_config is None # not using auto optimizer - and used_passes + and used_passes_configs # no pass specific evaluator # no pass needs to run on target and all( pass_config.evaluator is None and not get_run_on_target(package_config, pass_config) - for pass_config in used_passes + for pass_config in used_passes_configs ) ) + + is_ep_required = auto_optimizer_enabled or is_execution_provider_required(run_config, package_config) accelerator_specs = create_accelerators( engine.target_config, skip_supported_eps_check=target_not_used, is_ep_required=is_ep_required ) - # Initializes the passes and register it with the engine - passes_to_run = ( - {pass_name for pass_flow in run_config.pass_flows for pass_name in pass_flow} - if run_config.pass_flows - else set(run_config.passes.keys()) - ) - for pass_name in passes_to_run: - pass_run_config = run_config.passes[pass_name] - engine.register( - pass_run_config.type, - config=pass_run_config.config, - name=pass_name, - host=pass_run_config.host.create_system() if pass_run_config.host is not None else None, - evaluator_config=pass_run_config.evaluator, - ) - engine.set_pass_flows(run_config.pass_flows or [list(run_config.passes.keys())]) + # Set passes with the engine + engine.set_input_passes_configs(run_config.passes) + # run return engine.run( run_config.input_model, accelerator_specs, @@ -334,18 +324,10 @@ def get_local_ort_packages() -> List[str]: return local_ort_packages -def get_used_passes(run_config: RunConfig) -> Generator["RunPassConfig", None, None]: - if run_config.pass_flows: - passes = set() - for pass_flow in run_config.pass_flows: - for pass_name in pass_flow: - if run_config.passes[pass_name].type not in passes: - passes.add(run_config.passes[pass_name].type) - yield run_config.passes[pass_name] - elif run_config.passes: - yield from run_config.passes.values() +def get_used_passes_configs(run_config: RunConfig) -> List["RunPassConfig"]: + return [pass_config for _, pass_configs in run_config.passes.items() for pass_config in pass_configs] -def get_run_on_target(package_config: OlivePackageConfig, pass_config: RunPassConfig) -> bool: +def get_run_on_target(package_config: OlivePackageConfig, pass_config: "RunPassConfig") -> bool: pass_module_config = package_config.get_pass_module_config(pass_config.type) return pass_module_config.run_on_target diff --git a/test/unit_test/engine/packaging/test_packaging_generator.py b/test/unit_test/engine/packaging/test_packaging_generator.py index 98261d6c7..9e108fea6 100644 --- a/test/unit_test/engine/packaging/test_packaging_generator.py +++ b/test/unit_test/engine/packaging/test_packaging_generator.py @@ -54,7 +54,7 @@ def test_generate_zipfile_artifacts(mock_sys_getsizeof, save_as_external_data, m }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, "evaluator": evaluator_config, } diff --git a/test/unit_test/engine/test_engine.py b/test/unit_test/engine/test_engine.py index 0e4e82cff..4e7293721 100644 --- a/test/unit_test/engine/test_engine.py +++ b/test/unit_test/engine/test_engine.py @@ -19,6 +19,7 @@ from olive.data.config import DataComponentConfig, DataConfig from olive.engine import Engine +from olive.engine.config import RunPassConfig from olive.evaluator.metric import AccuracySubType from olive.evaluator.metric_result import MetricResult, joint_metric_key from olive.evaluator.olive_evaluator import OliveEvaluatorConfig @@ -28,7 +29,6 @@ from olive.passes.onnx.quantization import OnnxDynamicQuantization, OnnxStaticQuantization from olive.systems.accelerator_creator import create_accelerators from olive.systems.common import SystemType -from olive.systems.local import LocalSystem from olive.systems.system_config import LocalTargetUserConfig, SystemConfig # pylint: disable=protected-access @@ -43,7 +43,7 @@ def test_register(self, tmpdir): # setup p = get_onnxconversion_pass() name = p.__class__.__name__ - system = LocalSystem() + host = SystemConfig(type=SystemType.Local) evaluator_config = OliveEvaluatorConfig(metrics=[get_accuracy_metric(AccuracySubType.ACCURACY_SCORE)]) options = { @@ -53,19 +53,20 @@ def test_register(self, tmpdir): }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, } engine = Engine(**options) # execute - engine.register(OnnxConversion, host=system, evaluator_config=evaluator_config) + engine.register(OnnxConversion, host=host, evaluator_config=evaluator_config) # assert - assert name in engine.pass_run_configs - assert engine.pass_run_configs[name]["type"] == OnnxConversion.__name__ - assert engine.pass_run_configs[name]["host"] == system - assert engine.pass_run_configs[name]["evaluator"] == evaluator_config + assert name in engine.input_passes_configs + assert len(engine.input_passes_configs[name]) == 1 + assert engine.input_passes_configs[name][0].type == OnnxConversion.__name__.lower() + assert engine.input_passes_configs[name][0].host == host + assert engine.input_passes_configs[name][0].evaluator == evaluator_config def test_register_no_search(self, tmpdir): # setup @@ -82,7 +83,7 @@ def test_register_no_search(self, tmpdir): engine.register(OnnxDynamicQuantization) # assert - assert "OnnxDynamicQuantization" in engine.pass_run_configs + assert "OnnxDynamicQuantization" in engine.input_passes_configs def test_default_engine_run(self, tmpdir): # setup @@ -100,7 +101,7 @@ def test_default_engine_run(self, tmpdir): for fp_nodes in outputs.values(): for node in fp_nodes.nodes.values(): assert node.model_config - assert node.from_pass == "OnnxConversion" + assert node.from_pass == "onnxconversion" assert node.metrics is None, "Should not evaluate input/output model by default" @patch("olive.systems.local.LocalSystem") @@ -118,7 +119,7 @@ def test_run(self, mock_local_system, tmp_path): }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, "evaluator": evaluator_config, } @@ -143,16 +144,18 @@ def test_run(self, mock_local_system, tmp_path): system_object.olive_managed_env = False engine = Engine(**options) - p1_name = "converter_13" - p2_name = "converter_14" - p1, p1_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=13) - p2, p2_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=14) - engine.register(OnnxConversion, name=p1_name, config=p1_config) - engine.register(OnnxConversion, name=p2_name, config=p2_config) - engine.set_pass_flows([[p1_name], [p2_name]]) + p_name = "converter" + p1, p1_run_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=13) + p2, p2_run_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=14) + p1_run_config = RunPassConfig(**p1_run_config) + p2_run_config = RunPassConfig(**p2_run_config) + engine.set_input_passes_configs({p_name: [p1_run_config, p2_run_config]}) + + p1_pass_config = p1.serialize_config(p1_run_config.config, check_object=True) + p2_pass_config = p2.serialize_config(p2_run_config.config, check_object=True) model_ids = [ - engine.cache.get_output_model_id(p1.__class__.__name__, p1_config, input_model_id), - engine.cache.get_output_model_id(p2.__class__.__name__, p2_config, input_model_id), + engine.cache.get_output_model_id(p1.__class__.__name__, p1_pass_config, input_model_id), + engine.cache.get_output_model_id(p2.__class__.__name__, p2_pass_config, input_model_id), ] expected_res = { model_id: { @@ -214,7 +217,6 @@ def test_run_no_search_model_components(self, mock_local_system_init, tmpdir): engine = Engine(cache_config={"cache_dir": tmpdir}) engine.register(OptimumConversion) - engine.set_pass_flows() # output model to output_dir output_dir = Path(tmpdir) @@ -257,7 +259,6 @@ def test_run_no_search(self, mock_local_system_init, tmp_path): engine = Engine(**options) _, p_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=13) engine.register(OnnxConversion, config=p_config) - engine.set_pass_flows() accelerator_spec = DEFAULT_CPU_ACCELERATOR output_model_id = engine.cache.get_output_model_id( @@ -310,7 +311,7 @@ def test_run_no_search(self, mock_local_system_init, tmp_path): [ { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, None, ], @@ -361,7 +362,7 @@ def test_pass_exception(self, caplog, tmpdir): }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, "evaluator": evaluator_config, } @@ -499,7 +500,7 @@ def test_pass_cache(self, mock_get_available_providers, mock_local_system_init, }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, "evaluator": evaluator_config, } @@ -563,7 +564,7 @@ def test_pass_value_error(self, caplog, tmpdir): }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, "evaluator": evaluator_config, } @@ -601,7 +602,9 @@ def test_pass_quantization_error(self, is_search, caplog, tmpdir): }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", + "max_samples": 1, + "seed": 1, }, "evaluator": evaluator_config, } diff --git a/test/unit_test/passes/common/test_user_script.py b/test/unit_test/passes/common/test_user_script.py index d21fe9934..c1460dea4 100644 --- a/test/unit_test/passes/common/test_user_script.py +++ b/test/unit_test/passes/common/test_user_script.py @@ -8,6 +8,6 @@ class TestUserScriptConfig: def test_no_config(self): - config = {} - config = OrtSessionParamsTuning.generate_config(DEFAULT_CPU_ACCELERATOR, config, True) + config = OrtSessionParamsTuning.generate_config(DEFAULT_CPU_ACCELERATOR, disable_search=True) assert config + assert OrtSessionParamsTuning.validate_config(config, DEFAULT_CPU_ACCELERATOR, disable_search=True) diff --git a/test/unit_test/passes/test_pass_serialization.py b/test/unit_test/passes/test_pass_serialization.py index 88e09906e..ea05b0e35 100644 --- a/test/unit_test/passes/test_pass_serialization.py +++ b/test/unit_test/passes/test_pass_serialization.py @@ -11,8 +11,7 @@ @pytest.mark.parametrize("host_device", [None, "cpu", "gpu"]) def test_pass_serialization(host_device): - config = {} - config = OnnxConversion.generate_config(DEFAULT_CPU_ACCELERATOR, config) + config = OnnxConversion.generate_config(DEFAULT_CPU_ACCELERATOR) onnx_conversion = OnnxConversion(DEFAULT_CPU_ACCELERATOR, config, host_device=host_device) json = onnx_conversion.to_json(True) diff --git a/test/unit_test/search/samplers/test_random_sampler.py b/test/unit_test/search/samplers/test_random_sampler.py new file mode 100644 index 000000000..6f9179a2f --- /dev/null +++ b/test/unit_test/search/samplers/test_random_sampler.py @@ -0,0 +1,101 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict +from unittest.mock import patch + +from olive.search.samplers.random_sampler import RandomSampler +from olive.search.search_parameter import Categorical +from olive.search.search_space import SearchSpace + +# ruff: noqa: PD011 + + +class TestRandomSampler: + @patch("olive.search.search_space.SearchSpace.__getitem__") + def test_call_count(self, mock_search_space_get_item): + search_space = SearchSpace( + [ + ("PassA", Categorical([1, 2])), + ("PassB", Categorical([1, 2, 3])), + ("PassC", Categorical(["a", "b"])), + ("PassD", Categorical(["a", "b", "c"])), + ] + ) + assert len(search_space) == 36 + + config = {"max_samples": 50} + sampler = RandomSampler(search_space, config=config) + + count = 0 + while not sampler.should_stop: + sampler.suggest() + count += 1 + + assert count == 36 + assert mock_search_space_get_item.call_count == 36 + + def test_iteration(self): + search_space = SearchSpace( + [ + ("PassA", Categorical([1, 2])), + ("PassB", Categorical([1, 2, 3])), + ("PassC", Categorical(["a", "b"])), + ("PassD", Categorical(["a", "b", "c"])), + ] + ) + assert len(search_space) == 36 + + config = {"seed": 101, "max_samples": 50} + sampler = RandomSampler(search_space, config=config) + + count = 0 + actual = [] + while not sampler.should_stop: + actual.append(sampler.suggest()) + count += 1 + + actual = [(search_point.index, search_point.values) for search_point in actual] + expected = [ + (12, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (1, "b")})), + (35, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (2, "c")})), + (23, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (1, "b")})), + (31, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (2, "c")})), + (3, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (0, "a")})), + (24, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (2, "c")})), + (18, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (1, "b")})), + (7, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (0, "a")})), + (25, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (2, "c")})), + (9, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (0, "a")})), + (13, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (1, "b")})), + (21, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (1, "b")})), + (33, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (2, "c")})), + (8, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (0, "a")})), + (16, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (1, "b")})), + (26, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (2, "c")})), + (2, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (0, "a")})), + (15, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (1, "b")})), + (11, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (0, "a")})), + (10, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (0, "a")})), + (4, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (0, "a")})), + (20, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (1, "b")})), + (32, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (2, "c")})), + (27, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (2, "c")})), + (17, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (1, "b")})), + (6, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (0, "a")})), + (29, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (2, "c")})), + (22, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (1, "b")})), + (14, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (1, "b")})), + (30, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (2, "c")})), + (19, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (1, "b")})), + (28, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (2, "c")})), + (1, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (0, "a")})), + (0, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (0, "a")})), + (34, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (2, "c")})), + (5, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (0, "a")})), + ] + + assert count == 36 + assert actual == expected + assert len({sp for sp, _ in expected}) == len(expected) diff --git a/test/unit_test/search/samplers/test_sequential_sampler.py b/test/unit_test/search/samplers/test_sequential_sampler.py new file mode 100644 index 000000000..daa72a820 --- /dev/null +++ b/test/unit_test/search/samplers/test_sequential_sampler.py @@ -0,0 +1,93 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict +from unittest.mock import patch + +from olive.search.samplers.sequential_sampler import SequentialSampler +from olive.search.search_parameter import Categorical +from olive.search.search_space import SearchSpace + +# ruff: noqa: PD011 + + +class TestSequentialSampler: + @patch("olive.search.search_space.SearchSpace.__getitem__") + def test_length(self, mock_search_space_get_item): + search_space = SearchSpace( + [ + ("PassA", Categorical([1, 2])), + ("PassB", Categorical([1, 2, 3])), + ("PassC", Categorical(["a", "b"])), + ("PassD", Categorical(["a", "b", "c"])), + ] + ) + assert len(search_space) == 36 + + sampler = SequentialSampler(search_space) + + suggestions = [] + while not sampler.should_stop: + suggestions.append(sampler.suggest()) + + assert len(suggestions) == 36 + assert mock_search_space_get_item.call_count == 36 + + def test_iteration(self): + search_space = SearchSpace( + [ + ("PassA", Categorical([1, 2])), + ("PassB", Categorical([1, 2, 3])), + ("PassC", Categorical(["a", "b"])), + ("PassD", Categorical(["a", "b", "c"])), + ] + ) + assert len(search_space) == 36 + + sampler = SequentialSampler(search_space) + + actual = [] + while not sampler.should_stop: + actual.append(sampler.suggest()) + + actual = [(search_point.index, search_point.values) for search_point in actual] + expected = [ + (0, OrderedDict([("PassA", (0, 1)), ("PassB", (0, 1)), ("PassC", (0, "a")), ("PassD", (0, "a"))])), + (1, OrderedDict([("PassA", (1, 2)), ("PassB", (0, 1)), ("PassC", (0, "a")), ("PassD", (0, "a"))])), + (2, OrderedDict([("PassA", (0, 1)), ("PassB", (1, 2)), ("PassC", (0, "a")), ("PassD", (0, "a"))])), + (3, OrderedDict([("PassA", (1, 2)), ("PassB", (1, 2)), ("PassC", (0, "a")), ("PassD", (0, "a"))])), + (4, OrderedDict([("PassA", (0, 1)), ("PassB", (2, 3)), ("PassC", (0, "a")), ("PassD", (0, "a"))])), + (5, OrderedDict([("PassA", (1, 2)), ("PassB", (2, 3)), ("PassC", (0, "a")), ("PassD", (0, "a"))])), + (6, OrderedDict([("PassA", (0, 1)), ("PassB", (0, 1)), ("PassC", (1, "b")), ("PassD", (0, "a"))])), + (7, OrderedDict([("PassA", (1, 2)), ("PassB", (0, 1)), ("PassC", (1, "b")), ("PassD", (0, "a"))])), + (8, OrderedDict([("PassA", (0, 1)), ("PassB", (1, 2)), ("PassC", (1, "b")), ("PassD", (0, "a"))])), + (9, OrderedDict([("PassA", (1, 2)), ("PassB", (1, 2)), ("PassC", (1, "b")), ("PassD", (0, "a"))])), + (10, OrderedDict([("PassA", (0, 1)), ("PassB", (2, 3)), ("PassC", (1, "b")), ("PassD", (0, "a"))])), + (11, OrderedDict([("PassA", (1, 2)), ("PassB", (2, 3)), ("PassC", (1, "b")), ("PassD", (0, "a"))])), + (12, OrderedDict([("PassA", (0, 1)), ("PassB", (0, 1)), ("PassC", (0, "a")), ("PassD", (1, "b"))])), + (13, OrderedDict([("PassA", (1, 2)), ("PassB", (0, 1)), ("PassC", (0, "a")), ("PassD", (1, "b"))])), + (14, OrderedDict([("PassA", (0, 1)), ("PassB", (1, 2)), ("PassC", (0, "a")), ("PassD", (1, "b"))])), + (15, OrderedDict([("PassA", (1, 2)), ("PassB", (1, 2)), ("PassC", (0, "a")), ("PassD", (1, "b"))])), + (16, OrderedDict([("PassA", (0, 1)), ("PassB", (2, 3)), ("PassC", (0, "a")), ("PassD", (1, "b"))])), + (17, OrderedDict([("PassA", (1, 2)), ("PassB", (2, 3)), ("PassC", (0, "a")), ("PassD", (1, "b"))])), + (18, OrderedDict([("PassA", (0, 1)), ("PassB", (0, 1)), ("PassC", (1, "b")), ("PassD", (1, "b"))])), + (19, OrderedDict([("PassA", (1, 2)), ("PassB", (0, 1)), ("PassC", (1, "b")), ("PassD", (1, "b"))])), + (20, OrderedDict([("PassA", (0, 1)), ("PassB", (1, 2)), ("PassC", (1, "b")), ("PassD", (1, "b"))])), + (21, OrderedDict([("PassA", (1, 2)), ("PassB", (1, 2)), ("PassC", (1, "b")), ("PassD", (1, "b"))])), + (22, OrderedDict([("PassA", (0, 1)), ("PassB", (2, 3)), ("PassC", (1, "b")), ("PassD", (1, "b"))])), + (23, OrderedDict([("PassA", (1, 2)), ("PassB", (2, 3)), ("PassC", (1, "b")), ("PassD", (1, "b"))])), + (24, OrderedDict([("PassA", (0, 1)), ("PassB", (0, 1)), ("PassC", (0, "a")), ("PassD", (2, "c"))])), + (25, OrderedDict([("PassA", (1, 2)), ("PassB", (0, 1)), ("PassC", (0, "a")), ("PassD", (2, "c"))])), + (26, OrderedDict([("PassA", (0, 1)), ("PassB", (1, 2)), ("PassC", (0, "a")), ("PassD", (2, "c"))])), + (27, OrderedDict([("PassA", (1, 2)), ("PassB", (1, 2)), ("PassC", (0, "a")), ("PassD", (2, "c"))])), + (28, OrderedDict([("PassA", (0, 1)), ("PassB", (2, 3)), ("PassC", (0, "a")), ("PassD", (2, "c"))])), + (29, OrderedDict([("PassA", (1, 2)), ("PassB", (2, 3)), ("PassC", (0, "a")), ("PassD", (2, "c"))])), + (30, OrderedDict([("PassA", (0, 1)), ("PassB", (0, 1)), ("PassC", (1, "b")), ("PassD", (2, "c"))])), + (31, OrderedDict([("PassA", (1, 2)), ("PassB", (0, 1)), ("PassC", (1, "b")), ("PassD", (2, "c"))])), + (32, OrderedDict([("PassA", (0, 1)), ("PassB", (1, 2)), ("PassC", (1, "b")), ("PassD", (2, "c"))])), + (33, OrderedDict([("PassA", (1, 2)), ("PassB", (1, 2)), ("PassC", (1, "b")), ("PassD", (2, "c"))])), + (34, OrderedDict([("PassA", (0, 1)), ("PassB", (2, 3)), ("PassC", (1, "b")), ("PassD", (2, "c"))])), + (35, OrderedDict([("PassA", (1, 2)), ("PassB", (2, 3)), ("PassC", (1, "b")), ("PassD", (2, "c"))])), + ] + assert actual == expected diff --git a/test/unit_test/search/samplers/test_tpe_sampler.py b/test/unit_test/search/samplers/test_tpe_sampler.py new file mode 100644 index 000000000..192df28cb --- /dev/null +++ b/test/unit_test/search/samplers/test_tpe_sampler.py @@ -0,0 +1,168 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict + +from olive.search.samplers.tpe_sampler import TPESampler +from olive.search.search_parameter import Categorical, Conditional +from olive.search.search_space import SearchSpace + +# ruff: noqa: PD011 + + +class TestTPESampler: + def test_iteration(self): + search_space = SearchSpace( + [ + ("PassA", Categorical([1, 2])), + ("PassB", Categorical([1, 2, 3])), + ("PassC", Categorical(["a", "b"])), + ("PassD", Categorical(["a", "b", "c"])), + ] + ) + + config = {"seed": 101, "max_samples": 50} + sampler = TPESampler(search_space, config=config) + + count = 0 + actual = [] + while not sampler.should_stop: + actual.append(sampler.suggest()) + count += 1 + + actual = [(search_point.index, search_point.values) for search_point in actual] + expected = [ + (5, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (0, "a")})), + (16, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (1, "b")})), + (34, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (2, "c")})), + (21, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (1, "b")})), + (25, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (2, "c")})), + (22, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (1, "b")})), + (13, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (1, "b")})), + (20, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (1, "b")})), + (32, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (2, "c")})), + (9, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (0, "a")})), + (35, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (2, "c")})), + (10, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (0, "a")})), + (4, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (0, "a")})), + (28, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (2, "c")})), + (30, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (2, "c")})), + (27, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (2, "c")})), + (12, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (1, "b")})), + (33, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (2, "c")})), + (24, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (2, "c")})), + (1, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (0, "a")})), + (3, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (0, "a")})), + (31, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (2, "c")})), + (6, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (0, "a")})), + (11, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (0, "a")})), + (15, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (1, "b")})), + (18, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (1, "b")})), + (17, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (1, "b")})), + (2, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (0, "a")})), + (0, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (0, "a")})), + (14, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (1, "b")})), + (7, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (0, "a")})), + (29, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (2, "c")})), + (8, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (0, "a")})), + (26, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (2, "c")})), + (23, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (1, "b")})), + (19, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (1, "b")})), + ] + + assert count == 36 + assert actual == expected + assert len({spi for spi, _ in expected}) == len(expected) + + def test_suggest(self): + search_space = SearchSpace( + [ + ("conversion", Categorical([SearchSpace([])])), + ("transformers_optimization", Categorical([SearchSpace([])])), + ( + "quantization", + Categorical( + [ + SearchSpace( + [ + ("quant_mode", Categorical(["dynamic", "static"])), + ("weight_type", Categorical(["QInt8", "QUInt8"])), + ( + "quant_format", + Conditional( + parents=("quant_mode",), + support={("static",): Categorical(["QOperator", "QDQ"])}, + default=Conditional.get_ignored_choice(), + ), + ), + ( + "activation_type", + Conditional( + parents=("quant_mode", "quant_format", "weight_type"), + support={ + ("static", "QDQ", "QInt8"): Categorical(["QInt8"]), + ("static", "QDQ", "QUInt8"): Categorical(["QUInt8"]), + ("static", "QOperator", "QUInt8"): Categorical(["QUInt8"]), + ("static", "QOperator", "QInt8"): Conditional.get_invalid_choice(), + }, + default=Conditional.get_ignored_choice(), + ), + ), + ( + "prepare_qnn_config", + Conditional( + parents=("quant_mode",), + support={ + ("static",): Categorical([False]), + ("dynamic",): Conditional.get_ignored_choice(), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ( + "qnn_extra_options", + Conditional( + parents=("quant_mode",), + support={ + ("static",): Categorical([None]), + ("dynamic",): Conditional.get_ignored_choice(), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ( + "MatMulConstBOnly", + Conditional( + parents=("quant_mode",), + support={ + ("dynamic",): Categorical([True]), + ("static",): Categorical([False]), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ] + ) + ] + ), + ), + ( + "session_params_tuning", + Categorical( + [ + SearchSpace( + [("providers_list", Categorical(["OpenVINOExecutionProvider", "CPUExecutionProvider"]))] + ) + ] + ), + ), + ] + ) + + config = {"seed": 101, "max_samples": 500} + sampler = TPESampler(search_space, config=config) + + while not sampler.should_stop: + sp = sampler.suggest() + assert sp == search_space[sp.index] diff --git a/test/unit_test/search/test_search_results.py b/test/unit_test/search/test_search_results.py new file mode 100644 index 000000000..c6f2a14ca --- /dev/null +++ b/test/unit_test/search/test_search_results.py @@ -0,0 +1,106 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from olive.evaluator.metric_result import MetricResult +from olive.search.search_results import SearchResults + +# ruff: noqa: PD011 +# pylint: disable=W0212 + + +class TestSearchResults: + def test_empty(self): + results = SearchResults() + results.sort() + + assert results._sorted_indicies == [] + assert results.get_next_best_result(-1) == (None, None, None) + + def test_sort(self): + objectives2 = { + "accuracy-accuracy_custom": {"goal": 0.7220000147819519, "higher_is_better": True, "priority": 1}, + "latency-avg": {"goal": 24.044427, "higher_is_better": False, "priority": 2}, + } + objectives3 = { + **objectives2, + "latency-max": {"goal": 30, "higher_is_better": False, "priority": 3}, + } + + signal1 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.75, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 4.5, "priority": 2, "higher_is_better": False}, + } + ) + signal2 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.78, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 55.9, "priority": 2, "higher_is_better": False}, + } + ) + signal3 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.76, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 53.0, "priority": 2, "higher_is_better": False}, + "latency-max": {"value": 60.1, "priority": 3, "higher_is_better": False}, + } + ) + + signals = [ + ( + objectives2, + signal1, + ["model_id_1"], + ), + None, + None, + ( + objectives2, + signal2, + ["model_id_2"], + ), + None, + None, + None, + ( + objectives3, + signal3, + ["model_id_3"], + ), + ] + + results = SearchResults() + for i, signal in enumerate(signals): + if signal: + results.record_feedback_signal(i, *signal) + + results.sort() + + assert results._sorted_indicies == [3, 7, 0] + + next_best_spi = -1 + actual_order = [] + while next_best_spi is not None: + next_best_spi, spi, model_ids = results.get_next_best_result(next_best_spi) + if next_best_spi is not None: + actual_order.append((next_best_spi, spi, model_ids)) + + assert actual_order == [ + (0, 3, ["model_id_2"]), + (1, 7, ["model_id_3"]), + (2, 0, ["model_id_1"]), + ] diff --git a/test/unit_test/search/test_search_space.py b/test/unit_test/search/test_search_space.py new file mode 100644 index 000000000..9035b4cc4 --- /dev/null +++ b/test/unit_test/search/test_search_space.py @@ -0,0 +1,1963 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict + +from olive.search.search_parameter import Categorical, Conditional, SpecialParamValue +from olive.search.search_point import SearchPoint +from olive.search.search_space import SearchSpace + +# ruff: noqa: PD011 + + +class TestSearchSpace: + def test_length_categoricals(self): + search_space = SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2])), + ("Pass/0/SP2", Categorical([1, 2, 3])), + ("Pass/0/SP3", Categorical([1, 2, 3, 4])), + ("Pass/0/SP4", Categorical([1, 2, 3, 4, 5])), + ("Pass/0/SP5", Categorical(["a", "b"])), + ("Pass/0/SP6", Categorical(["a", "b", "c"])), + ("Pass/0/SP7", Categorical(["a", "b", "c", "d"])), + ("Pass/0/SP8", Categorical(["a", "b", "c", "d", "e"])), + ] + ) + assert len(search_space) == 14400 + + def test_length_with_conditionals(self): + search_space = SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2, 3])), + ("Pass/0/SP2", Categorical(["x", "y", "z"])), + ( + "Pass/0/SP3", + Conditional( + parents=("Pass/0/SP1",), + support={(1,): Categorical(["a"]), (3,): Categorical(["a", "b", "c"])}, + default=Conditional.get_ignored_choice(), + ), + ), + ("Pass/0/SP4", Categorical(["4a", "4b", "4c"])), + ( + "Pass/0/SP5", + Conditional( + parents=("Pass/0/SP1", "Pass/0/SP2"), + support={ + (1, "x"): Categorical(["1x", "x1"]), + (2, "y"): Categorical(["2y", "y2"]), + (3, "z"): Categorical(["3z", "z3"]), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ("Pass/0/SP6", Categorical(["5x", "5y", "5z"])), + ] + ) + assert len(search_space) == 486 + + def test_length_with_search_spaces(self): + search_space = SearchSpace( + [ + ( + "Pass/0", + SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2])), + ("Pass/0/SP2", Categorical([1, 2, 3])), + ] + ), + ), + ( + "Pass/1", + SearchSpace( + [ + ("Pass/1/SP1", Categorical(["a", "b"])), + ("Pass/1/SP2", Categorical(["a", "b", "c"])), + ] + ), + ), + ] + ) + assert len(search_space) == 36 + + def test_length(self): + search_space = SearchSpace( + [ + ( + "PassA", + Categorical( + [SearchSpace([("PA0_SP1", Categorical([1, 2, 3])), ("PA0_SP2", Categorical([4, 5, 6]))])] + ), + ), + ( + "PassB", + Categorical( + [ + SearchSpace( + [("PB0_SP1", Categorical(["a", "b", "c"])), ("PB0_SP2", Categorical(["x", "y", "z"]))] + ), + SearchSpace( + [("PB0_SP1", Categorical(["x", "y", "z"])), ("PB0_SP2", Categorical(["a", "b", "c"]))] + ), + ] + ), + ), + ("PassC", SearchSpace([("PC0_SP1", Categorical([3, 2, 1])), ("PC0_SP2", Categorical([9, 8, 7]))])), + ( + "PassD", + SearchSpace( + [ + ("PD0_SP1", Categorical([1, 2, 3])), + ("PD0_SP2", Categorical(["x", "y", "z"])), + ( + "PD0_SP3", + Conditional( + parents=("PD0_SP1",), + support={(1,): Categorical(["a"]), (3,): Categorical(["a", "b", "c"])}, + default=Conditional.get_ignored_choice(), + ), + ), + ( + "PD0_SP4", + Conditional( + parents=("PD0_SP1", "PD0_SP2"), + support={ + (1, "x"): Categorical(["1x", "x1"]), + (2, "y"): Categorical(["2y", "y2"]), + (3, "z"): Categorical(["3y", "y3"]), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ( + "PD0_SP5", + Conditional( + parents=("PD0_SP3", "PD0_SP4"), + support={ + ("a", "1x"): SearchSpace( + [ + ("PD0_SP5_SP1", Categorical(["a1x", "x1a"])), + ("PD0_SP5_SP2", Categorical(["x1a", "a1x"])), + ] + ), + ("c", "3y"): SearchSpace( + [ + ("PD0_SP5_SP3", Categorical(["c3y", "y3c"])), + ("PD0_SP5_SP4", Categorical(["y3c", "3cy"])), + ] + ), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ] + ), + ), + ] + ) + # PassA:length = 9 + # PassB:length = 18 + # PassB[0]:length = 9 + # PassB[1]:length = 9 + # PassC:length = 9 + # PassD:length = 216 + # PassD[0].PD0_SP1:length = 3 + # PassD[0].PD0_SP2:length = 3 + # PassD[0].PD0_SP3:length = 3 + # PassD[0].PD0_SP4:length = 2 + # PassD[0].PD0_SP5:length = 4 + assert len(search_space) == 314928 # 9 * 18 * 9 * 216 + + def test_empty(self): + search_space = SearchSpace([]) + assert len(search_space) == 1 + + actual = [(search_point.index, search_point.values) for search_point in search_space] + expected = [(0, OrderedDict())] + assert actual == expected + + search_space = SearchSpace( + [ + ("PassA", Categorical([SearchSpace([]), SearchSpace([]), SearchSpace([])])), + ("PassB", Categorical([SearchSpace([]), SearchSpace([])])), + ("PassC", Categorical([SearchSpace([])])), + ] + ) + assert len(search_space) == 6 + + actual = [(search_point.index, search_point.values) for search_point in search_space] + expected = [ + ( + 0, + OrderedDict( + [("PassA", (0, OrderedDict())), ("PassB", (0, OrderedDict())), ("PassC", (0, OrderedDict()))] + ), + ), + ( + 1, + OrderedDict( + [("PassA", (1, OrderedDict())), ("PassB", (0, OrderedDict())), ("PassC", (0, OrderedDict()))] + ), + ), + ( + 2, + OrderedDict( + [("PassA", (2, OrderedDict())), ("PassB", (0, OrderedDict())), ("PassC", (0, OrderedDict()))] + ), + ), + ( + 3, + OrderedDict( + [("PassA", (0, OrderedDict())), ("PassB", (1, OrderedDict())), ("PassC", (0, OrderedDict()))] + ), + ), + ( + 4, + OrderedDict( + [("PassA", (1, OrderedDict())), ("PassB", (1, OrderedDict())), ("PassC", (0, OrderedDict()))] + ), + ), + ( + 5, + OrderedDict( + [("PassA", (2, OrderedDict())), ("PassB", (1, OrderedDict())), ("PassC", (0, OrderedDict()))] + ), + ), + ] + assert actual == expected + + def test_iteration_with_categoricals(self): + search_space = SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2, 3])), + ("Pass/0/SP2", Categorical(["a", "b"])), + ("Pass/0/SP3", Categorical([10, 20, 30])), + ("Pass/0/SP4", Categorical(["x", "y", "z"])), + ] + ) + + actual = [(search_point.index, search_point.values) for search_point in search_space] + expected = [ + ( + 0, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 1, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 2, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 3, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 4, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 5, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 6, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 7, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 8, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 9, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 10, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 11, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 12, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 13, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 14, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 15, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 16, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 17, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 18, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 19, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 20, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 21, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 22, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 23, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 24, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 25, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 26, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 27, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 28, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 29, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 30, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 31, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 32, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 33, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 34, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 35, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 36, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 37, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 38, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 39, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 40, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 41, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 42, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 43, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 44, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 45, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 46, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 47, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 48, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 49, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 50, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 51, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 52, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 53, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ] + assert actual == expected + + def test_iteration_with_conditionals(self): + search_space = SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2, 3])), + ( + "Pass/0/SP2", + Conditional( + parents=("Pass/0/SP1",), + support={(1,): Categorical(["a"]), (3,): Categorical(["a", "b", "c"])}, + default=Conditional.get_ignored_choice(), + ), + ), + ("Pass/0/SP3", Categorical(["x", "y", "z"])), + ( + "Pass/0/SP4", + Conditional( + parents=("Pass/0/SP1", "Pass/0/SP3"), + support={ + (1, "x"): Categorical(["1x", "x1"]), + (2, "y"): Categorical(["2y", "y2"]), + (3, "z"): Categorical(["3z", "z3"]), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ] + ) + + actual = [(search_point.index, search_point.values) for search_point in search_space] + expected = [ + ( + 0, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, "1x")), + ] + ), + ), + ( + 1, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 2, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 3, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, "1x")), + ] + ), + ), + ( + 4, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 5, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 6, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, "1x")), + ] + ), + ), + ( + 7, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 8, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 9, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 10, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, "2y")), + ] + ), + ), + ( + 11, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 12, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 13, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, "2y")), + ] + ), + ), + ( + 14, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 15, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 16, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, "2y")), + ] + ), + ), + ( + 17, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 18, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 19, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 20, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, "3z")), + ] + ), + ), + ( + 21, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 22, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 23, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, "3z")), + ] + ), + ), + ( + 24, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 25, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 26, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, "3z")), + ] + ), + ), + ( + 27, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, "x1")), + ] + ), + ), + ( + 28, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 29, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 30, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, "x1")), + ] + ), + ), + ( + 31, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 32, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 33, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, "x1")), + ] + ), + ), + ( + 34, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 35, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 36, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 37, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, "y2")), + ] + ), + ), + ( + 38, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 39, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 40, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, "y2")), + ] + ), + ), + ( + 41, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 42, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 43, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, "y2")), + ] + ), + ), + ( + 44, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 45, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 46, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 47, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, "z3")), + ] + ), + ), + ( + 48, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 49, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 50, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, "z3")), + ] + ), + ), + ( + 51, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 52, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 53, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, "z3")), + ] + ), + ), + ] + assert actual == expected + + def test_iteration_with_search_spaces(self): + search_space = SearchSpace( + [ + ( + "Pass/0", + SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2])), + ("Pass/0/SP2", Categorical([1, 2, 3])), + ] + ), + ), + ( + "Pass/1", + SearchSpace( + [ + ("Pass/1/SP1", Categorical(["a", "b"])), + ("Pass/1/SP2", Categorical(["a", "b", "c"])), + ] + ), + ), + ] + ) + assert len(search_space) == 36 + + actual = [(search_point.index, search_point.values) for search_point in search_space] + expected = [ + ( + 0, + OrderedDict( + [ + ("Pass/0", (0, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 1, + OrderedDict( + [ + ("Pass/0", (1, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 2, + OrderedDict( + [ + ("Pass/0", (2, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 3, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 4, + OrderedDict( + [ + ("Pass/0", (4, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 5, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 6, + OrderedDict( + [ + ("Pass/0", (0, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (1, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 7, + OrderedDict( + [ + ("Pass/0", (1, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (1, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 8, + OrderedDict( + [ + ("Pass/0", (2, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (1, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 9, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (1, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 10, + OrderedDict( + [ + ("Pass/0", (4, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (1, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 11, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (1, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 12, + OrderedDict( + [ + ("Pass/0", (0, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (2, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 13, + OrderedDict( + [ + ("Pass/0", (1, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (2, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 14, + OrderedDict( + [ + ("Pass/0", (2, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (2, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 15, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (2, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 16, + OrderedDict( + [ + ("Pass/0", (4, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (2, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 17, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (2, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 18, + OrderedDict( + [ + ("Pass/0", (0, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (3, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 19, + OrderedDict( + [ + ("Pass/0", (1, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (3, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 20, + OrderedDict( + [ + ("Pass/0", (2, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (3, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 21, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (3, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 22, + OrderedDict( + [ + ("Pass/0", (4, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (3, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 23, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (3, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 24, + OrderedDict( + [ + ("Pass/0", (0, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 25, + OrderedDict( + [ + ("Pass/0", (1, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 26, + OrderedDict( + [ + ("Pass/0", (2, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 27, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 28, + OrderedDict( + [ + ("Pass/0", (4, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 29, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 30, + OrderedDict( + [ + ("Pass/0", (0, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (5, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 31, + OrderedDict( + [ + ("Pass/0", (1, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (5, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 32, + OrderedDict( + [ + ("Pass/0", (2, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (5, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 33, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (5, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 34, + OrderedDict( + [ + ("Pass/0", (4, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (5, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 35, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (5, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ] + assert actual == expected + + def test_get_item_with_categoricals(self): + search_space = SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2, 3])), + ("Pass/0/SP2", Categorical(["a", "b"])), + ("Pass/0/SP3", Categorical([10, 20, 30])), + ("Pass/0/SP4", Categorical(["x", "y", "z"])), + ] + ) + + actual_5 = search_space[5] + expected_5 = SearchPoint( + 5, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ) + assert actual_5 == expected_5 + + actual_45 = search_space[45] + expected_45 = SearchPoint( + 45, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ) + assert actual_45 == expected_45 + + def test_get_item_with_conditionals(self): + search_space = SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2, 3])), + ( + "Pass/0/SP2", + Conditional( + parents=("Pass/0/SP1",), + support={(1,): Categorical(["a"]), (3,): Categorical(["a", "b", "c"])}, + default=Conditional.get_ignored_choice(), + ), + ), + ("Pass/0/SP3", Categorical(["x", "y", "z"])), + ( + "Pass/0/SP4", + Conditional( + parents=("Pass/0/SP1", "Pass/0/SP3"), + support={ + (1, "x"): Categorical(["1x", "x1"]), + (2, "y"): Categorical(["2y", "y2"]), + (3, "z"): Categorical(["3z", "z3"]), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ] + ) + + actual_5 = search_space[5] + expected_5 = SearchPoint( + 5, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ) + assert actual_5 == expected_5 + + actual_17 = search_space[17] + expected_17 = SearchPoint( + 17, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ) + assert actual_17 == expected_17 + + def test_get_item_with_search_spaces(self): + search_space = SearchSpace( + [ + ( + "Pass/0", + SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2])), + ("Pass/0/SP2", Categorical([1, 2, 3])), + ] + ), + ), + ( + "Pass/1", + SearchSpace( + [ + ("Pass/1/SP1", Categorical(["a", "b"])), + ("Pass/1/SP2", Categorical(["a", "b", "c"])), + ] + ), + ), + ] + ) + assert len(search_space) == 36 + + actual_5 = search_space[5] + expected_5 = SearchPoint( + 5, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ) + assert actual_5 == expected_5 + + actual_27 = search_space[27] + expected_27 = SearchPoint( + 27, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ) + assert actual_27 == expected_27 diff --git a/test/unit_test/search/test_search_strategy.py b/test/unit_test/search/test_search_strategy.py new file mode 100644 index 000000000..5bc4f43cb --- /dev/null +++ b/test/unit_test/search/test_search_strategy.py @@ -0,0 +1,849 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import re +from collections import OrderedDict + +import pytest + +from olive.evaluator.metric_result import MetricResult +from olive.search.search_parameter import Categorical +from olive.search.search_strategy import SearchStrategy, SearchStrategyConfig, SearchStrategyExecutionOrder + +# pylint: disable=protected-access +# ruff: noqa: PD011 + + +class TestSearchStrategy: + @pytest.mark.parametrize( + "execution_order", [SearchStrategyExecutionOrder.JOINT, SearchStrategyExecutionOrder.PASS_BY_PASS] + ) + def test_initialize(self, execution_order, tmpdir): + config = SearchStrategyConfig(execution_order=execution_order, sampler="random") + space_config = OrderedDict( + [ + ( + "PassA", + [ + {"PA0_SP1": Categorical(["A01a", "A01b"]), "PA0_SP2": Categorical(["A02a", "A02b"])}, + {"PA1_SP1": Categorical(["A11a", "A11b"])}, + {}, + {}, + ], + ), + ("PassB", [{"PB0_SP1": Categorical(["B01a", "B01b"])}, {}]), + ("PassC", [{}, {}]), + ] + ) + + strategy = SearchStrategy(config) + strategy.initialize(space_config, "whatever", {}) + + actual = re.sub(r"[\n\t\s]*", "", str(strategy._search_spaces)) + + if execution_order == SearchStrategyExecutionOrder.JOINT: + assert len(strategy._search_spaces) == 1 + + expected = re.sub( + r"[\n\t\s]*", + "", + """[ + SearchSpace([ + ('PassA', Categorical([ + SearchSpace([ + ('PA0_SP1', Categorical(['A01a', 'A01b'])), + ('PA0_SP2', Categorical(['A02a', 'A02b'])) + ], 4), + SearchSpace([('PA1_SP1', Categorical(['A11a', 'A11b']))], 2), + SearchSpace([], 1), + SearchSpace([], 1) + ]) + ), + ('PassB', Categorical([ + SearchSpace([('PB0_SP1', Categorical(['B01a', 'B01b']))], 2), + SearchSpace([], 1) + ]) + ), + ('PassC', Categorical([SearchSpace([], 1), SearchSpace([], 1)])) + ], 48) + ]""", + ) + + elif execution_order == SearchStrategyExecutionOrder.PASS_BY_PASS: + assert len(strategy._search_spaces) == 3 + + expected = re.sub( + r"[\n\t\s]*", + "", + """[ + SearchSpace([ + ('PassA', Categorical([ + SearchSpace([ + ('PA0_SP1', Categorical(['A01a', 'A01b'])), + ('PA0_SP2', Categorical(['A02a', 'A02b'])) + ], 4), + SearchSpace([('PA1_SP1', Categorical(['A11a', 'A11b']))], 2), + SearchSpace([], 1), + SearchSpace([], 1) + ]) + )], 8), + SearchSpace([ + ('PassB', Categorical([ + SearchSpace([('PB0_SP1', Categorical(['B01a', 'B01b']))], 2), + SearchSpace([], 1) + ]) + )], 3), + SearchSpace([('PassC', Categorical([SearchSpace([], 1), SearchSpace([], 1)]))], 2) + ]""", + ) + else: + expected = None + pytest.fail("Unsupported execution_order") + + assert actual == expected + + def test_iteration_joint(self, tmpdir): + config = SearchStrategyConfig(execution_order="joint", sampler="sequential") + space_config = OrderedDict( + [ + ( + "PassA", + [ + {"PA0_SP1": Categorical(["A01a", "A01b"]), "PA0_SP2": Categorical(["A02a", "A02b"])}, + {"PA1_SP1": Categorical(["A11a", "A11b"])}, + {}, + {}, + ], + ), + ("PassB", [{"PB0_SP1": Categorical(["B01a", "B01b"])}, {}]), + ("PassC", [{}, {}]), + ] + ) + + strategy = SearchStrategy(config) + strategy.initialize(space_config, "whatever", {}) + + actual = [(sample.search_point.index, sample.passes_configs) for sample in strategy] + expected = [ + ( + 0, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 1, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 2, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 3, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 4, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11a")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 5, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11b")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 6, + { + "PassA": OrderedDict([("index", 2), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 7, + { + "PassA": OrderedDict([("index", 3), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 8, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 9, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 10, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 11, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 12, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11a")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 13, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11b")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 14, + { + "PassA": OrderedDict([("index", 2), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 15, + { + "PassA": OrderedDict([("index", 3), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 16, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 17, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 18, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 19, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 20, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11a")]))]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 21, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11b")]))]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 22, + { + "PassA": OrderedDict([("index", 2), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 23, + { + "PassA": OrderedDict([("index", 3), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 24, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 25, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 26, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 27, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 28, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11a")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 29, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11b")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 30, + { + "PassA": OrderedDict([("index", 2), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 31, + { + "PassA": OrderedDict([("index", 3), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 32, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 33, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 34, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 35, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 36, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11a")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 37, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11b")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 38, + { + "PassA": OrderedDict([("index", 2), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 39, + { + "PassA": OrderedDict([("index", 3), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 40, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 41, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 42, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 43, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 44, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11a")]))]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 45, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11b")]))]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 46, + { + "PassA": OrderedDict([("index", 2), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 47, + { + "PassA": OrderedDict([("index", 3), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ] + + assert actual == expected + + def test_iteration_pass_by_pass(self, tmpdir): + config = SearchStrategyConfig(execution_order="pass-by-pass", sampler="sequential") + space_config = OrderedDict( + [ + ( + "PassA", + [ + {"PA0_SP1": Categorical(["A01a", "A01b"]), "PA0_SP2": Categorical(["A02a", "A02b"])}, + {"PA1_SP1": Categorical(["A11a", "A11b"])}, + {}, + {}, + ], + ), + ("PassB", [{"PB0_SP1": Categorical(["B01a", "B01b"])}, {}]), + ("PassC", [{}, {}]), + ] + ) + objectives = OrderedDict( + [ + ( + "PassA", + [ + { + "accuracy-accuracy_custom": { + "goal": 0.70, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 24.0, "higher_is_better": False, "priority": 2}, + }, + { + "accuracy-accuracy_custom": { + "goal": 0.85, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 23.0, "higher_is_better": False, "priority": 2}, + }, + { + "accuracy-accuracy_custom": { + "goal": 0.80, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 25.0, "higher_is_better": False, "priority": 2}, + }, + { + "accuracy-accuracy_custom": { + "goal": 0.75, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 20.0, "higher_is_better": False, "priority": 2}, + }, + ], + ), + ( + "PassB", + [ + { + "accuracy-accuracy_custom": { + "goal": 0.70, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 40.0, "higher_is_better": False, "priority": 2}, + "latency-max": {"goal": 64.0, "higher_is_better": False, "priority": 2}, + }, + { + "accuracy-accuracy_custom": { + "goal": 0.80, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 45.0, "higher_is_better": False, "priority": 2}, + "latency-max": {"goal": 72.0, "higher_is_better": False, "priority": 2}, + }, + ], + ), + ( + "PassC", + [ + { + "accuracy-accuracy_custom": { + "goal": 0.89, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 16.0, "higher_is_better": False, "priority": 2}, + }, + { + "accuracy-accuracy_custom": { + "goal": 0.92, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 14.0, "higher_is_better": False, "priority": 2}, + }, + ], + ), + ] + ) + + signal1 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.96, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 4.5, "priority": 2, "higher_is_better": False}, + } + ) + signal2 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.72, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 55.9, "priority": 2, "higher_is_better": False}, + } + ) + signal3 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.73, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 8.9, "priority": 2, "higher_is_better": False}, + } + ) + signal4 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.91, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 13.9, "priority": 2, "higher_is_better": False}, + } + ) + signal5 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.82, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 53.0, "priority": 2, "higher_is_better": False}, + "latency-max": {"value": 60.1, "priority": 2, "higher_is_better": False}, + } + ) + signal6 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.76, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 58.0, "priority": 2, "higher_is_better": False}, + "latency-max": {"value": 55.2, "priority": 3, "higher_is_better": False}, + } + ) + + signals = { + "PassA": [ + None, + ( + signal1, + ["model_id_1"], + False, + ), + None, + None, + ( + signal2, + ["model_id_2"], + False, + ), + ( + signal3, + ["model_id_3"], + True, + ), + ], + "PassB": [ + None, + None, + ( + signal5, + ["model_id_5"], + False, + ), + None, + None, + ( + signal6, + ["model_id_6"], + False, + ), + ], + "PassC": [ + None, + None, + None, + ( + signal4, + ["model_id_4"], + False, + ), + ], + } + + strategy = SearchStrategy(config) + strategy.initialize(space_config, "whatever", objectives) + + actual = [] + for sample in strategy: + actual.append((sample.search_point.index, sample.passes_configs, sample.model_ids)) + + pass_name = next(iter(sample.search_point.values.keys())) + if sample.search_point.index < len(signals[pass_name]) and signals[pass_name][sample.search_point.index]: + strategy.record_feedback_signal( + sample.search_point.index, *signals[pass_name][sample.search_point.index] + ) + + expected = [ + ( + 0, + OrderedDict( + {"PassA": OrderedDict({"index": 0, "params": OrderedDict({"PA0_SP1": "A01a", "PA0_SP2": "A02a"})})} + ), + ["whatever"], + ), + ( + 1, + OrderedDict( + {"PassA": OrderedDict({"index": 0, "params": OrderedDict({"PA0_SP1": "A01b", "PA0_SP2": "A02a"})})} + ), + ["whatever"], + ), + ( + 2, + OrderedDict( + {"PassA": OrderedDict({"index": 0, "params": OrderedDict({"PA0_SP1": "A01a", "PA0_SP2": "A02b"})})} + ), + ["whatever"], + ), + ( + 3, + OrderedDict( + {"PassA": OrderedDict({"index": 0, "params": OrderedDict({"PA0_SP1": "A01b", "PA0_SP2": "A02b"})})} + ), + ["whatever"], + ), + ( + 4, + OrderedDict({"PassA": OrderedDict({"index": 1, "params": OrderedDict({"PA1_SP1": "A11a"})})}), + ["whatever"], + ), + ( + 5, + OrderedDict({"PassA": OrderedDict({"index": 1, "params": OrderedDict({"PA1_SP1": "A11b"})})}), + ["whatever"], + ), + (6, OrderedDict({"PassA": OrderedDict({"index": 2, "params": OrderedDict()})}), ["whatever"]), + (7, OrderedDict({"PassA": OrderedDict({"index": 3, "params": OrderedDict()})}), ["whatever"]), + ( + 0, + OrderedDict({"PassB": OrderedDict({"index": 0, "params": OrderedDict({"PB0_SP1": "B01a"})})}), + ["model_id_1"], + ), + ( + 1, + OrderedDict({"PassB": OrderedDict({"index": 0, "params": OrderedDict({"PB0_SP1": "B01b"})})}), + ["model_id_1"], + ), + (2, OrderedDict({"PassB": OrderedDict({"index": 1, "params": OrderedDict()})}), ["model_id_1"]), + (0, OrderedDict({"PassC": OrderedDict({"index": 0, "params": OrderedDict()})}), ["model_id_5"]), + (1, OrderedDict({"PassC": OrderedDict({"index": 1, "params": OrderedDict()})}), ["model_id_5"]), + ] + + assert actual == expected diff --git a/test/unit_test/utils.py b/test/unit_test/utils.py index f3f670c08..e4f000796 100644 --- a/test/unit_test/utils.py +++ b/test/unit_test/utils.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import os from pathlib import Path +from typing import Any, Dict, Tuple, Type, Union from unittest.mock import MagicMock import numpy as np @@ -17,7 +18,7 @@ from olive.evaluator.metric import AccuracySubType, LatencySubType, Metric, MetricType from olive.evaluator.metric_config import MetricGoal from olive.model import HfModelHandler, ModelConfig, ONNXModelHandler, PyTorchModelHandler -from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.olive_pass import Pass, create_pass_from_dict ONNX_MODEL_PATH = Path(__file__).absolute().parent / "dummy_model.onnx" @@ -235,12 +236,14 @@ def get_throughput_metric(*lat_subtype, user_config=None): ) -def get_onnxconversion_pass(ignore_pass_config=True, target_opset=13): +def get_onnxconversion_pass( + ignore_pass_config=True, target_opset=13 +) -> Union[Type[Pass], Tuple[Type[Pass], Dict[str, Any]]]: from olive.passes.onnx.conversion import OnnxConversion onnx_conversion_config = {"target_opset": target_opset} p = create_pass_from_dict(OnnxConversion, onnx_conversion_config) - return p if ignore_pass_config else (p, p.to_json(check_object=True)["config"]) + return p if ignore_pass_config else (p, p.to_json(check_object=True)) def get_onnx_dynamic_quantization_pass(disable_search=False): diff --git a/test/unit_test/workflows/test_run_config.py b/test/unit_test/workflows/test_run_config.py index 75674e6cb..35fe011cf 100644 --- a/test/unit_test/workflows/test_run_config.py +++ b/test/unit_test/workflows/test_run_config.py @@ -150,42 +150,19 @@ def test_get_module_path(self, pass_type, is_onnx): assert pass_module.startswith("olive.passes.onnx") == is_onnx @pytest.mark.parametrize( - ("passes", "pass_flows", "is_onnx"), + ("passes", "is_onnx"), [ - (None, None, True), - ( - { - "lora": {"type": "LoRA"}, - }, - None, - False, - ), - ( - { - "lora": {"type": "LoRA"}, - "quantization": {"type": "IncQuantization"}, - }, - None, - True, - ), - ( - { - "lora": {"type": "LoRA"}, - "quantization": {"type": "IncQuantization"}, - }, - [["lora"]], - False, - ), + (None, True), + ({"lora": {"type": "LoRA"}}, False), + ({"lora": {"type": "LoRA"}, "quantization": {"type": "IncQuantization"}}, True), ], ) - def test_is_execution_provider_required(self, passes, pass_flows, is_onnx): + def test_is_execution_provider_required(self, passes, is_onnx): with self.user_script_config_file.open() as f: user_script_config = json.load(f) if passes: user_script_config["passes"] = passes - if pass_flows: - user_script_config["pass_flows"] = pass_flows run_config = RunConfig.parse_obj(user_script_config) result = is_execution_provider_required(run_config, self.package_config) @@ -212,7 +189,7 @@ def setup(self): }, } ], - "passes": {"tuning": {"type": "OrtSessionParamsTuning"}}, + "passes": {"tuning": [{"type": "OrtSessionParamsTuning"}]}, "engine": {"evaluate_input_model": False}, } @@ -286,10 +263,10 @@ def test_auto_insert_trust_remote_code( ) def test_str_to_data_config(self, data_config_str): config_dict = deepcopy(self.template) - config_dict["passes"]["tuning"]["data_config"] = data_config_str + config_dict["passes"]["tuning"][0]["data_config"] = data_config_str run_config = RunConfig.parse_obj(config_dict) - pass_data_config = run_config.passes["tuning"].config["data_config"] + pass_data_config = run_config.passes["tuning"][0].config["data_config"] if data_config_str is None: assert pass_data_config is None else: @@ -305,9 +282,11 @@ def setup(self): "type": "OnnxModel", }, "passes": { - "tuning": { - "type": "IncQuantization", - } + "tuning": [ + { + "type": "IncQuantization", + } + ] }, "evaluate_input_model": False, } @@ -319,19 +298,19 @@ def setup(self): (None, "SEARCHABLE_VALUES", False), (None, "dummy_approach", True), ( - {"execution_order": "joint", "search_algorithm": "exhaustive"}, + {"execution_order": "joint", "sampler": "sequential"}, "SEARCHABLE_VALUES", - True, + False, ), ], ) def test_pass_config_(self, search_strategy, approach, is_valid): config_dict = self.template.copy() config_dict["search_strategy"] = search_strategy - config_dict["passes"]["tuning"]["approach"] = approach + config_dict["passes"]["tuning"][0]["approach"] = approach if not is_valid: with pytest.raises(ValueError): # noqa: PT011 RunConfig.parse_obj(config_dict) else: config = RunConfig.parse_obj(config_dict) - assert config.passes["tuning"].config["approach"] == approach + assert config.passes["tuning"][0].config["approach"] == approach diff --git a/test/unit_test/workflows/test_workflow_run.py b/test/unit_test/workflows/test_workflow_run.py index ccdcd843c..f4a4ce6e8 100644 --- a/test/unit_test/workflows/test_workflow_run.py +++ b/test/unit_test/workflows/test_workflow_run.py @@ -1,3 +1,4 @@ +from copy import deepcopy from pathlib import Path from test.unit_test.utils import ( get_pytorch_model, @@ -100,7 +101,7 @@ def test_run_without_ep(mock_model_to_json, mock_model_from_json, mock_run, conf with user_script.open("w"): pass - config = config_test + config = deepcopy(config_test) config["passes"]["qat"]["config"]["user_script"] = str(user_script) config["engine"]["cache_dir"] = str(tmp_path / "cache") config["engine"]["output_dir"] = str(tmp_path / "output")