From 024fd06e4511c17d65e6fcbaeeec5a9500e0f052 Mon Sep 17 00:00:00 2001 From: shaahji Date: Thu, 30 Jan 2025 13:04:25 -0800 Subject: [PATCH 1/2] Fix issues with CI builds * Relax version of bitsandbytes * Add triton to requirements * Few fixes for using newer version of torch --- olive/model/handler/pytorch.py | 2 +- olive/olive_config.json | 4 ++-- test/requirements-test-gpu.txt | 4 ++-- test/unit_test/model/test_pytorch_model.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/olive/model/handler/pytorch.py b/olive/model/handler/pytorch.py index e87c9d39b..cb4b53c4b 100644 --- a/olive/model/handler/pytorch.py +++ b/olive/model/handler/pytorch.py @@ -151,7 +151,7 @@ def load_model(self, rank: int = None, cache_model: bool = True) -> "torch.nn.Mo elif self.model_file_format == ModelFileFormat.PYTORCH_TORCH_SCRIPT: model = torch.jit.load(self.model_path) elif self.model_file_format == ModelFileFormat.PYTORCH_ENTIRE_MODEL: - model = torch.load(self.model_path) + model = torch.load(self.model_path, weights_only=False) elif self.model_file_format == ModelFileFormat.PYTORCH_SLICE_GPT_MODEL: model = self._load_slicegpt_model() elif self.model_file_format == ModelFileFormat.PYTORCH_STATE_DICT: diff --git a/olive/olive_config.json b/olive/olive_config.json index 0027ef697..1bee9abfe 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -363,13 +363,13 @@ "extra_dependencies": { "auto-opt": [ "optimum" ], "azureml": [ "azure-ai-ml>=1.11.1", "azure-keyvault-secrets", "azure-identity", "azureml-fsspec" ], - "bnb": [ "bitsandbytes" ], + "bnb": [ "bitsandbytes", "triton" ], "capture-onnx-graph": [ "onnxruntime-genai", "optimum" ], "cpu": [ "onnxruntime" ], "directml": [ "onnxruntime-directml" ], "docker": [ "docker" ], "shared-cache": [ "azure-identity", "azure-storage-blob" ], - "finetune": [ "onnxruntime-genai", "optimum", "accelerate>=0.30.0", "peft", "scipy", "bitsandbytes" ], + "finetune": [ "onnxruntime-genai", "optimum", "accelerate>=0.30.0", "peft", "scipy", "bitsandbytes", "triton" ], "flash-attn": [ "flash_attn" ], "gpu": [ "onnxruntime-gpu" ], "inc": [ "neural-compressor" ], diff --git a/test/requirements-test-gpu.txt b/test/requirements-test-gpu.txt index e8b9592f6..65674e827 100644 --- a/test/requirements-test-gpu.txt +++ b/test/requirements-test-gpu.txt @@ -1,5 +1,5 @@ -r requirements-test.txt auto-gptq autoawq -# only available on Linux currently -bitsandbytes==0.43.3 +bitsandbytes +triton diff --git a/test/unit_test/model/test_pytorch_model.py b/test/unit_test/model/test_pytorch_model.py index ac26454e1..bee910886 100644 --- a/test/unit_test/model/test_pytorch_model.py +++ b/test/unit_test/model/test_pytorch_model.py @@ -58,7 +58,7 @@ def test_load_from_path(torch_load): model = PyTorchModelHandler(model_path="test_path") assert model.load_model() == "dummy_pytorch_model" - torch_load.assert_called_once_with("test_path") + torch_load.assert_called_once_with("test_path", weights_only=False) @patch("olive.model.handler.pytorch.UserModuleLoader") From cbea4293197359a71258d8c4c063816a6c8e335a Mon Sep 17 00:00:00 2001 From: shaahji Date: Thu, 16 Jan 2025 16:03:27 -0800 Subject: [PATCH 2/2] Implement pass search Reimplement search logic to include passes in search space. --- .gitignore | 1 + .lintrunner.toml | 31 - docs/architecture.md | 36 +- docs/source/extending/design.md | 28 +- .../how-to/configure-workflows/auto-opt.md | 16 +- .../configure-workflows/model-packaging.md | 8 +- docs/source/reference/index.rst | 8 +- docs/source/reference/options.md | 65 +- docs/source/reference/search-algorithm.rst | 30 - docs/source/reference/search-samplers.rst | 30 + docs/source/why-olive.md | 2 +- {olive/strategy => examples}/__init__.py | 0 examples/bert/README.md | 4 +- examples/bert/bert.py | 27 + ...a_gpu.json => bert_cuda_gpu.template.json} | 6 +- examples/bert/bert_inc_dynamic_ptq_cpu.json | 2 +- examples/bert/bert_inc_ptq_cpu.json | 2 +- examples/bert/bert_ptq_cpu.json | 2 +- examples/bert/notebook/bert_auto_opt_gpu.json | 2 +- examples/bert/notebook/multi_ep_search.ipynb | 10 +- examples/bert/user_script.py | 14 +- examples/deberta/deberta.json | 2 +- examples/directml/llm/config_llm.json | 2 +- .../config_text_encoder.json | 1 - .../config_text_encoder_2.json | 1 - .../stable_diffusion_xl/config_unet.json | 1 - .../config_vae_decoder.json | 1 - .../config_vae_encoder.json | 1 - .../stable_diffusion_xl.py | 10 +- examples/llama2/llama2.py | 35 +- examples/llama2/llama2_lmeval.json | 2 +- examples/llama2/llama2_model_builder.py | 7 +- ...nfig_cpu.json => config_cpu.template.json} | 0 ...nfig_gpu.json => config_gpu.template.json} | 0 ..._ep.json => config_multi_ep.template.json} | 0 .../llama2/notebook/llama2_multiep/llama2.py | 32 + examples/mistral/mistral.py | 2 +- examples/mistral/mistral_fp16.json | 1 - examples/mistral/mistral_int4.json | 1 - examples/mobilenet/prepare_config.py | 32 +- examples/mobilenet/raw_qnn_sdk_template.json | 4 +- examples/phi2/phi2.py | 17 +- examples/phi3/phi3.py | 7 +- examples/phi3/phi3_template.json | 1 - examples/resnet/resnet_dynamic_ptq_cpu.json | 2 +- examples/resnet/resnet_ptq_cpu.json | 2 +- .../resnet/resnet_ptq_cpu_aml_dataset.json | 2 +- examples/resnet/resnet_static_ptq_cpu.json | 2 +- .../config_safety_checker.json | 1 - .../stable_diffusion/config_text_encoder.json | 1 - examples/stable_diffusion/config_unet.json | 1 - .../stable_diffusion/config_vae_decoder.json | 1 - .../stable_diffusion/config_vae_encoder.json | 1 - examples/stable_diffusion/sd_utils/ort.py | 20 +- examples/stable_diffusion/sd_utils/ov.py | 2 +- examples/stable_diffusion/stable_diffusion.py | 5 +- examples/test/__init__.py | 4 + examples/test/azureml/__init__.py | 4 + .../test/azureml/test_bert_ptq_cpu_aml.py | 2 +- examples/test/azureml/test_llama2.py | 6 +- .../test/azureml/test_resnet_ptq_cpu_aml.py | 6 +- examples/test/local/__init__.py | 4 + examples/test/local/test_bert_cuda_gpu.py | 11 +- examples/test/local/test_bert_ptq_cpu.py | 6 +- .../test/local/test_bert_ptq_cpu_docker.py | 6 +- examples/test/local/test_llama2.py | 6 +- examples/test/local/test_qnn_tooklit.py | 9 +- examples/test/local/test_resnet_ptq_cpu.py | 6 +- examples/test/local/test_resnet_qat.py | 6 +- examples/test/utils.py | 12 +- olive/cache.py | 2 +- olive/cli/auto_opt.py | 2 +- olive/cli/base.py | 12 +- olive/cli/quantize.py | 4 +- olive/engine/config.py | 40 +- olive/engine/engine.py | 326 ++- olive/evaluator/olive_evaluator.py | 1 + olive/passes/olive_pass.py | 58 +- olive/passes/onnx/inc_quantization.py | 2 +- olive/passes/onnx/nvmo_quantization.py | 2 +- olive/passes/onnx/quantization.py | 2 +- olive/passes/onnx/session_params_tuning.py | 14 +- olive/passes/onnx/vitis_ai_quantization.py | 2 +- olive/passes/pass_config.py | 2 +- olive/passes/pytorch/lora.py | 2 +- olive/passes/snpe/quantization.py | 2 +- olive/search/__init__.py | 4 + olive/search/samplers/__init__.py | 12 + olive/search/samplers/optuna_sampler.py | 144 ++ olive/search/samplers/random_sampler.py | 61 + olive/search/samplers/search_sampler.py | 69 + olive/search/samplers/sequential_sampler.py | 46 + .../samplers}/tpe_sampler.py | 4 +- .../{strategy => search}/search_parameter.py | 2 +- olive/search/search_point.py | 70 + olive/search/search_results.py | 124 ++ olive/search/search_sample.py | 61 + olive/search/search_space.py | 230 ++ olive/search/search_strategy.py | 345 +++ olive/{strategy => search}/utils.py | 2 +- olive/strategy/search_algorithm/__init__.py | 12 - olive/strategy/search_algorithm/exhaustive.py | 33 - .../search_algorithm/optuna_sampler.py | 101 - .../search_algorithm/random_sampler.py | 58 - .../search_algorithm/search_algorithm.py | 59 - olive/strategy/search_results.py | 131 -- olive/strategy/search_space.py | 111 - olive/strategy/search_strategy.py | 282 --- olive/systems/azureml/aml_system.py | 7 +- olive/systems/docker/docker_system.py | 13 +- olive/systems/olive_system.py | 7 +- olive/workflows/run/config.py | 130 +- olive/workflows/run/run.py | 86 +- .../packaging/test_packaging_generator.py | 2 +- test/unit_test/engine/test_engine.py | 57 +- .../passes/common/test_user_script.py | 4 +- .../passes/test_pass_serialization.py | 3 +- .../search/samplers/test_random_sampler.py | 101 + .../samplers/test_sequential_sampler.py | 93 + .../search/samplers/test_tpe_sampler.py | 168 ++ test/unit_test/search/test_search_results.py | 106 + test/unit_test/search/test_search_space.py | 1963 +++++++++++++++++ test/unit_test/search/test_search_strategy.py | 849 +++++++ test/unit_test/utils.py | 9 +- test/unit_test/workflows/test_run_config.py | 55 +- test/unit_test/workflows/test_workflow_run.py | 3 +- 126 files changed, 5177 insertions(+), 1525 deletions(-) delete mode 100644 docs/source/reference/search-algorithm.rst create mode 100644 docs/source/reference/search-samplers.rst rename {olive/strategy => examples}/__init__.py (100%) create mode 100644 examples/bert/bert.py rename examples/bert/{bert_cuda_gpu.json => bert_cuda_gpu.template.json} (89%) rename examples/llama2/notebook/llama2_multiep/{config_cpu.json => config_cpu.template.json} (100%) rename examples/llama2/notebook/llama2_multiep/{config_gpu.json => config_gpu.template.json} (100%) rename examples/llama2/notebook/llama2_multiep/{config_multi_ep.json => config_multi_ep.template.json} (100%) create mode 100644 examples/llama2/notebook/llama2_multiep/llama2.py create mode 100644 olive/search/__init__.py create mode 100644 olive/search/samplers/__init__.py create mode 100644 olive/search/samplers/optuna_sampler.py create mode 100644 olive/search/samplers/random_sampler.py create mode 100644 olive/search/samplers/search_sampler.py create mode 100644 olive/search/samplers/sequential_sampler.py rename olive/{strategy/search_algorithm => search/samplers}/tpe_sampler.py (92%) rename olive/{strategy => search}/search_parameter.py (99%) create mode 100644 olive/search/search_point.py create mode 100644 olive/search/search_results.py create mode 100644 olive/search/search_sample.py create mode 100644 olive/search/search_space.py create mode 100644 olive/search/search_strategy.py rename olive/{strategy => search}/utils.py (97%) delete mode 100644 olive/strategy/search_algorithm/__init__.py delete mode 100644 olive/strategy/search_algorithm/exhaustive.py delete mode 100644 olive/strategy/search_algorithm/optuna_sampler.py delete mode 100644 olive/strategy/search_algorithm/random_sampler.py delete mode 100644 olive/strategy/search_algorithm/search_algorithm.py delete mode 100644 olive/strategy/search_results.py delete mode 100644 olive/strategy/search_space.py delete mode 100644 olive/strategy/search_strategy.py create mode 100644 test/unit_test/search/samplers/test_random_sampler.py create mode 100644 test/unit_test/search/samplers/test_sequential_sampler.py create mode 100644 test/unit_test/search/samplers/test_tpe_sampler.py create mode 100644 test/unit_test/search/test_search_results.py create mode 100644 test/unit_test/search/test_search_space.py create mode 100644 test/unit_test/search/test_search_strategy.py diff --git a/.gitignore b/.gitignore index e2c5a77c4..bc06032a2 100644 --- a/.gitignore +++ b/.gitignore @@ -126,6 +126,7 @@ celerybeat.pid # Environments .env .venv +.vs env/ venv/ ENV/ diff --git a/.lintrunner.toml b/.lintrunner.toml index 2b4cbbf02..d197dda38 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -117,37 +117,6 @@ init_command = [ '--requirement=requirements-dev.txt' ] -[[linter]] -code = 'MYPY' -include_patterns = [ - '**/*.py', - '**/*.pyi' -] -exclude_patterns = [ - 'examples/pytorch/*.py' -] -command = [ - 'python', - '-m', - 'lintrunner_adapters', - 'run', - 'mypy_linter', - '--config=pyproject.toml', - '--show-notes', - '--show-disable', - '--', - '@{{PATHSFILE}}' -] -init_command = [ - 'python', - '-m', - 'lintrunner_adapters', - 'run', - 'pip_init', - '--dry-run={{DRYRUN}}', - 'mypy==1.0.0' -] - [[linter]] code = 'NOQA' include_patterns = ['**/*.py', '**/*.pyi'] diff --git a/docs/architecture.md b/docs/architecture.md index c3de402fa..29efb7c06 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -13,12 +13,12 @@ This document describes the Olive components, and some implementation details. T - [Search](#search) - [Search Parameter](#searchparameter) - [Search Space](#searchspace) - - [Search Algorithm](#searchalgorithm) + - [Search Sampler](#searchsampler) - [Search Results](#searchresults) -- [Search Strategy](#search-strategy) - - [Execution order](#execution-order) - - [User Interface](#user-interface) - - [Implementation](#implementation) + - [Search Strategy](#search-strategy) + - [Execution order](#execution-order) + - [User Interface](#user-interface) + - [Implementation](#implementation) - [System](#system) - [OliveSystem Class](#olivesystem-class) - [Data Container](#data-container) @@ -109,9 +109,9 @@ The engine maintains a cache directory with three sub-directories: - `mlflow`: stores mlflow model files. ## Search -Olive workflows support search parameters which are optimized using search algorithms. +Olive workflows support search parameters which are optimized using various execution order. -At the most basic level is `SearchParameter` which describes the options for search parameters. `SearchSpace` combines search parameters for one or more passes and `SearchAlgorithm` provides different sampling algorithms to search for the best parameter configuration (search point) from the search space. +At the most basic level is `SearchParameter` which describes the options for search parameters. `SearchSpace` combines search parameters for one or more passes and `SearchSampler` provides different sampling algorithms to search for the best parameter configuration (search point) from the search space. ### SearchParameter A search parameter defines a discrete categorical distribution. @@ -122,7 +122,7 @@ There are two types of search parameters: **Note:** - There cannot be any cyclic parent child dependencies. -- Search algorithms order the search parameters topologically so that the parents are sampled before the children. +- Search space orders the search parameters topologically so that the parents are sampled before the children. ### SearchSpace Search space combines search parameters from one or more passes and provides methods to iterate over the search space (`iterate`) or generate random samples (`random_sample`). @@ -139,18 +139,18 @@ The corresponding conceptual search space is the space of all possible parameter {“pass_id/space_name”: {“param_name”: param_value}} ``` -### SearchAlgorithm -Search algorithm operates over a search space and provides samples/trials (search points) from the search space to execute and evaluate. +### SearchSampler +Sampling algorithm queries over a search space and provides samples/trials (search points) from the search space to evaluate. -Each search algorithm provides the methods: +Each search sampler provides the methods: - `suggest`: returns a search point to execute and evaluate. The algorithm can sample a search point based on the evaluation results for previously suggested points. - `report`: report evaluation results for a search point. The search point can also be pruned if it contains invalid pass configs or failed during execution/evaluation. -The following search algorithms have been implemented: -- `ExhaustiveSearchAlgorithm`: Exhaustively iterates over the search space. -- `RandomSearchAlgorithm`: Randomly samples points from the search space without replacement. -- `OptunaSearchAlgorithm`: Abstract base class for algorithms built using `optuna` samplers. This class cannot be used directly - - `TPESearchAlgorithm`: Uses optuna `TPESampler`. +The following sampling algorithms have been implemented: +- `SequentialSampler`: Sequentially iterates over the search space. +- `RandomSampler`: Randomly samples points from the search space. +- `OptunaSampler`: Abstract base class for algorithms built using `optuna` samplers. + - `TPESampler`: Uses optuna's `TPESampler`. ### SearchResults `SearchResults` stores evaluation results for samples made from a search space and provides tools to analyze and select the best search point/s. @@ -159,10 +159,10 @@ Results are reported using the `record` method. Currently `best_search_point` selects the best search point by maximizing/minimizing metrics using tie breaking. We intend to provide different model selection strategies for both single and multi-objective optimization. -## Search Strategy +## SearchStrategy Search strategy provides an optimization pipeline that finds the best search point from the search space of one or more passes. -It consists of two sub-components – `execution_order` and `search_algorithm`. Search algorithm has been covered in the previous section. +It consists of two sub-components – `execution_order` and `sampler`. Sampling algorithms have been covered in the previous section. ### Execution Order The execution order defines the order in which the passes are optimized. diff --git a/docs/source/extending/design.md b/docs/source/extending/design.md index 2ac6445bc..5af2641cf 100644 --- a/docs/source/extending/design.md +++ b/docs/source/extending/design.md @@ -5,8 +5,7 @@ that are composed to construct a model optimization workflow. The workflow which is run by the **Engine** is composed of **Passes** that are executed in a specific order. Each Pass is responsible for performing a specific optimization on the model. Each Pass might have a set of parameters that can be tuned to achieve the best metrics, say accuracy and latency, that are evaluated by the respective **Evaluator**. -The Engine employs a **Search Strategy** that uses a **Search Algorithm** to auto-tune each Pass one by one or set of Passes -together. +The Engine employs a **Search Strategy** that uses a **Search Samplers** to auto-tune each Pass one by one or set of Passes together. Each Pass can be run on any host **System** and its output model can be evaluated on the desired target **System**. @@ -71,22 +70,23 @@ created and **registered** along with their host system and evaluators if any. The engine also maintains a cache directory to cache pass runs, models and evaluations. ## Search Strategy -Search strategy provides an optimization pipeline that finds the best search point from the search space of one or more passes. +Search strategy provides an optimization pipeline that finds the best search point from the search space built from one or more passes and search parameters within each of those passes. `include_pass_params` controls whether or not individual pass' search parameters are included in the search. `max_iter` and `max_time` can be configured for finer control. -It consists of two sub-components – `execution_order` and `search_algorithm`. +It consists of two sub-components – `execution_order` and `sampler`. ### Execution Order -The execution order defines the order in which the passes are optimized. +The execution order defines the order in which the search space is traversed. Currently, we support two execution orders: -- `joint`: The search spaces of all passes are combined and searched together to find the best search point. Each search point -that is evaluated has parameters for the search parameters of all passes. -- `pass-by-pass`: The search space of each pass is searched and optimized independently in order. +- `joint`: The search spaces of all passes and their corresponding parameters are combined and searched together to find the best search point. Each search point consists of values for all search parameters of at most one pass in each pass group. +- `pass-by-pass`: The search space of each pass group is searched and optimized independently in order. -### Search Algorithm -Search algorithm operates over a search space and provides samples/trials (search points) from the search space to execute and evaluate. +### Search Sampler +Search sampler provides samples/trials (search points) from the search space to evaluate. Each search point consists of valuesfor all search parameters and all passes within the pass group. -The following search algorithms have been implemented: -- `exhaustive`: Exhaustively iterates over the search space. -- `random`: Randomly samples points from the search space without replacement. -- `tpe`: ample using TPE (Tree-structured Parzen Estimator) algorithm to sample from the search space. +The following sampling algorithms have been implemented: +- `sequential`: Sequentially iterates over the search space. +- `random`: Randomly samples points from the search space. +- `tpe`: Sample using TPE (Tree-structured Parzen Estimator) algorithm to sample the search space. + +Each of the sampler can be used as an exhaustive search by setting the `max_samples` field to zero. diff --git a/docs/source/how-to/configure-workflows/auto-opt.md b/docs/source/how-to/configure-workflows/auto-opt.md index 98bc03822..c0a6ba268 100644 --- a/docs/source/how-to/configure-workflows/auto-opt.md +++ b/docs/source/how-to/configure-workflows/auto-opt.md @@ -41,11 +41,9 @@ Here is a simple example of Auto Optimizer configuration, the item which is not "engine": { "search_strategy": { "execution_order": "joint", - "search_algorithm": "tpe", - "search_algorithm_config": { - "num_samples": 1, - "seed": 0 - } + "sampler": "tpe", + "max_samples": 1, + "seed": 0 }, "evaluator": "common_evaluator", "cache_dir": "cache", @@ -149,8 +147,8 @@ Here is another quick comparison between Auto Optimizer and manual settings. }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "tpe", - "num_samples": 1, + "sampler": "tpe", + "max_samples": 1, "seed": 0 }, "evaluator": "common_evaluator", @@ -261,8 +259,8 @@ Here is another quick comparison between Auto Optimizer and manual settings. ], "search_strategy": { "execution_order": "joint", - "search_algorithm": "tpe", - "num_samples": 1, + "sampler": "tpe", + "max_samples": 1, "seed": 0 }, "evaluator": "common_evaluator", diff --git a/docs/source/how-to/configure-workflows/model-packaging.md b/docs/source/how-to/configure-workflows/model-packaging.md index e6d447ece..3e76a4ec7 100644 --- a/docs/source/how-to/configure-workflows/model-packaging.md +++ b/docs/source/how-to/configure-workflows/model-packaging.md @@ -197,11 +197,9 @@ You can add different types `PackagingConfig` as a list to Engine configurations "engine": { "search_strategy": { "execution_order": "joint", - "search_algorithm": "tpe", - "search_algorithm_config": { - "num_samples": 5, - "seed": 0 - } + "sampler": "tpe", + "max_samples": 5, + "seed": 0 }, "evaluator": "common_evaluator", "host": "local_system", diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst index 2777f8d34..07ad3a168 100644 --- a/docs/source/reference/index.rst +++ b/docs/source/reference/index.rst @@ -33,11 +33,11 @@ Reference :octicon:`arrow-right;1em;sd-text-info` `Pass `_ .. grid-item-card:: - **Search Algorithm** + **Search Samplers** - Configure search strategies. + Configure search samplers. - :octicon:`arrow-right;1em;sd-text-info` `Search Algorithm `_ + :octicon:`arrow-right;1em;sd-text-info` `Search Samplers `_ .. toctree:: @@ -48,4 +48,4 @@ Reference options model pass - search-algorithm + search-samplers diff --git a/docs/source/reference/options.md b/docs/source/reference/options.md index 322dc7144..3407a7cc3 100644 --- a/docs/source/reference/options.md +++ b/docs/source/reference/options.md @@ -446,41 +446,6 @@ Please also find the detailed options from following table for each pass: } ``` -## Pass Flows Information - -`pass_flows: List[List[str]]` - -This is a list of list of pass names. Each list of pass names is a pass flow which will be executed in order. -When `pass_flows` is not specified, the passes are executed in the order of the `passes` dictionary. - -### Example - -```json -"passes": { - "onnx_conversion": { - "type": "OnnxConversion", - "target_opset": 13 - }, - "transformers_optimization": { - "type": "OrtTransformersOptimization", - "model_type": "bert", - "num_heads": 12, - "hidden_size": 768, - "float16": true - }, - "onnx_quantization": { - "type": "OnnxQuantization", - "data_config": "calib_data_coonfig", - "weight_type": "QUInt8" - } -}, -"pass_flows": [ - ["onnx_conversion", "transformers_optimization"], - ["onnx_conversion", "transformers_optimization", "onnx_quantization"], - ["onnx_conversion", "onnx_quantization"], -] -``` - ## Engine Information `engine: [Dict]` @@ -491,15 +456,13 @@ This is a dictionary that contains the information of the engine. Its fields can - `execution_order: [str]` The execution order of the optimizations of passes. The options are `pass-by-pass` and `joint`. - - `search_algorithm: [str]` The search algorithm of the engine. The available search algorithms are `exhaustive`, `random` and `tpe`. + - `sampler: [str]` The search sampler to use while traversing. The available search algorithms are `random`, `sequential` and `tpe`. - - `search_algorithm_config: [Dict]` The configuration of the search algorithm. The configuration of the search algorithm depends on - the search algorithm. Its fields can be provided directly to the parent dictionary. + - `sampler_config: [Dict]` The configuration of the sampler. The options depends on the chosen sampler. Its fields can be provided directly to the parent dictionary. - - `output_model_num: [int]` The number of output models from the engine based on metric priority. If not specified, the engine will output all qualified models. + - `stop_when_goals_met: [Boolean]` This decides whether to stop the search when the metric goals, if any, are met. This is `false` by default. - - `stop_when_goals_met: [Boolean]` This decides whether to stop the search when the metric goals, if any, are met. This is `false` by - default. + - `include_pass_params: [Boolean]` Includes individual pass parameter to build the search space. Defaults to true. - `max_iter: [int]` The maximum number of iterations of the search. Only valid for `joint` execution order. By default, there is no maximum number of iterations. @@ -511,8 +474,8 @@ This is a dictionary that contains the information of the engine. Its fields can have empty search spaces. The output of the final pass will be evaluated if there is a valid evaluator. The output of the engine will be the output model of the final pass and its evaluation result. - If `search_strategy` is `true`, the search strategy will be the default search strategy. The default search strategy is `exhaustive` search - algorithm with `joint` execution order. + If `search_strategy` is `true`, the search strategy will be the default search strategy. The default search strategy is `sequential` search + sampler with `joint` execution order. - `evaluate_input_model: [Boolean]` In this mode, the engine will evaluate the input model using the engine's evaluator and return the results. If the engine has no evaluator, it will skip the evaluation. This is `true` by default. @@ -556,13 +519,15 @@ This is a dictionary that contains the information of the engine. Its fields can - `log_to_file: [Boolean]`, `false` by default. This decides whether to log to file. If `true`, the log will be stored in a olive-.log file under the current working directory. -Please find the detailed config options from following table for each search algorithm: +Please find the detailed config options from following table for each search sampler: + +Note that if `max_samples` is set to zero, each of the below sampler will be exhaustive. -| Algorithm | Description | +| Sampler | Description | |:----------|:-------------| -| [exhaustive](../../reference/search-algorithm.rst#_exhaustive_search_algorithm) | Iterates over the entire search space | -| [random](../../reference/search-algorithm.rst#_random_search_algorithm) | Samples random points from the search space with or without replacement | -| [tpe](../../reference/search-algorithm.rst#_tpe_search_algorithm) | Sample using TPE (Tree-structured Parzen Estimator) algorithm. | +| [random](../../reference/search-samplers.rst#_random_sampler) | Samples random points from the search space | +| [sequential](../../reference/search-samplers.rst#_sequential_sampler) | Iterates over the entire search space sequentially | +| [tpe](../../reference/search-samplers.rst#_tpe_sampler) | Sample using TPE (Tree-structured Parzen Estimator) algorithm. | ### Example @@ -570,8 +535,8 @@ Please find the detailed config options from following table for each search alg "engine": { "search_strategy": { "execution_order": "joint", - "search_algorithm": "tpe", - "num_samples": 5, + "sampler": "tpe", + "max_samples": 5, "seed": 0 }, "evaluator": "common_evaluator", diff --git a/docs/source/reference/search-algorithm.rst b/docs/source/reference/search-algorithm.rst deleted file mode 100644 index 59f8731df..000000000 --- a/docs/source/reference/search-algorithm.rst +++ /dev/null @@ -1,30 +0,0 @@ -SearchAlgorithms -================================= - -The following search algorithms are available in Olive. - -Each algorithm is followed by a description of the algorithm and a list of the its configuration options. - -.. _exhaustive_search_algorithm: - -ExhaustiveSearchAlgorithm -------------------------- -**Name:** :code:`"exhaustive"` - -.. autoconfigclass:: olive.strategy.search_algorithm.ExhaustiveSearchAlgorithm - -.. _random_search_algorithm: - -RandomSearchAlgorithm ---------------------- -**Name:** :code:`"random"` - -.. autoconfigclass:: olive.strategy.search_algorithm.RandomSearchAlgorithm - -.. _tpe_search_algorithm: - -TPESearchAlgorithm ------------------- -**Name:** :code:`"tpe"` - -.. autoconfigclass:: olive.strategy.search_algorithm.TPESearchAlgorithm diff --git a/docs/source/reference/search-samplers.rst b/docs/source/reference/search-samplers.rst new file mode 100644 index 000000000..4c43010d6 --- /dev/null +++ b/docs/source/reference/search-samplers.rst @@ -0,0 +1,30 @@ +Samplers +================================= + +The following sampling algorithms are available in Olive. + +Each sampler is followed by a description of the algorithm and a list of the its configuration options. + +.. _sequential_sampler: + +SequentialSampler +----------------- +**Name:** :code:`"sequential"` + +.. autoconfigclass:: olive.search.samplers.SequentialSampler + +.. _random_sampler: + +RandomSampler +------------- +**Name:** :code:`"random"` + +.. autoconfigclass:: olive.search.samplers.RandomSampler + +.. _tpe_sampler: + +TPESampler +---------- +**Name:** :code:`"tpe"` + +.. autoconfigclass:: olive.search.samplers.TPESampler diff --git a/docs/source/why-olive.md b/docs/source/why-olive.md index 267241746..1e07637c5 100644 --- a/docs/source/why-olive.md +++ b/docs/source/why-olive.md @@ -16,7 +16,7 @@ Olive (**O**NNX **LIVE**) is a cutting-edge model optimization toolkit with an a The input to Olive is typically a PyTorch or Hugging Face model, and the output is an optimized ONNX model that is executed on a device (deployment target) running the ONNX runtime. Olive will optimize the model for the deployment target's AI accelerator (NPU, GPU, CPU) provided by a hardware vendor such as Qualcomm, AMD, Nvidia, or Intel. -Olive executes a *workflow*, which is an ordered sequence of individual model optimization tasks called *passes* - example passes include model compression, graph capture, quantization, and graph optimization. Each pass has a set of parameters that can be tuned to achieve the best metrics, such as accuracy and latency, that are evaluated by the respective *evaluator*. Olive employs a *search strategy* that uses a *search algorithm* to auto-tune each pass individually or a set of passes together. +Olive executes a *workflow*, which is an ordered sequence of individual model optimization tasks called *passes* - example passes include model compression, graph capture, quantization, and graph optimization. Each pass has a set of parameters that can be tuned to achieve the best metrics, such as accuracy and latency, that are evaluated by the respective *evaluator*. Olive employs a *search strategy* that uses a *search sampler* to auto-tune each pass individually or a set of passes together. ``` ## Benefits of using Olive diff --git a/olive/strategy/__init__.py b/examples/__init__.py similarity index 100% rename from olive/strategy/__init__.py rename to examples/__init__.py diff --git a/examples/bert/README.md b/examples/bert/README.md index fc994726d..1f32016e4 100644 --- a/examples/bert/README.md +++ b/examples/bert/README.md @@ -90,8 +90,10 @@ Config file: [bert_qat_customized_train_loop_cpu.json](bert_qat_customized_train ### BERT optimization with CUDA/TensorRT on GPU This workflow performs BERT optimization on GPU with CUDA/TensorRT. It performs the optimization pipeline: 1. CUDA: `CUDAExecutionProvider` + - *PyTorch Model -> Onnx Model -> ONNX Runtime performance tuning* + Run: [bert.py](bert.py) - *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model with fp16 -> ONNX Runtime performance tuning* - Config file: [bert_cuda_gpu.json](bert_cuda_gpu.json) + Run: [bert.py](bert.py) --optimize 2. TensorRT: `TensorrtExecutionProvider` - *PyTorch Model -> Onnx Model -> ONNX Runtime performance tuning with trt_fp16_enable* Config file: [bert_trt_gpu.json](bert_trt_gpu.json) diff --git a/examples/bert/bert.py b/examples/bert/bert.py new file mode 100644 index 000000000..cda726c3b --- /dev/null +++ b/examples/bert/bert.py @@ -0,0 +1,27 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import argparse +import json +from pathlib import Path + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--optimize", + action="store_true", + help="If set, run transformers optimization pass", + ) + args = parser.parse_args() + + input_filename = "bert_cuda_gpu.template.json" + with Path(input_filename).open("r") as f: + config = json.load(f) + + if not args.optimize: + del config["passes"]["transformers_optimization"] + + output_filename = input_filename.replace(".template", "") + with Path(output_filename).open("w") as strm: + json.dump(config, fp=strm, indent=4) diff --git a/examples/bert/bert_cuda_gpu.json b/examples/bert/bert_cuda_gpu.template.json similarity index 89% rename from examples/bert/bert_cuda_gpu.json rename to examples/bert/bert_cuda_gpu.template.json index 85aaad56a..0eae68317 100644 --- a/examples/bert/bert_cuda_gpu.json +++ b/examples/bert/bert_cuda_gpu.template.json @@ -51,11 +51,7 @@ "transformers_optimization": { "type": "OrtTransformersOptimization", "float16": true }, "session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "glue_mrpc", "io_bind": true } }, - "pass_flows": [ - [ "conversion", "transformers_optimization", "session_params_tuning" ], - [ "conversion", "session_params_tuning" ] - ], - "search_strategy": { "execution_order": "joint", "search_algorithm": "tpe", "num_samples": 3, "seed": 0 }, + "search_strategy": { "execution_order": "joint", "sampler": "tpe", "max_samples": 3, "seed": 0 }, "host": "local_system", "target": "local_system", "evaluator": "common_evaluator", diff --git a/examples/bert/bert_inc_dynamic_ptq_cpu.json b/examples/bert/bert_inc_dynamic_ptq_cpu.json index 9d9f52cfa..889beff3c 100644 --- a/examples/bert/bert_inc_dynamic_ptq_cpu.json +++ b/examples/bert/bert_inc_dynamic_ptq_cpu.json @@ -34,7 +34,7 @@ "transformers_optimization": { "type": "OrtTransformersOptimization", "model_type": "bert" }, "dynamic_quantization": { "type": "IncDynamicQuantization" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "evaluator": "common_evaluator", "cache_dir": "cache", "output_dir": "models/bert_inc_dynamic_ptq_cpu" diff --git a/examples/bert/bert_inc_ptq_cpu.json b/examples/bert/bert_inc_ptq_cpu.json index 09558de46..82220ac84 100644 --- a/examples/bert/bert_inc_ptq_cpu.json +++ b/examples/bert/bert_inc_ptq_cpu.json @@ -64,7 +64,7 @@ } } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "evaluator": "common_evaluator", "cache_dir": "cache", "output_dir": "models/bert_inc_ptq_cpu" diff --git a/examples/bert/bert_ptq_cpu.json b/examples/bert/bert_ptq_cpu.json index 5a717acf1..625415925 100644 --- a/examples/bert/bert_ptq_cpu.json +++ b/examples/bert/bert_ptq_cpu.json @@ -67,7 +67,7 @@ }, "session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "glue_mrpc" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "tpe", "num_samples": 3, "seed": 0 }, + "search_strategy": { "execution_order": "joint", "sampler": "tpe", "max_samples": 3, "seed": 0 }, "evaluator": "common_evaluator", "host": "local_system", "target": "local_system", diff --git a/examples/bert/notebook/bert_auto_opt_gpu.json b/examples/bert/notebook/bert_auto_opt_gpu.json index f1ed055a5..929626f46 100644 --- a/examples/bert/notebook/bert_auto_opt_gpu.json +++ b/examples/bert/notebook/bert_auto_opt_gpu.json @@ -43,7 +43,7 @@ ] } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "tpe", "num_samples": 1, "seed": 0 }, + "search_strategy": { "execution_order": "joint", "sampler": "tpe", "max_samples": 1, "seed": 0 }, "evaluator": "common_evaluator", "host": "local_system", "target": "local_system", diff --git a/examples/bert/notebook/multi_ep_search.ipynb b/examples/bert/notebook/multi_ep_search.ipynb index fcedbfa1b..fd9ccfe70 100644 --- a/examples/bert/notebook/multi_ep_search.ipynb +++ b/examples/bert/notebook/multi_ep_search.ipynb @@ -141,17 +141,15 @@ "#### Engine and search strategy\n", "\n", "Engine is used to manage the optimization process where we run optimization on host device, and run evaluation on target device.\n", - "Search strategy is used to search the optimal optimization among different EPs. In this notebook, we use `joint` as the `execution_order` and `tpe` as the `search_algorithm`. We set the `num_samples` to 1 and `seed` to 0.\n", + "Search strategy is used to search the optimal optimization among different EPs. In this notebook, we use `joint` as the `execution_order` and `tpe` as the `sampler`. We set the `max_samples` to 1 and `seed` to 0.\n", "\n", "```json\n", "\"engine\": {\n", " \"search_strategy\": {\n", " \"execution_order\": \"joint\",\n", - " \"search_algorithm\": \"tpe\",\n", - " \"search_algorithm_config\": {\n", - " \"num_samples\": 1,\n", - " \"seed\": 0\n", - " }\n", + " \"sampler\": \"tpe\",\n", + " \"max_samples\": 1,\n", + " \"seed\": 0\n", " },\n", " \"evaluator\": \"common_evaluator\",\n", " \"host\": \"local_system\",\n", diff --git a/examples/bert/user_script.py b/examples/bert/user_script.py index 480b57276..f1bd60919 100644 --- a/examples/bert/user_script.py +++ b/examples/bert/user_script.py @@ -2,13 +2,17 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- + +# pylint: disable=attribute-defined-outside-init, protected-access, ungrouped-imports +# This file is only used by bert_inc_ptq_cpu, bert_qat_customized_train_loop_cpu + import copy import numpy as np import torch import torchmetrics import transformers -from datasets import load_dataset, load_metric +from datasets import load_dataset from datasets.utils import logging as datasets_logging from neural_compressor.data import DefaultDataLoader from torch.utils.data import Dataset @@ -26,12 +30,14 @@ from olive.data.registry import Registry from olive.model import OliveModelHandler +try: + from datasets import load_metric +except ImportError: + from evaluate import load as load_metric + datasets_logging.disable_progress_bar() datasets_logging.set_verbosity_error() -# pylint: disable=attribute-defined-outside-init, protected-access -# This file is only used by bert_inc_ptq_cpu, bert_qat_customized_train_loop_cpu - # ------------------------------------------------------------------------- # Common Dataset diff --git a/examples/deberta/deberta.json b/examples/deberta/deberta.json index 025a0c2cf..56c8a60ae 100644 --- a/examples/deberta/deberta.json +++ b/examples/deberta/deberta.json @@ -46,7 +46,7 @@ "quantization": { "type": "OnnxQuantization", "data_config": "glue_mnli_matched" }, "session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "glue_mnli_matched" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "tpe", "num_samples": 3, "seed": 0 }, + "search_strategy": { "execution_order": "joint", "sampler": "tpe", "max_samples": 3, "seed": 0 }, "evaluator": "common_evaluator", "cache_dir": "cache", "output_dir": "models/microsoft-deberta" diff --git a/examples/directml/llm/config_llm.json b/examples/directml/llm/config_llm.json index a052fc3dc..8b1178df6 100644 --- a/examples/directml/llm/config_llm.json +++ b/examples/directml/llm/config_llm.json @@ -90,7 +90,7 @@ } } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "evaluator": "common_evaluator", "evaluate_input_model": false, "host": "local_system", diff --git a/examples/directml/stable_diffusion_xl/config_text_encoder.json b/examples/directml/stable_diffusion_xl/config_text_encoder.json index c131054a9..a205209f7 100644 --- a/examples/directml/stable_diffusion_xl/config_text_encoder.json +++ b/examples/directml/stable_diffusion_xl/config_text_encoder.json @@ -112,7 +112,6 @@ "keep_io_types": true } }, - "pass_flows": [ [ "convert", "optimize" ] ], "evaluator": "common_evaluator", "evaluate_input_model": false, "host": "local_system", diff --git a/examples/directml/stable_diffusion_xl/config_text_encoder_2.json b/examples/directml/stable_diffusion_xl/config_text_encoder_2.json index dffc5de9f..a26dd0833 100644 --- a/examples/directml/stable_diffusion_xl/config_text_encoder_2.json +++ b/examples/directml/stable_diffusion_xl/config_text_encoder_2.json @@ -152,7 +152,6 @@ "keep_io_types": true } }, - "pass_flows": [ [ "convert", "optimize" ] ], "evaluator": "common_evaluator", "evaluate_input_model": false, "host": "local_system", diff --git a/examples/directml/stable_diffusion_xl/config_unet.json b/examples/directml/stable_diffusion_xl/config_unet.json index 6e38b41c0..9b2c513af 100644 --- a/examples/directml/stable_diffusion_xl/config_unet.json +++ b/examples/directml/stable_diffusion_xl/config_unet.json @@ -96,7 +96,6 @@ "keep_io_types": true } }, - "pass_flows": [ [ "convert", "optimize" ] ], "evaluator": "common_evaluator", "evaluate_input_model": false, "host": "local_system", diff --git a/examples/directml/stable_diffusion_xl/config_vae_decoder.json b/examples/directml/stable_diffusion_xl/config_vae_decoder.json index ad22ac2b9..c6c3c493e 100644 --- a/examples/directml/stable_diffusion_xl/config_vae_decoder.json +++ b/examples/directml/stable_diffusion_xl/config_vae_decoder.json @@ -108,7 +108,6 @@ "use_gpu": true } }, - "pass_flows": [ [ "convert", "optimize" ] ], "evaluator": "common_evaluator", "evaluate_input_model": false, "host": "local_system", diff --git a/examples/directml/stable_diffusion_xl/config_vae_encoder.json b/examples/directml/stable_diffusion_xl/config_vae_encoder.json index ef889387f..fc4c8fd1d 100644 --- a/examples/directml/stable_diffusion_xl/config_vae_encoder.json +++ b/examples/directml/stable_diffusion_xl/config_vae_encoder.json @@ -87,7 +87,6 @@ "keep_io_types": true } }, - "pass_flows": [ [ "convert", "optimize" ] ], "evaluator": "common_evaluator", "evaluate_input_model": false, "host": "local_system", diff --git a/examples/directml/stable_diffusion_xl/stable_diffusion_xl.py b/examples/directml/stable_diffusion_xl/stable_diffusion_xl.py index 8bb69dca4..ac5d592ea 100644 --- a/examples/directml/stable_diffusion_xl/stable_diffusion_xl.py +++ b/examples/directml/stable_diffusion_xl/stable_diffusion_xl.py @@ -7,6 +7,7 @@ import shutil import sys import warnings +from collections import OrderedDict from pathlib import Path from typing import Dict @@ -289,9 +290,10 @@ def run_inference( def update_config_with_provider(config: Dict, provider: str, is_fp16: bool) -> Dict: + used_passes = {} if provider == "dml": # DirectML EP is the default, so no need to update config. - return config + used_passes = {"convert", "optimize"} elif provider == "cuda": if version.parse(OrtVersion) < version.parse("1.17.0"): # disable skip_group_norm fusion since there is a shape inference bug which leads to invalid models @@ -299,12 +301,14 @@ def update_config_with_provider(config: Dict, provider: str, is_fp16: bool) -> D # keep model fully in fp16 if use_fp16_fixed_vae is set if is_fp16: config["passes"]["optimize_cuda"].update({"float16": True, "keep_io_types": False}) - config["pass_flows"] = [["convert", "optimize_cuda"]] + used_passes = {"convert", "optimize_cuda"} config["systems"]["local_system"]["accelerators"][0]["execution_providers"] = ["CUDAExecutionProvider"] - return config else: raise ValueError(f"Unsupported provider: {provider}") + config["passes"] = OrderedDict([(k, v) for k, v in config["passes"].items() if k in used_passes]) + return config + def optimize( model_id: str, diff --git a/examples/llama2/llama2.py b/examples/llama2/llama2.py index 388655764..ef5abfc30 100644 --- a/examples/llama2/llama2.py +++ b/examples/llama2/llama2.py @@ -48,8 +48,14 @@ def get_args(raw_args): help="Whether to use GQA(grouped query attention) instead of MHA(multi-head attention). Only supported on gpu.", ) parser.add_argument( - "--use_gptq", - action="store_true", + "--quantize", + choices=["gptq", "blockwise", "dynamic"], + required=False, + help="Quantization method to use.", + ) + parser.add_argument( + "--precision", + choices=["fp16", "fp32"], required=False, help="Whether to use GPTQ quantization instead of RTN quantization. Only supported on gpu.", ) @@ -103,6 +109,12 @@ def main(raw_args=None): if args.use_gqa and not args.gpu: raise ValueError("GQA is only supported on gpu.") + if args.gpu and args.quantize == "dynamic": + raise ValueError("Dynamic quantization is only supported on CPU.") + + if args.gptq and not args.gpu: + raise ValueError("GPTQ is only supported on gpu.") + if args.qlora: template_json, config_name = get_qlora_config() else: @@ -173,17 +185,24 @@ def get_general_config(args): gqa = "gqa" if args.use_gqa else "mha" config_name = f"llama2_{device}_{gqa}" - # add pass flows - if not args.use_gptq: - template_json["pass_flows"] = [flow for flow in SUPPORTED_WORKFLOWS[device] if "gptq" not in flow[0]] - else: - template_json["pass_flows"] = [flow for flow in SUPPORTED_WORKFLOWS[device] if "gptq" in flow[0]] + precision = args.precision or ("fp16" if args.gpu else "fp32") + + # add pass names + used_passes = {"conversion_merged"} + used_passes.add("transformers_optimization_fp16" if precision == "fp16" else "transformers_optimization_fp32") + + if args.quantize == "gptq": + used_passes.add("gptq_quant_int4") + auto_gptq_logger = logging.getLogger("auto_gptq") auto_gptq_logger.addHandler(logging.StreamHandler(sys.stdout)) auto_gptq_logger.setLevel(logging.INFO) + elif args.quantize == "blockwise": + used_passes.add("blockwise_quant_int4") + elif args.quantize == "dynamic": + used_passes.add("onnx_dynamic_quant_int8") # remove unused passes and set gqa related configs - used_passes = {pass_name for pass_flow in SUPPORTED_WORKFLOWS[device] for pass_name in pass_flow} for pass_name in list(template_json["passes"].keys()): if pass_name not in used_passes: del template_json["passes"][pass_name] diff --git a/examples/llama2/llama2_lmeval.json b/examples/llama2/llama2_lmeval.json index b6630d25d..54ee88e5f 100644 --- a/examples/llama2/llama2_lmeval.json +++ b/examples/llama2/llama2_lmeval.json @@ -57,7 +57,7 @@ } }, "auto_optimizer_config": { "opt_level": 0, "disable_auto_optimizer": true, "precision": "fp16" }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "evaluator": "evaluator", "host": "local_system", "target": "local_system", diff --git a/examples/llama2/llama2_model_builder.py b/examples/llama2/llama2_model_builder.py index f613f2c37..40263ba74 100644 --- a/examples/llama2/llama2_model_builder.py +++ b/examples/llama2/llama2_model_builder.py @@ -5,6 +5,7 @@ import argparse import json +from collections import OrderedDict from olive.workflows import run as olive_run @@ -38,10 +39,8 @@ def main(raw_args=None): template_json = json.loads(template_json_str) # add pass flows - if args.metadata_only: - template_json["pass_flows"] = [["conversion", "metadata"]] - else: - template_json["pass_flows"] = [["builder", "session_params_tuning"]] + used_passes = {"conversion", "metadata"} if args.metadata_only else {"builder", "session_params_tuning"} + template_json["passes"] = OrderedDict([(k, v) for k, v in template_json["passes"].items() if k in used_passes]) template_json["output_dir"] = f"models/{model_name}" # dump config diff --git a/examples/llama2/notebook/llama2_multiep/config_cpu.json b/examples/llama2/notebook/llama2_multiep/config_cpu.template.json similarity index 100% rename from examples/llama2/notebook/llama2_multiep/config_cpu.json rename to examples/llama2/notebook/llama2_multiep/config_cpu.template.json diff --git a/examples/llama2/notebook/llama2_multiep/config_gpu.json b/examples/llama2/notebook/llama2_multiep/config_gpu.template.json similarity index 100% rename from examples/llama2/notebook/llama2_multiep/config_gpu.json rename to examples/llama2/notebook/llama2_multiep/config_gpu.template.json diff --git a/examples/llama2/notebook/llama2_multiep/config_multi_ep.json b/examples/llama2/notebook/llama2_multiep/config_multi_ep.template.json similarity index 100% rename from examples/llama2/notebook/llama2_multiep/config_multi_ep.json rename to examples/llama2/notebook/llama2_multiep/config_multi_ep.template.json diff --git a/examples/llama2/notebook/llama2_multiep/llama2.py b/examples/llama2/notebook/llama2_multiep/llama2.py new file mode 100644 index 000000000..069ee8dbb --- /dev/null +++ b/examples/llama2/notebook/llama2_multiep/llama2.py @@ -0,0 +1,32 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import argparse +import json +from pathlib import Path + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--device", + choices=["cpu", "gpu", "multi_ep"], + help="Device to use", + ) + parser.add_argument( + "--quantize", + action="store_true", + help="If set, run transformers optimization pass", + ) + args = parser.parse_args() + + input_filename = f"config_{args.device}.template.json" + with Path(input_filename).open("r") as f: + config = json.load(f) + + if not args.quantize: + del config["passes"]["blockwise_quant_int4"] + + output_filename = input_filename.replace(".template", "") + with Path(output_filename).open("w") as strm: + json.dump(config, fp=strm, indent=4) diff --git a/examples/mistral/mistral.py b/examples/mistral/mistral.py index 39ec7dd2b..2fdaa8450 100644 --- a/examples/mistral/mistral.py +++ b/examples/mistral/mistral.py @@ -108,7 +108,7 @@ def main(raw_args=None): optimized_model_dir = ( script_dir / config["output_dir"] - / "-".join(config["pass_flows"][0]) + / "-".join(config["passes"].keys()) / f"{config['output_name']}_{accelerator}-{ep_header}_model" ) diff --git a/examples/mistral/mistral_fp16.json b/examples/mistral/mistral_fp16.json index 926d13e6e..364e51630 100644 --- a/examples/mistral/mistral_fp16.json +++ b/examples/mistral/mistral_fp16.json @@ -45,7 +45,6 @@ "enable_profiling": false } }, - "pass_flows": [ [ "convert", "optimize", "session_params_tuning" ] ], "evaluate_input_model": false, "evaluator": "common_evaluator", "host": "local_system", diff --git a/examples/mistral/mistral_int4.json b/examples/mistral/mistral_int4.json index adfe1949f..16f89acbd 100644 --- a/examples/mistral/mistral_int4.json +++ b/examples/mistral/mistral_int4.json @@ -57,7 +57,6 @@ "all_tensors_to_one_file": true } }, - "pass_flows": [ [ "convert", "optimize", "quantization" ] ], "evaluate_input_model": false, "evaluator": "common_evaluator", "host": "local_system", diff --git a/examples/mobilenet/prepare_config.py b/examples/mobilenet/prepare_config.py index db299bdea..756eed3b3 100644 --- a/examples/mobilenet/prepare_config.py +++ b/examples/mobilenet/prepare_config.py @@ -2,27 +2,32 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import argparse import json import platform +from collections import OrderedDict from pathlib import Path from olive.common.constants import OS -def raw_qnn_config(): +def raw_qnn_config(mode: str): # pylint: disable=redefined-outer-name - with Path("./raw_qnn_sdk_template.json").open("r") as f: + with Path("raw_qnn_sdk_template.json").open("r") as f: raw_qnn_config = json.load(f) sys_platform = platform.system() + if mode == "convert": + used_passes = {"converter", "build_model_lib"} + elif mode == "quantize": + used_passes = {"quantization", "build_model_lib"} + else: + raise ValueError(f"Unsupported mode: {mode}") + if sys_platform == OS.LINUX: - raw_qnn_config["passes"]["qnn_context_binary"] = { - "type": "QNNContextBinaryGenerator", - "backend": "libQnnHtp.so", - } - raw_qnn_config["pass_flows"].append(["converter", "build_model_lib", "qnn_context_binary"]) + used_passes.add("qnn_context_binary") raw_qnn_config["passes"]["build_model_lib"]["lib_targets"] = "x86_64-linux-clang" elif sys_platform == OS.WINDOWS: raw_qnn_config["passes"]["build_model_lib"]["lib_targets"] = "x86_64-windows-msvc" @@ -33,11 +38,20 @@ def raw_qnn_config(): elif sys_platform == OS.LINUX: metric_config["user_config"]["inference_settings"]["qnn"]["backend"] = "libQnnCpu" + raw_qnn_config["passes"] = OrderedDict([(k, v) for k, v in raw_qnn_config["passes"].items() if k in used_passes]) + with Path("raw_qnn_sdk_config.json").open("w") as f: json_str = json.dumps(raw_qnn_config, indent=4) - input_file_path = Path("./data/eval/input_order.txt").resolve().as_posix() + input_file_path = Path("data/eval/input_order.txt").resolve().as_posix() f.write(json_str.replace("", str(input_file_path))) if __name__ == "__main__": - raw_qnn_config() + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", + choices=["convert", "quantize"], + help="Mode selection", + ) + args = parser.parse_args() + raw_qnn_config(args.mode) diff --git a/examples/mobilenet/raw_qnn_sdk_template.json b/examples/mobilenet/raw_qnn_sdk_template.json index c94d40837..e4f5d1eeb 100644 --- a/examples/mobilenet/raw_qnn_sdk_template.json +++ b/examples/mobilenet/raw_qnn_sdk_template.json @@ -46,9 +46,9 @@ "dynamic_shape_to_fixed": { "type": "DynamicToFixedShape", "dim_param": [ "batch_size" ], "dim_value": [ 1 ] }, "converter": { "type": "QNNConversion" }, "quantization": { "type": "QNNConversion", "extra_args": "--input_list " }, - "build_model_lib": { "type": "QNNModelLibGenerator", "lib_targets": "x86_64-linux-clang" } + "build_model_lib": { "type": "QNNModelLibGenerator", "lib_targets": "x86_64-linux-clang" }, + "qnn_context_binary": { "type": "QNNContextBinaryGenerator", "backend": "libQnnHtp.so" } }, - "pass_flows": [ [ "converter", "build_model_lib" ], [ "quantization", "build_model_lib" ] ], "log_severity_level": 0, "host": "local_system", "target": "local_system", diff --git a/examples/phi2/phi2.py b/examples/phi2/phi2.py index 1f9370a6a..bd2ef06d7 100644 --- a/examples/phi2/phi2.py +++ b/examples/phi2/phi2.py @@ -186,25 +186,23 @@ def main(raw_args=None): legacy_optimization_setting(template_json) # add pass flows - pass_flows = [[]] + used_passes = {} if args.finetune_method: - pass_flows[0].append(args.finetune_method) + used_passes.add(args.finetune_method) # torch fine tuning does not require execution provider, just set it to CUDAExecutionProvider update_accelerator(template_json, "gpu") if args.slicegpt: - pass_flows[0].extend(SUPPORTED_WORKFLOWS["slicegpt"][0]) + used_passes.update(SUPPORTED_WORKFLOWS["slicegpt"][0]) update_accelerator(template_json, "gpu") del template_json["input_model"]["io_config"] if model_type: - pass_flows[0].extend(SUPPORTED_WORKFLOWS[model_type][0]) - template_json["pass_flows"] = pass_flows + used_passes.update(SUPPORTED_WORKFLOWS[model_type][0]) if args.optimum_optimization: legacy_optimization_setting(template_json) - for pass_flow in template_json["pass_flows"]: - pass_flow[0] = "optimum_convert" - if "session_params_tuning" in pass_flow: - pass_flow.remove("session_params_tuning") + used_passes.pop("convert", None) + used_passes.pop("session_params_tuning", None) + used_passes.add("optimum_convert") if "cuda" in model_type: update_accelerator(template_json, "gpu") @@ -217,7 +215,6 @@ def main(raw_args=None): template_json["evaluate_input_model"] = False del template_json["evaluator"] - used_passes = {pass_name for pass_flow in pass_flows for pass_name in pass_flow} for pass_name in list(template_json["passes"].keys()): if pass_name not in used_passes: del template_json["passes"][pass_name] diff --git a/examples/phi3/phi3.py b/examples/phi3/phi3.py index 16277dd82..4d2e114b7 100644 --- a/examples/phi3/phi3.py +++ b/examples/phi3/phi3.py @@ -191,7 +191,10 @@ def use_passes(template_json, *passes): else: del template_json["data_configs"] - template_json["pass_flows"] = [passes] + for pass_name in set(template_json["passes"].keys()): + if pass_name not in passes: + template_json["passes"].pop(pass_name, None) + return template_json @@ -223,7 +226,7 @@ def generate_config(args): if args.tune_session_params: passes_to_use.append("tune_session_params") - template_json["search_strategy"] = {"execution_order": "joint", "search_algorithm": "exhaustive"} + template_json["search_strategy"] = {"execution_order": "joint", "sampler": "sequential"} template_json["evaluator"] = "common_evaluator" else: del template_json["evaluators"] diff --git a/examples/phi3/phi3_template.json b/examples/phi3/phi3_template.json index e5508d878..04e7151cf 100644 --- a/examples/phi3/phi3_template.json +++ b/examples/phi3/phi3_template.json @@ -105,7 +105,6 @@ "execution_mode_list": [ 0, 1 ] } }, - "pass_flows": [ [ "" ] ], "cache_dir": "cache", "output_dir": "models", "host": "local_system", diff --git a/examples/resnet/resnet_dynamic_ptq_cpu.json b/examples/resnet/resnet_dynamic_ptq_cpu.json index 17a2c65d6..9351eadf8 100644 --- a/examples/resnet/resnet_dynamic_ptq_cpu.json +++ b/examples/resnet/resnet_dynamic_ptq_cpu.json @@ -56,7 +56,7 @@ "onnx_dynamic_quantization": { "type": "OnnxDynamicQuantization", "weight_type": "QUInt8" }, "session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "cifar10_data_config" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "host": "local_system", "target": "local_system", "evaluator": "common_evaluator", diff --git a/examples/resnet/resnet_ptq_cpu.json b/examples/resnet/resnet_ptq_cpu.json index 1c0922d3e..fdfaa07d4 100644 --- a/examples/resnet/resnet_ptq_cpu.json +++ b/examples/resnet/resnet_ptq_cpu.json @@ -72,7 +72,7 @@ }, "session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "cifar10_data_config" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "host": "local_system", "target": "local_system", "evaluator": "common_evaluator", diff --git a/examples/resnet/resnet_ptq_cpu_aml_dataset.json b/examples/resnet/resnet_ptq_cpu_aml_dataset.json index 1ee41bb03..7a05b4f38 100644 --- a/examples/resnet/resnet_ptq_cpu_aml_dataset.json +++ b/examples/resnet/resnet_ptq_cpu_aml_dataset.json @@ -81,7 +81,7 @@ }, "session_params_tuning": { "type": "OrtSessionParamsTuning", "data_config": "cifar10_data_config" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "evaluator": "common_evaluator", "cache_dir": "cache", "output_dir": "models/resnet_ptq_cpu" diff --git a/examples/resnet/resnet_static_ptq_cpu.json b/examples/resnet/resnet_static_ptq_cpu.json index 65e7b2950..771d8e5f4 100644 --- a/examples/resnet/resnet_static_ptq_cpu.json +++ b/examples/resnet/resnet_static_ptq_cpu.json @@ -65,7 +65,7 @@ "data_config": "session_params_tuning_data_config" } }, - "search_strategy": { "execution_order": "joint", "search_algorithm": "exhaustive" }, + "search_strategy": { "execution_order": "joint", "sampler": "sequential" }, "host": "local_system", "target": "local_system", "evaluator": "common_evaluator", diff --git a/examples/stable_diffusion/config_safety_checker.json b/examples/stable_diffusion/config_safety_checker.json index fc36d6c8c..33bb6fcd4 100644 --- a/examples/stable_diffusion/config_safety_checker.json +++ b/examples/stable_diffusion/config_safety_checker.json @@ -88,7 +88,6 @@ "keep_io_types": false } }, - "pass_flows": [ [ "convert", "optimize" ] ], "log_severity_level": 0, "evaluator": "common_evaluator", "evaluate_input_model": false, diff --git a/examples/stable_diffusion/config_text_encoder.json b/examples/stable_diffusion/config_text_encoder.json index 3747e4f80..baa6de7d3 100644 --- a/examples/stable_diffusion/config_text_encoder.json +++ b/examples/stable_diffusion/config_text_encoder.json @@ -85,7 +85,6 @@ "keep_io_types": false } }, - "pass_flows": [ [ "convert", "optimize" ] ], "log_severity_level": 0, "evaluator": "common_evaluator", "evaluate_input_model": false, diff --git a/examples/stable_diffusion/config_unet.json b/examples/stable_diffusion/config_unet.json index dfad5e88c..2ee455b9b 100644 --- a/examples/stable_diffusion/config_unet.json +++ b/examples/stable_diffusion/config_unet.json @@ -100,7 +100,6 @@ "keep_io_types": false } }, - "pass_flows": [ [ "convert", "optimize" ] ], "log_severity_level": 0, "evaluator": "common_evaluator", "evaluate_input_model": false, diff --git a/examples/stable_diffusion/config_vae_decoder.json b/examples/stable_diffusion/config_vae_decoder.json index 362f49cb9..5288528c9 100644 --- a/examples/stable_diffusion/config_vae_decoder.json +++ b/examples/stable_diffusion/config_vae_decoder.json @@ -92,7 +92,6 @@ "keep_io_types": false } }, - "pass_flows": [ [ "convert", "optimize" ] ], "log_severity_level": 0, "evaluator": "common_evaluator", "evaluate_input_model": false, diff --git a/examples/stable_diffusion/config_vae_encoder.json b/examples/stable_diffusion/config_vae_encoder.json index 61e46d298..a038ecdcb 100644 --- a/examples/stable_diffusion/config_vae_encoder.json +++ b/examples/stable_diffusion/config_vae_encoder.json @@ -87,7 +87,6 @@ "keep_io_types": false } }, - "pass_flows": [ [ "convert", "optimize" ] ], "log_severity_level": 0, "evaluator": "common_evaluator", "evaluate_input_model": false, diff --git a/examples/stable_diffusion/sd_utils/ort.py b/examples/stable_diffusion/sd_utils/ort.py index 01eab8923..f5852c442 100644 --- a/examples/stable_diffusion/sd_utils/ort.py +++ b/examples/stable_diffusion/sd_utils/ort.py @@ -12,18 +12,30 @@ from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline from onnxruntime import __version__ as OrtVersion from packaging import version -from sd_utils import config +from sd_utils import config as sd_config from olive.model import ONNXModelHandler # ruff: noqa: TID252, T201 +def update_dml_config(config_dml: Dict): + used_passes = {"convert", "optimize"} + for pass_name in set(config_dml["passes"].keys()): + if pass_name not in used_passes: + config_dml["passes"].pop(pass_name, None) + config_dml["systems"]["local_system"]["accelerators"][0]["execution_providers"] = ["DmlExecutionProvider"] + return config_dml + + def update_cuda_config(config_cuda: Dict): if version.parse(OrtVersion) < version.parse("1.17.0"): # disable skip_group_norm fusion since there is a shape inference bug which leads to invalid models config_cuda["passes"]["optimize_cuda"]["optimization_options"] = {"enable_skip_group_norm": False} - config_cuda["pass_flows"] = [["convert", "optimize_cuda"]] + used_passes = {"convert", "optimize_cuda"} + for pass_name in set(config_cuda["passes"].keys()): + if pass_name not in used_passes: + config_cuda["passes"].pop(pass_name, None) config_cuda["systems"]["local_system"]["accelerators"][0]["execution_providers"] = ["CUDAExecutionProvider"] return config_cuda @@ -134,8 +146,8 @@ def get_ort_pipeline(model_dir, common_args, ort_args, guidance_scale): batch_size = common_args.batch_size image_size = common_args.image_size provider = common_args.provider - vae_sample_size = config.vae_sample_size - unet_sample_size = config.unet_sample_size + vae_sample_size = sd_config.vae_sample_size + unet_sample_size = sd_config.unet_sample_size if static_dims: hidden_batch_size = batch_size if (guidance_scale == 0.0) else batch_size * 2 diff --git a/examples/stable_diffusion/sd_utils/ov.py b/examples/stable_diffusion/sd_utils/ov.py index 95905227f..5c17e228c 100644 --- a/examples/stable_diffusion/sd_utils/ov.py +++ b/examples/stable_diffusion/sd_utils/ov.py @@ -442,7 +442,7 @@ def get_timesteps(self, num_inference_steps: int, strength: float): def update_ov_config(config: Dict): - config["pass_flows"] = [["ov_convert"]] + config["passes"] = {"ov_convert": config["passes"]["ov_convert"]} config["search_strategy"] = False config["systems"]["local_system"]["accelerators"][0]["execution_providers"] = ["CPUExecutionProvider"] del config["evaluators"] diff --git a/examples/stable_diffusion/stable_diffusion.py b/examples/stable_diffusion/stable_diffusion.py index 2cd45826a..8cc0f6a40 100644 --- a/examples/stable_diffusion/stable_diffusion.py +++ b/examples/stable_diffusion/stable_diffusion.py @@ -179,8 +179,9 @@ def on_generate_click(): def update_config_with_provider(config: Dict, provider: str): if provider == "dml": - # DirectML EP is the default, so no need to update config. - return config + from sd_utils.ort import update_dml_config + + return update_dml_config(config) elif provider == "cuda": from sd_utils.ort import update_cuda_config diff --git a/examples/test/__init__.py b/examples/test/__init__.py index e69de29bb..862c45ce3 100644 --- a/examples/test/__init__.py +++ b/examples/test/__init__.py @@ -0,0 +1,4 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- diff --git a/examples/test/azureml/__init__.py b/examples/test/azureml/__init__.py index e69de29bb..862c45ce3 100644 --- a/examples/test/azureml/__init__.py +++ b/examples/test/azureml/__init__.py @@ -0,0 +1,4 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- diff --git a/examples/test/azureml/test_bert_ptq_cpu_aml.py b/examples/test/azureml/test_bert_ptq_cpu_aml.py index 0e9384ebf..f01378d61 100644 --- a/examples/test/azureml/test_bert_ptq_cpu_aml.py +++ b/examples/test/azureml/test_bert_ptq_cpu_aml.py @@ -27,7 +27,7 @@ def setup(): ], ) def test_bert(olive_test_knob): - # olive_config: (config_json_path, search_algorithm, execution_order, system) + # olive_config: (config_json_path, sampler, execution_order, system) # bert_ptq_cpu.json: use huggingface model id # bert_ptq_cpu_aml.json: use aml model path from olive.workflows import run as olive_run diff --git a/examples/test/azureml/test_llama2.py b/examples/test/azureml/test_llama2.py index 85da88e71..1e5c17fa2 100644 --- a/examples/test/azureml/test_llama2.py +++ b/examples/test/azureml/test_llama2.py @@ -22,15 +22,15 @@ def setup(): os.chdir(get_example_dir("llama2")) -@pytest.mark.parametrize("search_algorithm", [False]) +@pytest.mark.parametrize("sampler", [False]) @pytest.mark.parametrize("execution_order", [None]) @pytest.mark.parametrize("system", ["local_system"]) @pytest.mark.parametrize("cache_config", [None, {"account_name": account_name, "container_name": container_name}]) @pytest.mark.parametrize("olive_json", ["llama2_qlora.json"]) -def test_llama2(search_algorithm, execution_order, system, cache_config, olive_json): +def test_llama2(sampler, execution_order, system, cache_config, olive_json): from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system, is_gpu=False, hf_token=True) + olive_config = patch_config(olive_json, sampler, execution_order, system, is_gpu=False, hf_token=True) # reduce qlora steps for faster test olive_config["passes"]["f"]["training_args"]["max_steps"] = 5 diff --git a/examples/test/azureml/test_resnet_ptq_cpu_aml.py b/examples/test/azureml/test_resnet_ptq_cpu_aml.py index 8550a0e24..402eff8d6 100644 --- a/examples/test/azureml/test_resnet_ptq_cpu_aml.py +++ b/examples/test/azureml/test_resnet_ptq_cpu_aml.py @@ -23,7 +23,7 @@ def setup(): retry_func(run_subprocess, kwargs={"cmd": "python prepare_model_data.py", "check": True}) -@pytest.mark.parametrize("search_algorithm", ["random"]) +@pytest.mark.parametrize("sampler", ["random"]) @pytest.mark.parametrize("execution_order", ["pass-by-pass"]) @pytest.mark.parametrize("system", ["aml_system"]) @pytest.mark.parametrize( @@ -40,10 +40,10 @@ def setup(): version.parse(OrtVersion) == version.parse("1.16.0"), reason="resnet is not supported in ORT 1.16.0 caused by https://github.com/microsoft/onnxruntime/issues/17627", ) -def test_resnet(search_algorithm, execution_order, system, olive_json): +def test_resnet(sampler, execution_order, system, olive_json): from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system) + olive_config = patch_config(olive_json, sampler, execution_order, system) footprint = olive_run(olive_config, tempdir=os.environ.get("OLIVE_TEMPDIR", None)) check_output(footprint) diff --git a/examples/test/local/__init__.py b/examples/test/local/__init__.py index e69de29bb..862c45ce3 100644 --- a/examples/test/local/__init__.py +++ b/examples/test/local/__init__.py @@ -0,0 +1,4 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- diff --git a/examples/test/local/test_bert_cuda_gpu.py b/examples/test/local/test_bert_cuda_gpu.py index a5c6b0d82..698d72695 100644 --- a/examples/test/local/test_bert_cuda_gpu.py +++ b/examples/test/local/test_bert_cuda_gpu.py @@ -6,6 +6,8 @@ import pytest +from olive.common.utils import retry_func, run_subprocess + from ..utils import check_output, get_example_dir, patch_config @@ -16,14 +18,17 @@ def setup(): @pytest.mark.skip(reason="Disable failing tests") -@pytest.mark.parametrize("search_algorithm", ["tpe"]) +@pytest.mark.parametrize("sampler", ["tpe"]) @pytest.mark.parametrize("execution_order", ["joint", "pass-by-pass"]) @pytest.mark.parametrize("system", ["local_system"]) @pytest.mark.parametrize("olive_json", ["bert_cuda_gpu.json"]) -def test_bert(search_algorithm, execution_order, system, olive_json): +@pytest.mark.parametrize("cmd_args", ["", "--optimize"]) +def test_bert(sampler, execution_order, system, olive_json, cmd_args): from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system, is_gpu=True) + retry_func(run_subprocess, kwargs={"cmd": f"python bert.py {cmd_args}", "check": True}) + + olive_config = patch_config(olive_json, sampler, execution_order, system, is_gpu=True) olive_config["passes"]["session_params_tuning"]["enable_cuda_graph"] = False footprint = olive_run(olive_config, tempdir=os.environ.get("OLIVE_TEMPDIR", None)) diff --git a/examples/test/local/test_bert_ptq_cpu.py b/examples/test/local/test_bert_ptq_cpu.py index ee2975c44..9c2e9bdab 100644 --- a/examples/test/local/test_bert_ptq_cpu.py +++ b/examples/test/local/test_bert_ptq_cpu.py @@ -15,14 +15,14 @@ def setup(): os.chdir(get_example_dir("bert")) -@pytest.mark.parametrize("search_algorithm", ["tpe"]) +@pytest.mark.parametrize("sampler", ["tpe"]) @pytest.mark.parametrize("execution_order", ["joint"]) @pytest.mark.parametrize("system", ["local_system"]) @pytest.mark.parametrize("olive_json", ["bert_ptq_cpu.json"]) -def test_bert(search_algorithm, execution_order, system, olive_json): +def test_bert(sampler, execution_order, system, olive_json): from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system) + olive_config = patch_config(olive_json, sampler, execution_order, system) footprint = olive_run(olive_config, tempdir=os.environ.get("OLIVE_TEMPDIR", None)) check_output(footprint) diff --git a/examples/test/local/test_bert_ptq_cpu_docker.py b/examples/test/local/test_bert_ptq_cpu_docker.py index 1f812cd9c..27958daab 100644 --- a/examples/test/local/test_bert_ptq_cpu_docker.py +++ b/examples/test/local/test_bert_ptq_cpu_docker.py @@ -18,17 +18,17 @@ def setup(): os.chdir(get_example_dir("bert")) -@pytest.mark.parametrize("search_algorithm", ["tpe"]) +@pytest.mark.parametrize("sampler", ["tpe"]) @pytest.mark.parametrize("execution_order", ["joint"]) @pytest.mark.parametrize("system", ["docker_system"]) @pytest.mark.parametrize("olive_json", ["bert_ptq_cpu.json"]) -def test_bert(search_algorithm, execution_order, system, olive_json): +def test_bert(sampler, execution_order, system, olive_json): if system == "docker_system" and platform.system() == OS.WINDOWS: pytest.skip("Skip Linux containers on Windows host test case.") from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system) + olive_config = patch_config(olive_json, sampler, execution_order, system) footprint = olive_run(olive_config, tempdir=os.environ.get("OLIVE_TEMPDIR", None)) check_output(footprint) diff --git a/examples/test/local/test_llama2.py b/examples/test/local/test_llama2.py index 2ec2b1691..ee47a475d 100644 --- a/examples/test/local/test_llama2.py +++ b/examples/test/local/test_llama2.py @@ -16,16 +16,16 @@ def setup(): os.chdir(get_example_dir("llama2")) -@pytest.mark.parametrize("search_algorithm", [False]) +@pytest.mark.parametrize("sampler", [False]) @pytest.mark.parametrize("execution_order", [None]) @pytest.mark.parametrize("system", ["local_system"]) @pytest.mark.parametrize("olive_json", ["llama2_qlora.json"]) -def test_llama2(search_algorithm, execution_order, system, olive_json): +def test_llama2(sampler, execution_order, system, olive_json): from onnxruntime import __version__ as ort_version from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system) + olive_config = patch_config(olive_json, sampler, execution_order, system) # replace meta-llama with open-llama version of the model # doesn't require login diff --git a/examples/test/local/test_qnn_tooklit.py b/examples/test/local/test_qnn_tooklit.py index 2fe96d1bf..4886e6faa 100644 --- a/examples/test/local/test_qnn_tooklit.py +++ b/examples/test/local/test_qnn_tooklit.py @@ -24,7 +24,7 @@ def setup(self, tmp_path): os.environ["QNN_SDK_ROOT"] = f"{download_qc_toolkit(tmp_path, 'qnn')}/opt/qcom/aistack" os.environ["CONDA_INSTALLER"] = download_conda_installer(tmp_path) - def _setup_resource(self, use_olive_env): + def _setup_resource(self, use_olive_env, mode): """Setups any state specific to the execution of the given module.""" example_dir = get_example_dir("mobilenet") os.chdir(example_dir) @@ -68,13 +68,14 @@ def _setup_resource(self, use_olive_env): + os.environ["PATH"] ) retry_func(run_subprocess, kwargs={"cmd": "python download_files.py", "check": True}) - retry_func(run_subprocess, kwargs={"cmd": "python prepare_config.py", "check": True}) + retry_func(run_subprocess, kwargs={"cmd": f"python prepare_config.py --mode {mode}", "check": True}) @pytest.mark.parametrize("use_olive_env", [True, False]) - def test_mobilenet_qnn(self, use_olive_env): + @pytest.mark.parametrize("mode", ["convert", "quantize"]) + def test_mobilenet_qnn(self, use_olive_env, mode): from olive.workflows import run as olive_run - self._setup_resource(use_olive_env) + self._setup_resource(use_olive_env, mode) footprint = olive_run("raw_qnn_sdk_config.json", tempdir=os.environ.get("OLIVE_TEMPDIR", None)) check_output(footprint) diff --git a/examples/test/local/test_resnet_ptq_cpu.py b/examples/test/local/test_resnet_ptq_cpu.py index 0f7528167..5bccfe731 100644 --- a/examples/test/local/test_resnet_ptq_cpu.py +++ b/examples/test/local/test_resnet_ptq_cpu.py @@ -23,7 +23,7 @@ def setup(): retry_func(run_subprocess, kwargs={"cmd": "python prepare_model_data.py", "check": True}) -@pytest.mark.parametrize("search_algorithm", ["random"]) +@pytest.mark.parametrize("sampler", ["random"]) @pytest.mark.parametrize("execution_order", ["pass-by-pass"]) @pytest.mark.parametrize("system", ["local_system"]) @pytest.mark.parametrize( @@ -40,10 +40,10 @@ def setup(): version.parse(OrtVersion) == version.parse("1.16.0"), reason="resnet is not supported in ORT 1.16.0 caused by https://github.com/microsoft/onnxruntime/issues/17627", ) -def test_resnet(search_algorithm, execution_order, system, olive_json): +def test_resnet(sampler, execution_order, system, olive_json): from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system) + olive_config = patch_config(olive_json, sampler, execution_order, system) footprint = olive_run(olive_config, tempdir=os.environ.get("OLIVE_TEMPDIR", None)) check_output(footprint) diff --git a/examples/test/local/test_resnet_qat.py b/examples/test/local/test_resnet_qat.py index afdfc9efc..8c745876e 100644 --- a/examples/test/local/test_resnet_qat.py +++ b/examples/test/local/test_resnet_qat.py @@ -21,16 +21,16 @@ def setup(): retry_func(run_subprocess, kwargs={"cmd": "python prepare_model_data.py", "check": True}) -@pytest.mark.parametrize("search_algorithm", ["random"]) +@pytest.mark.parametrize("sampler", ["random"]) @pytest.mark.parametrize("execution_order", ["pass-by-pass"]) @pytest.mark.parametrize("system", ["local_system"]) @pytest.mark.parametrize( "olive_json", ["resnet_qat_default_train_loop_cpu.json", "resnet_qat_lightning_module_cpu.json"] ) -def test_resnet(search_algorithm, execution_order, system, olive_json): +def test_resnet(sampler, execution_order, system, olive_json): from olive.workflows import run as olive_run - olive_config = patch_config(olive_json, search_algorithm, execution_order, system) + olive_config = patch_config(olive_json, sampler, execution_order, system) footprint = olive_run(olive_config, tempdir=os.environ.get("OLIVE_TEMPDIR", None)) check_output(footprint) diff --git a/examples/test/utils.py b/examples/test/utils.py index 77e67e247..462e07705 100644 --- a/examples/test/utils.py +++ b/examples/test/utils.py @@ -35,7 +35,7 @@ def assert_metrics(footprints): def patch_config( config_json_path: str, - search_algorithm: str, + sampler: str, execution_order: str, system: str, is_gpu: bool = False, @@ -50,15 +50,15 @@ def patch_config( olive_config["clean_cache"] = True # update search strategy - if not search_algorithm: + if not sampler: olive_config["search_strategy"] = False else: olive_config["search_strategy"] = { - "search_algorithm": search_algorithm, + "sampler": sampler, "execution_order": execution_order, } - if search_algorithm in ("random", "tpe"): - olive_config["search_strategy"].update({"num_samples": 3, "seed": 0}) + if sampler in ("random", "tpe"): + olive_config["search_strategy"].update({"max_samples": 3, "seed": 0}) update_azureml_config(olive_config) if system == "aml_system": @@ -76,7 +76,7 @@ def patch_config( # as our docker image is big, we need to reduce the agent size to avoid timeout # for the docker system test, we skip to search for transformers optimization as # it is tested in other olive system tests - olive_config["search_strategy"]["num_samples"] = 2 + olive_config["search_strategy"]["max_samples"] = 2 return olive_config diff --git a/olive/cache.py b/olive/cache.py index 23384adcd..0029c52b8 100644 --- a/olive/cache.py +++ b/olive/cache.py @@ -280,7 +280,7 @@ def get_output_model_id( input_model_id: str, accelerator_spec: "AcceleratorSpec" = None, ): - run_json = self.get_run_json(pass_name, pass_config, input_model_id, accelerator_spec) + run_json = self.get_run_json(pass_name.lower(), pass_config, input_model_id, accelerator_spec) return hash_dict(run_json)[:8] def get_cache_dir(self) -> Path: diff --git a/olive/cli/auto_opt.py b/olive/cli/auto_opt.py index 664fae4cc..3d75469b0 100644 --- a/olive/cli/auto_opt.py +++ b/olive/cli/auto_opt.py @@ -412,7 +412,7 @@ def _get_passes_config(self, config: Dict[str, Any], olive_config: OlivePackageC TEMPLATE = { "input_model": {"type": "HfModel"}, "auto_optimizer_config": {}, - "search_strategy": {"execution_order": "joint", "search_algorithm": "tpe", "num_samples": 5, "seed": 0}, + "search_strategy": {"execution_order": "joint", "sampler": "tpe", "max_samples": 5, "seed": 0}, "systems": { "local_system": { "type": "LocalSystem", diff --git a/olive/cli/base.py b/olive/cli/base.py index 9ccbca71f..320dbd8e9 100644 --- a/olive/cli/base.py +++ b/olive/cli/base.py @@ -621,16 +621,16 @@ def add_search_options(sub_parser: ArgumentParser): "--enable_search", type=str, default=None, - const="exhaustive", + const="sequential", nargs="?", - choices=["exhaustive", "tpe", "random"], + choices=["random", "sequential", "tpe"], help=( "Enable search to produce optimal model for the given criteria. " - "Optionally provide search algorithm from available choices. " - "Use exhastive search algorithm by default." + "Optionally provide sampler from available choices. " + "By default, uses sequential sampler." ), ) - search_strategy_group.add_argument("--seed", type=int, default=0, help="Random seed for search algorithm") + search_strategy_group.add_argument("--seed", type=int, default=0, help="Random seed for search sampler") def update_search_options(args, config): @@ -641,7 +641,7 @@ def update_search_options(args, config): "search_strategy", { "execution_order": "joint", - "search_algorithm": args.enable_search, + "sampler": args.enable_search, "seed": args.seed, }, ), diff --git a/olive/cli/quantize.py b/olive/cli/quantize.py index f11e54e23..16298ed27 100644 --- a/olive/cli/quantize.py +++ b/olive/cli/quantize.py @@ -7,6 +7,7 @@ # ruff: noqa: RUF012 from argparse import ArgumentParser +from collections import OrderedDict from copy import deepcopy from typing import Any, Dict @@ -127,7 +128,6 @@ def _get_run_config(self, tempdir: str) -> Dict[str, Any]: self.args.implementation = [self.args.implementation] to_replace = [ - ("pass_flows", [self.args.implementation]), (("passes", "awq", "w_bit"), precision), (("passes", "gptq", "bits"), precision), (("passes", "bnb4", "quant_type"), precision), @@ -143,6 +143,7 @@ def _get_run_config(self, tempdir: str) -> Dict[str, Any]: if v is not None: set_nested_dict_value(config, k, v) + config["passes"] = OrderedDict([(k, v) for k, v in config["passes"].items() if k in self.args.implementation]) return config def run(self): @@ -183,7 +184,6 @@ def run(self): # "inc_static": {"type": "IncStaticQuantization", "data_config": "default_data_config"}, # "vitis": {"type": "VitisAIQuantization", "data_config": "default_data_config"}, }, - "pass_flows": [], "output_dir": "models", "host": "local_system", "target": "local_system", diff --git a/olive/engine/config.py b/olive/engine/config.py index a64d5d189..db820b952 100644 --- a/olive/engine/config.py +++ b/olive/engine/config.py @@ -5,9 +5,10 @@ from typing import Union from olive.common.config_utils import ConfigBase -from olive.common.pydantic_v1 import Extra +from olive.common.pydantic_v1 import Extra, Field from olive.evaluator.olive_evaluator import OliveEvaluatorConfig -from olive.strategy.search_strategy import SearchStrategyConfig +from olive.passes.pass_config import AbstractPassConfig +from olive.search.search_strategy import SearchStrategyConfig from olive.systems.system_config import SystemConfig # pass search-point was pruned due to failed run @@ -25,3 +26,38 @@ class EngineConfig(ConfigBase, extra=Extra.forbid): evaluator: OliveEvaluatorConfig = None plot_pareto_frontier: bool = False no_artifacts: bool = False + + +class RunPassConfig(AbstractPassConfig): + """Pass configuration for Olive workflow. + + This is the configuration for a single pass in Olive workflow. It includes configurations for pass type, config, + etc. + + Example: + .. code-block:: json + + { + "type": "OlivePass", + "config": { + "param1": "value1", + "param2": "value2" + } + } + + """ + + host: Union[SystemConfig, str] = Field( + None, + description=( + "Host system for the pass. If it is a string, must refer to a system config under `systems` section. If not" + " provided, use the engine's host system." + ), + ) + evaluator: Union[OliveEvaluatorConfig, str] = Field( + None, + description=( + "Evaluator for the pass. If it is a string, must refer to an evaluator config under `evaluators` section." + " If not provided, use the engine's evaluator." + ), + ) diff --git a/olive/engine/engine.py b/olive/engine/engine.py index e0f208836..f7c824364 100644 --- a/olive/engine/engine.py +++ b/olive/engine/engine.py @@ -10,12 +10,12 @@ from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union from olive.cache import CacheConfig, OliveCache from olive.common.config_utils import validate_config from olive.common.constants import DEFAULT_WORKFLOW_ID, LOCAL_INPUT_MODEL_ID -from olive.engine.config import FAILED_CONFIG, INVALID_CONFIG, PRUNED_CONFIGS +from olive.engine.config import FAILED_CONFIG, INVALID_CONFIG, PRUNED_CONFIGS, RunPassConfig from olive.engine.footprint import Footprint, FootprintNode, FootprintNodeMetric, get_best_candidate_node from olive.engine.packaging.packaging_generator import generate_output_artifacts from olive.evaluator.metric import Metric @@ -26,8 +26,8 @@ from olive.logging import enable_filelog from olive.model import ModelConfig from olive.package_config import OlivePackageConfig -from olive.strategy.search_parameter import SearchParameter -from olive.strategy.search_strategy import SearchStrategy, SearchStrategyConfig +from olive.search.search_sample import SearchSample +from olive.search.search_strategy import SearchStrategy, SearchStrategyConfig from olive.systems.common import SystemType from olive.systems.system_config import SystemConfig from olive.systems.utils import create_managed_system_with_cache @@ -35,7 +35,7 @@ if TYPE_CHECKING: from olive.engine.packaging.packaging_config import PackagingConfig from olive.passes.olive_pass import Pass - from olive.systems.olive_system import OliveSystem + from olive.search.search_parameter import SearchParameter logger = logging.getLogger(__name__) @@ -84,10 +84,9 @@ def __init__( self.skip_saving_artifacts = no_artifacts self.azureml_client_config = azureml_client_config - self.pass_run_configs: Dict[str, Dict[str, Any]] = OrderedDict() - self.pass_flows: List[List[str]] = [] - self.search_spaces: List[List[Tuple[str, Dict[str, SearchParameter]]]] = [] - self.footprints = defaultdict(Footprint) + self.input_passes_configs: Dict[str, List[RunPassConfig]] = OrderedDict() + self.computed_passes_configs: Dict[str, RunPassConfig] = OrderedDict() + self.footprints: Dict[AcceleratorSpec, Footprint] = defaultdict(Footprint) self._initialized = False @@ -108,18 +107,17 @@ def initialize(self, log_to_file: bool = False, log_severity_level: int = 1): if self.evaluator_config: self.evaluator_config = self.cache.prepare_resources_for_local(self.evaluator_config) - for pass_run_config in self.pass_run_configs.values(): - if pass_run_config["evaluator"]: - pass_run_config["evaluator"] = self.cache.prepare_resources_for_local(pass_run_config["evaluator"]) + for passes_configs in self.input_passes_configs.values(): + for pass_config in passes_configs: + if pass_config.evaluator: + pass_config.evaluator = self.cache.prepare_resources_for_local(pass_config.evaluator) - for pass_run_config in self.pass_run_configs.values(): - host_type = pass_run_config["host"].system_type if pass_run_config["host"] else self.host_config.type - if host_type != SystemType.AzureML: - pass_run_config["input_config"] = self.cache.prepare_resources_for_local( - pass_run_config["input_config"] - ) + for passes_configs in self.input_passes_configs.values(): + for pass_config in passes_configs: + host_type = pass_config.host.system_type if pass_config.host else self.host_config.type + if host_type != SystemType.AzureML: + pass_config.config = self.cache.prepare_resources_for_local(pass_config.config) - self.set_pass_flows(self.pass_flows) self._initialized = True def register( @@ -127,12 +125,12 @@ def register( pass_type: Union[Type["Pass"], str], config: Dict[str, Any] = None, name: str = None, - host: "OliveSystem" = None, + host: SystemConfig = None, evaluator_config: OliveEvaluatorConfig = None, ): """Register a pass configuration so that it could be instantiated and executed later.""" if name: - assert name not in self.pass_run_configs, f"Pass with name {name} already registered" + assert name not in self.input_passes_configs, f"Pass with name {name} already registered" else: idx = 0 while True: @@ -140,26 +138,22 @@ def register( if idx > 0: name = f"{name}_{idx}" idx += 1 - if name not in self.pass_run_configs: + if name not in self.input_passes_configs: break pass_type_name = pass_type if isinstance(pass_type, str) else pass_type.__name__ - logger.debug("Registering pass %s", pass_type_name) - self.pass_run_configs[name] = { - "type": pass_type_name, - "input_config": config or {}, - "host": host, - "evaluator": evaluator_config, - } - - def set_pass_flows(self, pass_flows: List[List[str]] = None): - """Construct pass flows from a list of pass names. - - Args: - pass_flows: a list of pass names, each pass name is a string. + logger.debug("Registering pass %s:%s", name, pass_type_name) + self.input_passes_configs[name] = [ + RunPassConfig( + type=pass_type_name, + config=config or {}, + host=host, + evaluator=evaluator_config, + ) + ] - """ - self.pass_flows = pass_flows or [list(self.pass_run_configs.keys())] + def set_input_passes_configs(self, pass_configs: Dict[str, List[RunPassConfig]]): + self.input_passes_configs = pass_configs def run( self, @@ -193,7 +187,7 @@ def run( output_dir/...: output model files 2. Multiple accelerator specs: - output_dir/{acclerator_spec}/...: Same as 1 but for each accelerator spec + output_dir/{accelerator_spec}/...: Same as 1 but for each accelerator spec output_dir/...: output model files No search mode: @@ -209,7 +203,7 @@ def run( output_dir/...: output model files 2. Multiple accelerator specs - output_dir/{acclerator_spec}/...: Same as 1 but for each accelerator spec + output_dir/{accelerator_spec}/...: Same as 1 but for each accelerator spec output_dir/...: output model files """ @@ -248,7 +242,7 @@ def run( run_history = self.footprints[accelerator_spec].summarize_run_history() self.dump_run_history(run_history, output_subdirs[accelerator_spec] / "run_history.txt") - if packaging_config and self.pass_run_configs: + if packaging_config and self.input_passes_configs: # TODO(trajep): should we support packaging pytorch model? logger.info("Package top ranked %d models as artifacts", sum(len(f.nodes) for f in outputs.values())) generate_output_artifacts( @@ -264,7 +258,7 @@ def run( # TODO(team): refactor output structure # Do not change condition order. For no search, values of outputs are MetricResult # Consolidate the output structure for search and no search mode - if outputs and self.pass_run_configs and not next(iter(outputs.values())).check_empty_nodes(): + if outputs and self.input_passes_configs and not next(iter(outputs.values())).check_empty_nodes(): best_node: FootprintNode = get_best_candidate_node(outputs, self.footprints) self.cache.save_model(model_id=best_node.model_id, output_dir=output_dir, overwrite=True) if len(accelerator_output_dir_list) > 1 and self.skip_saving_artifacts: @@ -280,12 +274,6 @@ def run_accelerator( evaluate_input_model: bool, accelerator_spec: "AcceleratorSpec", ): - # Setup pass configs - self._setup_pass_configs(accelerator_spec) - - # generate search space - self._setup_search_spaces(accelerator_spec) - # hash the input model input_model_id = input_model_config.get_model_id() if input_model_id == LOCAL_INPUT_MODEL_ID and self.cache.enable_shared_cache: @@ -300,32 +288,29 @@ def run_accelerator( try: if evaluate_input_model and not self.evaluator_config: logger.debug("evaluate_input_model is True but no evaluator provided. Skipping input model evaluation.") + elif evaluate_input_model: results = self._evaluate_model( input_model_config, input_model_id, self.evaluator_config, accelerator_spec ) logger.info("Input model evaluation results: %s", results) + if not self.skip_saving_artifacts: results_path = output_dir / "input_model_metrics.json" with results_path.open("w") as f: json.dump(results.to_json(), f, indent=4) logger.info("Saved evaluation results of input model to %s", results_path) - if not self.pass_run_configs: + if not self.input_passes_configs: logger.debug("No passes registered, return input model evaluation results.") return results if self.search_strategy: logger.debug("Running Olive in search mode ...") - output_footprint = self.run_search( - input_model_config, - input_model_id, - accelerator_spec, - output_dir, - ) + output_footprint = self._run_search(input_model_config, input_model_id, accelerator_spec, output_dir) else: logger.debug("Running Olive in no-search mode ...") - output_footprint = self.run_no_search(input_model_config, input_model_id, accelerator_spec, output_dir) + output_footprint = self._run_no_search(input_model_config, input_model_id, accelerator_spec, output_dir) except EXCEPTIONS_TO_RAISE: raise except Exception: @@ -343,29 +328,16 @@ def get_host_device(self): # for host device, we will always use the first accelerator device return self.host_config.config.accelerators[0].device if self.host_config.config.accelerators else None - def _setup_pass_configs(self, accelerator_spec: "AcceleratorSpec"): - disable_search = self.search_strategy is None - for pass_run_config in self.pass_run_configs.values(): - pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_run_config["type"]) - pass_run_config["config"] = pass_cls.generate_config( - accelerator_spec, pass_run_config["input_config"], disable_search - ) - - def _setup_search_spaces(self, accelerator_spec: "AcceleratorSpec"): - self.search_spaces.clear() - if self.search_strategy is None: - return + def _compute_no_search_pass_configs(self, accelerator_spec: "AcceleratorSpec"): + self.computed_passes_configs.clear() + for name, passes_configs in self.input_passes_configs.items(): + pass_config = validate_config(passes_configs[0].dict(), RunPassConfig) - for pass_flow in self.pass_flows: - pass_search_spaces: List[Tuple[str, Dict[str, SearchParameter]]] = [] - for pass_name in pass_flow: - pass_run_config = self.pass_run_configs[pass_name] - pass_search_spaces.append( - (pass_name, {k: v for k, v in pass_run_config["config"].items() if isinstance(v, SearchParameter)}) - ) - self.search_spaces.append(pass_search_spaces) + pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_config.type) + pass_config.config = pass_cls.generate_config(accelerator_spec, pass_config.config, {}, True) + self.computed_passes_configs[name] = pass_config - def run_no_search( + def _run_no_search( self, input_model_config: ModelConfig, input_model_id: str, @@ -375,46 +347,92 @@ def run_no_search( """Run all the registered Olive pass flows in no-search mode.""" output_model_dir = Path(output_dir) - output_model_ids = [] - for pass_flow in self.pass_flows: - # search point is empty since there is no search - passes_to_run = [(pass_name, {}) for pass_name in pass_flow] - - # run all the passes in the pass flow - logger.debug("Running %s with no search ...", pass_flow) - should_prune, signal, model_ids = self._run_passes( - passes_to_run, - input_model_config, - input_model_id, - accelerator_spec, - ) + # Compute pas configs + self._compute_no_search_pass_configs(accelerator_spec) - if should_prune: - failed_pass = pass_flow[len(model_ids)] - logger.warning( - "Flow %s is pruned due to failed or invalid config for pass '%s'", pass_flow, failed_pass - ) - continue + # run all the passes in the pass flow + pass_flow = list(self.computed_passes_configs.keys()) + logger.debug("Running %s with no search ...", pass_flow) + should_prune, signal, model_ids = self._run_passes(input_model_config, input_model_id, accelerator_spec) - # use output_model_dir if there is only one pass flow - # else output_model_dir/pass_flow - flow_output_dir = output_model_dir / "-".join(pass_flow) if len(self.pass_flows) > 1 else output_model_dir - flow_output_dir.mkdir(parents=True, exist_ok=True) + if should_prune: + failed_pass = pass_flow[len(model_ids)] + logger.warning("Flow %s is pruned due to failed or invalid config for pass '%s'", pass_flow, failed_pass) + return Footprint() - if signal is not None and not self.skip_saving_artifacts: - results_path = flow_output_dir / "metrics.json" - with open(results_path, "w") as f: - json.dump(signal.to_json(), f, indent=4) - logger.info("Saved evaluation results of output model to %s", results_path) + # use output_model_dir if there is only one pass flow + # else output_model_dir/pass_flow + flow_output_dir = output_model_dir / "-".join(pass_flow) if len(pass_flow) > 1 else output_model_dir + flow_output_dir.mkdir(parents=True, exist_ok=True) - output_model_ids.append(model_ids[-1]) + if signal is not None and not self.skip_saving_artifacts: + results_path = flow_output_dir / "metrics.json" + with open(results_path, "w") as f: + json.dump(signal.to_json(), f, indent=4) + logger.info("Saved evaluation results of output model to %s", results_path) - output_footprints = self.footprints[accelerator_spec].create_footprints_by_model_ids(output_model_ids) + output_footprints = self.footprints[accelerator_spec].create_footprints_by_model_ids([model_ids[-1]]) if not self.skip_saving_artifacts: output_footprints.to_file(output_dir / "output_footprints.json") return output_footprints - def run_search( + def _get_search_space_config(self, accelerator_spec: "AcceleratorSpec"): + space_config: Dict[str, List[Dict[str, SearchParameter]]] = OrderedDict() + for pass_name, passes_configs in self.input_passes_configs.items(): + space_config[pass_name] = pass_params_config = [] + for pass_config in passes_configs: + pass_cls = self.olive_config.import_pass_module(pass_config.type) + _, search_params = pass_cls.get_config_params(accelerator_spec, pass_config.config, False) + pass_params_config.append(search_params) + return space_config + + def _get_search_space_objectives( + self, + input_model_config: ModelConfig, + input_model_id: str, + accelerator_spec: "AcceleratorSpec", + ) -> Dict[str, List[Dict[str, Any]]]: + objectives_by_pass_name: Dict[str, List[Dict[str, Any]]] = {} + objectives_by_evaluator_name: Dict[str, Dict[str, Any]] = {} + for pass_name, passes_configs in self.input_passes_configs.items(): + objectives_by_pass_name[pass_name] = passes_objectives = [] + for pass_config in passes_configs: + evaluator_config = pass_config.evaluator or self.evaluator_config + if evaluator_config.name not in objectives_by_evaluator_name: + objectives_by_evaluator_name[evaluator_config.name] = self.resolve_objectives( + input_model_config, input_model_id, evaluator_config.metrics, accelerator_spec + ) + passes_objectives.append(objectives_by_evaluator_name[evaluator_config.name]) + + accelerator_objectives: Dict[str, Any] = {} + for objectives in objectives_by_evaluator_name.values(): + accelerator_objectives.update(objectives) + self.footprints[accelerator_spec].record_objective_dict(accelerator_objectives) + return objectives_by_pass_name + + def _compute_search_pass_configs(self, accelerator_spec: "AcceleratorSpec", sample: SearchSample): + self.computed_passes_configs.clear() + sample_passes_configs = sample.passes_configs + if not sample_passes_configs: + return + + disable_pass_params_search = not self.search_strategy.config.include_pass_params + for pass_name, passes_configs in self.input_passes_configs.items(): + if pass_name in sample_passes_configs: + sample_pass_config = sample_passes_configs[pass_name] + pass_config = passes_configs[sample_pass_config["index"]] + pass_config = validate_config(pass_config.dict(), RunPassConfig) + + pass_cls = self.olive_config.import_pass_module(pass_config.type) + pass_config.config = pass_cls.generate_config( + accelerator_spec, + pass_config.config, + sample_pass_config["params"], + disable_pass_params_search, + ) + self.computed_passes_configs[pass_name] = pass_config + + def _run_search( self, input_model_config: ModelConfig, input_model_id: str, @@ -422,55 +440,34 @@ def run_search( output_dir: Path, ): """Run all the registered Olive passes in search model where search strategy is not None.""" - # get objective_dict - evaluator_config = self.evaluator_for_pass(list(self.pass_run_configs.keys())[-1]) - - if evaluator_config is None: - raise ValueError("No evaluator provided for the last pass") - else: - objective_dict = self.resolve_objectives( - input_model_config, input_model_id, evaluator_config.metrics, accelerator_spec - ) - self.footprints[accelerator_spec].record_objective_dict(objective_dict) - # initialize the search strategy - self.search_strategy.initialize(self.search_spaces, input_model_id, objective_dict) - output_model_num = self.search_strategy.get_output_model_num() - - # record start time - start_time = time.time() - iter_num = 0 - while True: - iter_num += 1 - - # get the next step - next_step = self.search_strategy.next_step() + search_space_config = self._get_search_space_config(accelerator_spec) + search_space_objectives = self._get_search_space_objectives( + input_model_config, input_model_id, accelerator_spec + ) + self.search_strategy.initialize(search_space_config, input_model_id, search_space_objectives) - # if no more steps, break - if next_step is None: - break + for sample in self.search_strategy: # pylint: disable=not-an-iterable + self._compute_search_pass_configs(accelerator_spec, sample) - # get the model id of the first input model - model_id = next_step["model_id"] - model_config = input_model_config if model_id == input_model_id else self._load_model(model_id) + if self.computed_passes_configs: + # get the model id of the first input model + model_id = sample.model_ids[0] + model_config = input_model_config if model_id == input_model_id else self._load_model(model_id) - logger.debug("Step %d with search point %s ...", iter_num, next_step["search_point"]) + logger.info( + "Step %d with search point %s ...", self.search_strategy.iteration_count, sample.search_point + ) - # run all the passes in the step - should_prune, signal, model_ids = self._run_passes( - next_step["passes"], - model_config, - model_id, - accelerator_spec, - ) + # run all the passes in the step + should_prune, signal, model_ids = self._run_passes(model_config, model_id, accelerator_spec) + else: + should_prune, signal, model_ids = True, None, [] # record feedback signal - self.search_strategy.record_feedback_signal(next_step["search_point"], signal, model_ids, should_prune) - - time_diff = time.time() - start_time - self.search_strategy.check_exit_criteria(iter_num, time_diff, signal) + self.search_strategy.record_feedback_signal(sample.search_point.index, signal, model_ids, should_prune) - return self.create_pareto_frontier_footprints(accelerator_spec, output_model_num, output_dir) + return self.create_pareto_frontier_footprints(accelerator_spec, None, output_dir) def create_pareto_frontier_footprints( self, accelerator_spec: "AcceleratorSpec", output_model_num: int, output_dir: Path @@ -598,13 +595,13 @@ def resolve_goals( return resolved_goals - def host_for_pass(self, pass_name: str) -> "OliveSystem": - host: SystemConfig = self.pass_run_configs[pass_name]["host"] - return host or self.host + def host_for_pass(self, pass_name: str) -> SystemConfig: + host: SystemConfig = self.computed_passes_configs[pass_name].host + return host.create_system() if host else self.host def evaluator_for_pass(self, pass_name: str) -> OliveEvaluatorConfig: """Return evaluator for the given pass.""" - return self.pass_run_configs[pass_name]["evaluator"] or self.evaluator_config + return self.computed_passes_configs[pass_name].evaluator or self.evaluator_config def _cache_model(self, model_id: str, model: Union[ModelConfig, str], check_object: bool = True): # TODO(trajep): move model/pass run/evaluation cache into footprints @@ -623,7 +620,6 @@ def _load_model(self, model_id: str) -> Union[ModelConfig, str]: def _run_passes( self, - passes: List[Tuple[str, Dict[str, Any]]], model_config: ModelConfig, model_id: str, accelerator_spec: "AcceleratorSpec", @@ -637,10 +633,9 @@ def _run_passes( model_ids = [] pass_name = None - for pass_name, pass_search_point in passes: + for pass_name in self.computed_passes_configs: model_config, model_id = self._run_pass( pass_name, - pass_search_point, model_config, model_id, accelerator_spec, @@ -663,7 +658,7 @@ def _run_passes( else: logger.info("Run model evaluation for the final model...") signal = self._evaluate_model(model_config, model_id, evaluator_config, accelerator_spec) - logger.debug("Signal: %s", signal) + logger.debug("Signal: %s, %s", signal, model_ids) else: signal = None logger.warning("Skipping evaluation as model was pruned") @@ -673,22 +668,21 @@ def _run_passes( def _run_pass( self, pass_name: str, - pass_search_point: Dict[str, Any], input_model_config: ModelConfig, input_model_id: str, accelerator_spec: "AcceleratorSpec", ): """Run a pass on the input model.""" run_start_time = datetime.now().timestamp() - pass_run_config: Dict[str, Any] = self.pass_run_configs[pass_name] - pass_type_name = pass_run_config["type"] - logger.info("Running pass %s:%s %s", pass_name, pass_type_name, pass_search_point) + pass_config: RunPassConfig = self.computed_passes_configs[pass_name] + pass_type_name = pass_config.type + + logger.info("Running pass %s:%s", pass_name, pass_type_name) # check whether the config is valid - pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_run_config["type"]) - pass_config = pass_cls.config_at_search_point(pass_search_point, accelerator_spec, pass_run_config["config"]) - if not pass_cls.validate_config(pass_config, accelerator_spec, self.search_strategy is None): + pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_config.type) + if not pass_cls.validate_config(pass_config.config, accelerator_spec, self.search_strategy is None): logger.warning("Invalid config, pruned.") logger.debug(pass_config) # no need to record in footprint since there was no run and thus no valid/failed model @@ -697,12 +691,13 @@ def _run_pass( # this helps reusing cached models for different accelerator specs return INVALID_CONFIG, None - p: Pass = pass_cls(accelerator_spec, pass_config, self.get_host_device()) - pass_config = p.serialize_config(pass_config, check_object=True) + p: Pass = pass_cls(accelerator_spec, pass_config.config, self.get_host_device()) + pass_config = p.serialize_config(pass_config.config, check_object=True) output_model_config = None # load run from cache if it exists run_accel = None if p.is_accelerator_agnostic(accelerator_spec) else accelerator_spec + output_model_id = self.cache.get_output_model_id(pass_type_name, pass_config, input_model_id, run_accel) run_cache = self.cache.load_run_from_model_id(output_model_id) if run_cache: @@ -869,6 +864,7 @@ def create_system(config: "SystemConfig", accelerator_spec): target_start_time = time.time() self.target = create_system(self.target_config, accelerator_spec) logger.info("Target system created in %f seconds", time.time() - target_start_time) + if not self.host: host_accelerators = self.host_config.config.accelerators if host_accelerators and host_accelerators[0].execution_providers: diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index d23b85748..2503503fb 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -1141,6 +1141,7 @@ def evaluate( class OliveEvaluatorConfig(NestedConfig): _nested_field_name = "type_args" + name: str = None type: str = None type_args: Dict = Field(default_factory=dict) diff --git a/olive/passes/olive_pass.py b/olive/passes/olive_pass.py index 1f9cbb061..24135cfb0 100644 --- a/olive/passes/olive_pass.py +++ b/olive/passes/olive_pass.py @@ -23,15 +23,14 @@ create_config_class, ) from olive.resource_path import ResourcePath -from olive.strategy.search_parameter import ( +from olive.search.search_parameter import ( Categorical, Conditional, ConditionalDefault, SearchParameter, SpecialParamValue, ) -from olive.strategy.search_space import SearchSpace -from olive.strategy.utils import cyclic_search_space, order_search_parameters +from olive.search.utils import cyclic_search_space, order_search_parameters logger = logging.getLogger(__name__) @@ -106,12 +105,12 @@ def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool: return True @classmethod - def generate_config( + def get_config_params( cls, accelerator_spec: AcceleratorSpec, config: Optional[Dict[str, Any]] = None, disable_search: Optional[bool] = False, - ) -> Dict[str, Any]: + ) -> Tuple[Dict[str, Any], Dict[str, SearchParameter]]: """Generate search space for the pass.""" assert accelerator_spec is not None, "Please specify the accelerator spec for the pass" config = config or {} @@ -125,10 +124,26 @@ def generate_config( # Generate the search space by using both default value and default search value and user provided config config = validate_config(config, config_class) - config = cls._resolve_config(config, default_config) - fixed_values, search_params = cls._init_fixed_and_search_params(config, default_config) - return {**fixed_values, **search_params} + return cls._init_fixed_and_search_params(config, default_config) + + @classmethod + def generate_config( + cls, + accelerator_spec: AcceleratorSpec, + config: Optional[Dict[str, Any]] = None, + point: Optional[Dict[str, Any]] = None, + disable_search: Optional[bool] = False, + ) -> Dict[str, Any]: + """Get the configuration for the pass at a specific point in the search space.""" + assert accelerator_spec is not None, "Please specify the accelerator spec for the pass" + + point = point or {} + fixed_values, search_params = cls.get_config_params(accelerator_spec, config, disable_search) + assert ( + set(point.keys()).intersection(set(search_params.keys())) == point.keys() + ), "Search point is not in the search space." + return {**fixed_values, **search_params, **point} @classmethod def _identify_search_values( @@ -184,29 +199,6 @@ def default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConf ), f"{param} ending with data_config must be of type DataConfig." return config - @classmethod - def config_at_search_point( - cls, - point: Dict[str, Any], - accelerator_spec: AcceleratorSpec, - config: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - """Get the configuration for the pass at a specific point in the search space.""" - assert accelerator_spec is not None, "Please specify the accelerator spec for the pass" - - # Get the config class with default search value - config_class, default_config = cls.get_config_class(accelerator_spec) - - # Replace user-provided values with Categorical if user intended to search - config = cls._identify_search_values(config or {}, default_config) - - # Generate the search space by using both default value and default search value and user provided config - config = validate_config(config, config_class) - config = cls._resolve_config(config, default_config) - fixed_values, search_params = cls._init_fixed_and_search_params(config, default_config) - assert set(point.keys()) == set(search_params.keys()), "Search point is not in the search space." - return {**fixed_values, **search_params, **point} - @classmethod def validate_config( cls, @@ -451,8 +443,6 @@ def _init_fixed_and_search_params( value = str(Path(value).resolve()) fixed_params[key] = value assert not cyclic_search_space(search_space), "Search space is cyclic." - # TODO(jambayk): better error message, e.g. which parameters are invalid, how they are invalid - assert SearchSpace({"search_space": search_space}).size() > 0, "There are no valid points in the search space." return fixed_params, search_space @classmethod @@ -528,5 +518,5 @@ def create_pass_from_dict( if accelerator_spec is None: accelerator_spec = DEFAULT_CPU_ACCELERATOR - config = pass_cls.generate_config(accelerator_spec, config, disable_search) + config = pass_cls.generate_config(accelerator_spec, config, disable_search=disable_search) return pass_cls(accelerator_spec, config, host_device) diff --git a/olive/passes/onnx/inc_quantization.py b/olive/passes/onnx/inc_quantization.py index 0ff18ddd6..d697ad2ab 100644 --- a/olive/passes/onnx/inc_quantization.py +++ b/olive/passes/onnx/inc_quantization.py @@ -24,7 +24,7 @@ from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_has_adapters, model_proto_to_olive_model from olive.passes.pass_config import PassConfigParam -from olive.strategy.search_parameter import Boolean, Categorical, Conditional +from olive.search.search_parameter import Boolean, Categorical, Conditional logger = logging.getLogger(__name__) diff --git a/olive/passes/onnx/nvmo_quantization.py b/olive/passes/onnx/nvmo_quantization.py index b44feb65f..f3d88b09c 100644 --- a/olive/passes/onnx/nvmo_quantization.py +++ b/olive/passes/onnx/nvmo_quantization.py @@ -20,7 +20,7 @@ from olive.passes import Pass from olive.passes.onnx.common import model_proto_to_olive_model from olive.passes.pass_config import PassConfigParam -from olive.strategy.search_parameter import Categorical +from olive.search.search_parameter import Categorical logger = logging.getLogger(__name__) diff --git a/olive/passes/onnx/quantization.py b/olive/passes/onnx/quantization.py index 63c555ed6..da39ab65f 100644 --- a/olive/passes/onnx/quantization.py +++ b/olive/passes/onnx/quantization.py @@ -29,7 +29,7 @@ ) from olive.passes.pass_config import PassConfigParam from olive.resource_path import LocalFile -from olive.strategy.search_parameter import Boolean, Categorical, Conditional, ConditionalDefault +from olive.search.search_parameter import Boolean, Categorical, Conditional, ConditionalDefault logger = logging.getLogger(__name__) diff --git a/olive/passes/onnx/session_params_tuning.py b/olive/passes/onnx/session_params_tuning.py index 9e69f4ec1..701c35080 100644 --- a/olive/passes/onnx/session_params_tuning.py +++ b/olive/passes/onnx/session_params_tuning.py @@ -23,7 +23,7 @@ from olive.model import ONNXModelHandler from olive.passes import Pass from olive.passes.pass_config import PassConfigParam, get_user_script_data_config -from olive.strategy.search_parameter import Categorical +from olive.search.search_parameter import Categorical logger = logging.getLogger(__name__) @@ -114,25 +114,25 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon "providers_list": PassConfigParam( type_=str, default_value=execution_provider, - searchable_values=Categorical(AcceleratorLookup.get_execution_providers_for_device(device)), + search_defaults=Categorical(AcceleratorLookup.get_execution_providers_for_device(device)), description="Execution providers framework list to execute the ONNX models.", ), "provider_options_list": PassConfigParam( type_=Dict[str, Any], default_value={}, - searchable_values=Categorical([{}]), + search_defaults=Categorical([{}]), description="Execution provider options to execute the ONNX models.", ), "execution_mode_list": PassConfigParam( type_=int, default_value=None, - searchable_values=Categorical([None]), + search_defaults=Categorical([None]), description="Parallelism list between operators.", ), "opt_level_list": PassConfigParam( type_=int, default_value=None, - searchable_values=Categorical([None]), + search_defaults=Categorical([None]), description="Optimization level list for ONNX model.", ), "trt_fp16_enable": PassConfigParam( @@ -141,13 +141,13 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon "intra_thread_num_list": PassConfigParam( type_=int, default_value=None, - searchable_values=Categorical([None]), + search_defaults=Categorical([None]), description="List of intra thread number for test.", ), "inter_thread_num_list": PassConfigParam( type_=int, default_value=None, - searchable_values=Categorical([None]), + search_defaults=Categorical([None]), description="List of inter thread number for test.", ), "extra_session_config": PassConfigParam( diff --git a/olive/passes/onnx/vitis_ai_quantization.py b/olive/passes/onnx/vitis_ai_quantization.py index c61324619..bd0088cac 100644 --- a/olive/passes/onnx/vitis_ai_quantization.py +++ b/olive/passes/onnx/vitis_ai_quantization.py @@ -25,7 +25,7 @@ ) from olive.passes.pass_config import PassConfigParam from olive.resource_path import LocalFile -from olive.strategy.search_parameter import Boolean, Categorical, Conditional +from olive.search.search_parameter import Boolean, Categorical, Conditional logger = logging.getLogger(__name__) diff --git a/olive/passes/pass_config.py b/olive/passes/pass_config.py index 82695f9d9..27144e420 100644 --- a/olive/passes/pass_config.py +++ b/olive/passes/pass_config.py @@ -18,7 +18,7 @@ from olive.hardware.accelerator import Device from olive.hardware.constants import DEVICE_TO_EXECUTION_PROVIDERS from olive.resource_path import validate_resource_path -from olive.strategy.search_parameter import SearchParameter, SpecialParamValue, json_to_search_parameter +from olive.search.search_parameter import SearchParameter, SpecialParamValue, json_to_search_parameter class PassParamDefault(StrEnumBase): diff --git a/olive/passes/pytorch/lora.py b/olive/passes/pytorch/lora.py index 8135d5a40..06ef23fd8 100644 --- a/olive/passes/pytorch/lora.py +++ b/olive/passes/pytorch/lora.py @@ -37,7 +37,7 @@ load_hf_base_model, prepare_model_for_finetuning, ) -from olive.strategy.search_parameter import Categorical +from olive.search.search_parameter import Categorical if TYPE_CHECKING: import torch diff --git a/olive/passes/snpe/quantization.py b/olive/passes/snpe/quantization.py index a7cd97f57..398cfed21 100644 --- a/olive/passes/snpe/quantization.py +++ b/olive/passes/snpe/quantization.py @@ -14,7 +14,7 @@ from olive.platform_sdk.qualcomm.snpe.tools.dev import quantize_dlc from olive.platform_sdk.qualcomm.utils.data_loader import FileListCommonDataLoader, FileListDataLoader from olive.resource_path import LocalFile -from olive.strategy.search_parameter import Boolean +from olive.search.search_parameter import Boolean class SNPEQuantization(Pass): diff --git a/olive/search/__init__.py b/olive/search/__init__.py new file mode 100644 index 000000000..862c45ce3 --- /dev/null +++ b/olive/search/__init__.py @@ -0,0 +1,4 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- diff --git a/olive/search/samplers/__init__.py b/olive/search/samplers/__init__.py new file mode 100644 index 000000000..70e45d124 --- /dev/null +++ b/olive/search/samplers/__init__.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from olive.search.samplers.random_sampler import RandomSampler +from olive.search.samplers.search_sampler import SearchSampler +from olive.search.samplers.sequential_sampler import SequentialSampler +from olive.search.samplers.tpe_sampler import TPESampler + +REGISTRY = SearchSampler.registry + +__all__ = ["REGISTRY", "RandomSampler", "SearchSampler", "SequentialSampler", "TPESampler"] diff --git a/olive/search/samplers/optuna_sampler.py b/olive/search/samplers/optuna_sampler.py new file mode 100644 index 000000000..7d42b9310 --- /dev/null +++ b/olive/search/samplers/optuna_sampler.py @@ -0,0 +1,144 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from abc import abstractmethod +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +import optuna + +from olive.common.config_utils import ConfigBase, ConfigParam +from olive.search.samplers.search_sampler import SearchSampler +from olive.search.search_parameter import Categorical, Conditional, SearchParameter +from olive.search.search_point import SearchPoint +from olive.search.search_space import SearchSpace + +if TYPE_CHECKING: + from optuna.trial import Trial + + from olive.evaluator.metric_result import MetricResult + + +optuna.logging.set_verbosity(optuna.logging.WARNING) + + +class OptunaSampler(SearchSampler): + """Optuna sampler for search sampling.""" + + name = "optuna" + + @classmethod + def _default_config(cls) -> Dict[str, ConfigParam]: + return { + **super()._default_config(), + "seed": ConfigParam(type_=int, default_value=1, description="Seed for the rng."), + } + + def __init__( + self, + search_space: SearchSpace, + config: Optional[Union[Dict[str, Any], ConfigBase]] = None, + ): + super().__init__(search_space, config) + + # Initialize the searcher + self._sampler = self._create_sampler() + # TODO(olivedev): There is no absolute direction to set. + # directions = ["maximize" if hib else "minimize" for hib in self._higher_is_betters] + # self._study = optuna.create_study(directions=directions, sampler=self._sampler) + self._study = optuna.create_study(sampler=self._sampler) + self._num_samples_suggested = 0 + self._search_point_index_to_trail_id = {} + + @property + def num_samples_suggested(self) -> int: + """Returns the number of samples suggested so far.""" + return self._num_samples_suggested + + @abstractmethod + def _create_sampler(self) -> optuna.samplers.BaseSampler: + """Create the sampler.""" + + def suggest(self) -> SearchPoint: + """Suggest a new configuration to try.""" + if self.should_stop: + return None + + # The optuna.BaseSampler seems to be returning duplicates. Avoid returning + # duplicates by querying recurrently to find the next sample. It won't be + # an infinite loop because if all samples had been processed, the + # self.should_stop check above would succeed. + while True: + trial = self._study.ask() + values: Dict[str, Tuple[int, Any]] = OrderedDict() + spi, _, values = self._get_search_point_values("", "", self._search_space, trial, values) + if spi not in self._search_point_index_to_trail_id: + break + + self._search_point_index_to_trail_id[spi] = trial.number + self._num_samples_suggested += 1 + return SearchPoint(spi, values) + + def _get_search_point_values( + self, + prefix: str, + name: str, + param: Union[SearchParameter, SearchSpace], + trial: "Trial", + values: Dict[str, Tuple[int, Any]], + ) -> Tuple[int, int, Union[Dict[str, Any], Any]]: + if isinstance(param, SearchParameter): + suggestion_name = f"{prefix}__{name}" if prefix else name + + if isinstance(param, Categorical): + suggestions = param.get_support() + elif isinstance(param, Conditional): + parent_values = {parent: values[parent][1] for parent in param.parents} + suggestions = param.get_support_with_args(parent_values) + max_length = max(len(support.get_support()) for support in param.support.values()) + suggestions += param.default.get_support() * (max_length - len(suggestions)) + suggestion_name += "___" + " ".join(str(v) for v in parent_values.values()) + + suggestion_len = len(suggestions) + suggestion_index = trial.suggest_categorical(suggestion_name, list(range(suggestion_len))) + suggestion = suggestions[suggestion_index] + + if isinstance(suggestion, (SearchParameter, SearchSpace)): + suggestion_index, suggestion_len, _ = self._get_search_point_values( + prefix, name, suggestion, trial, values + ) + else: + values[name] = suggestion_index, suggestion + + return suggestion_index, suggestion_len, suggestion + + elif isinstance(param, SearchSpace): + child_values = OrderedDict() + indices_lengths = [] + for child_name, child_param in param.parameters: + child_index, suggestions_len, _ = self._get_search_point_values( + prefix, child_name, child_param, trial, child_values + ) + indices_lengths.append((child_index, suggestions_len)) + values[name] = (0, child_values) + + spi = 0 + for child_index, suggestions_len in reversed(indices_lengths): + spi *= suggestions_len + spi += child_index + + return spi, len(param), child_values + + else: + raise ValueError(f"Unsupported parameter type: {type(param)}") + + def record_feedback_signal( + self, search_point_index: int, objectives: Dict[str, dict], signal: "MetricResult", should_prune: bool = False + ): + trial_id = self._search_point_index_to_trail_id[search_point_index] + if should_prune: + self._study.tell(trial_id, state=optuna.trial.TrialState.PRUNED) + else: + values = [signal[objective].value for objective in objectives] + self._study.tell(trial_id, values) diff --git a/olive/search/samplers/random_sampler.py b/olive/search/samplers/random_sampler.py new file mode 100644 index 000000000..a0f8b4099 --- /dev/null +++ b/olive/search/samplers/random_sampler.py @@ -0,0 +1,61 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from random import Random +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from olive.common.config_utils import ConfigBase, ConfigParam +from olive.search.samplers.search_sampler import SearchSampler +from olive.search.search_space import SearchSpace + +if TYPE_CHECKING: + from olive.search.search_point import SearchPoint + + +class RandomSampler(SearchSampler): + """Random sampler. Samples random points from the search space.""" + + name = "random" + + @classmethod + def _default_config(cls) -> Dict[str, ConfigParam]: + return { + **super()._default_config(), + "seed": ConfigParam(type_=int, default_value=1, description="Seed for the rng."), + } + + def __init__( + self, + search_space: SearchSpace, + config: Optional[Union[Dict[str, Any], ConfigBase]] = None, + ): + super().__init__(search_space, config) + + self._rng = Random(self.config.seed) + self._search_points = [None] * len(self._search_space) + self._available = list(range(len(self._search_space))) + + def reset_rng(self): + """Reset the random number generator.""" + self._rng = Random(self.config.seed) + + @property + def num_samples_suggested(self) -> int: + """Returns the number of samples suggested so far.""" + return len(self._search_space) - len(self._available) + + @property + def should_stop(self) -> bool: + """Check if the searcher should stop at the current trial.""" + return super().should_stop or (len(self._available) == 0) + + def suggest(self) -> "SearchPoint": + """Suggest a new configuration to try.""" + if self.should_stop: + return None + + index = self._available[self._rng.randrange(len(self._available))] + self._available.remove(index) + self._search_points[index] = self._search_space[index] + return self._search_points[index] diff --git a/olive/search/samplers/search_sampler.py b/olive/search/samplers/search_sampler.py new file mode 100644 index 000000000..bf05e7650 --- /dev/null +++ b/olive/search/samplers/search_sampler.py @@ -0,0 +1,69 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Type, Union + +from olive.common.auto_config import AutoConfigClass +from olive.common.config_utils import ConfigBase, ConfigParam +from olive.search.search_space import SearchSpace + +if TYPE_CHECKING: + from olive.evaluator.metric_result import MetricResult + from olive.search.search_point import SearchPoint + + +class SearchSampler(AutoConfigClass): + """Abstract base class for searchers.""" + + registry: ClassVar[Dict[str, Type["SearchSampler"]]] = {} + + @classmethod + def _default_config(cls) -> Dict[str, ConfigParam]: + return { + "max_samples": ConfigParam( + type_=int, + default_value=0, + description="Maximum number of samples to suggest. Search exhaustively if set to zero.", + ), + } + + def __init__( + self, + search_space: SearchSpace, + config: Optional[Union[Dict[str, Any], ConfigBase]] = None, + ): + super().__init__(config) + + self._search_space = search_space + + @property + @abstractmethod + def num_samples_suggested(self) -> int: + """Returns the number of samples suggested so far.""" + return 0 + + @property + def max_samples(self) -> int: + """Returns the maximum number of samples to suggest.""" + return self.config.max_samples + + @property + def should_stop(self) -> bool: + """Check if the searcher should stop at the current trial.""" + return ( + (len(self._search_space) == 0) + or (self.num_samples_suggested >= len(self._search_space)) + or ((self.max_samples > 0) and (self.num_samples_suggested >= self.max_samples)) + ) + + @abstractmethod + def suggest(self) -> "SearchPoint": + """Suggest a new configuration to try.""" + return None + + def record_feedback_signal( + self, search_point_index: int, objectives: Dict[str, dict], signal: "MetricResult", should_prune: bool = False + ): + """Report the result of a configuration.""" diff --git a/olive/search/samplers/sequential_sampler.py b/olive/search/samplers/sequential_sampler.py new file mode 100644 index 000000000..7c8b30d0e --- /dev/null +++ b/olive/search/samplers/sequential_sampler.py @@ -0,0 +1,46 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from olive.common.config_utils import ConfigBase +from olive.search.samplers.search_sampler import SearchSampler +from olive.search.search_space import SearchSpace + +if TYPE_CHECKING: + from olive.search.search_point import SearchPoint + + +class SequentialSampler(SearchSampler): + """Sequential sampler provides search sequential search points.""" + + name = "sequential" + + @classmethod + def _default_config(cls): + return super()._default_config() + + def __init__( + self, + search_space: SearchSpace, + config: Optional[Union[Dict[str, Any], ConfigBase]] = None, + ): + super().__init__(search_space, config) + + self._index = 0 + + @property + def num_samples_suggested(self) -> int: + """Returns the number of samples suggested so far.""" + return self._index + + def suggest(self) -> "SearchPoint": + """Suggest a new configuration to try.""" + if self.should_stop: + return None + + index = self._index + self._index += 1 + + return self._search_space[index] diff --git a/olive/strategy/search_algorithm/tpe_sampler.py b/olive/search/samplers/tpe_sampler.py similarity index 92% rename from olive/strategy/search_algorithm/tpe_sampler.py rename to olive/search/samplers/tpe_sampler.py index 0954043db..f52f2831f 100644 --- a/olive/strategy/search_algorithm/tpe_sampler.py +++ b/olive/search/samplers/tpe_sampler.py @@ -7,10 +7,10 @@ import optuna from olive.common.config_utils import ConfigParam -from olive.strategy.search_algorithm.optuna_sampler import OptunaSearchAlgorithm +from olive.search.samplers.optuna_sampler import OptunaSampler -class TPESearchAlgorithm(OptunaSearchAlgorithm): +class TPESampler(OptunaSampler): """Sample using TPE (Tree-structured Parzen Estimator) algorithm. Uses optuna TPESampler underneath. Refer to https://optuna.readthedocs.io/en/stable/reference/samplers/generated/optuna.samplers.TPESampler.html diff --git a/olive/strategy/search_parameter.py b/olive/search/search_parameter.py similarity index 99% rename from olive/strategy/search_parameter.py rename to olive/search/search_parameter.py index e1adfc50b..f3aa2ef48 100644 --- a/olive/strategy/search_parameter.py +++ b/olive/search/search_parameter.py @@ -36,7 +36,7 @@ class SpecialParamValue(StrEnumBase): """Special values for parameters. IGNORED: the parameter gets the value "OLIVE_IGNORED_PARAM_VALUE". The pass might ignore this parameter. - INVALID: Any search point with this value is invalid. The search algorithm will not suggest such a search point. + INVALID: Any search point with this value is invalid. The search strategy will not suggest such a search point. """ IGNORED = "OLIVE_IGNORED_PARAM_VALUE" diff --git a/olive/search/search_point.py b/olive/search/search_point.py new file mode 100644 index 000000000..e5450c04a --- /dev/null +++ b/olive/search/search_point.py @@ -0,0 +1,70 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Dict, Tuple + +from olive.search.search_parameter import SpecialParamValue + +# ruff: noqa: PD011 + + +@dataclass +class SearchPoint: + """Search point from a search space. + + A search point is uniquely identified by an index and contains the corresponding values. + Each value is a tuple of integer and some value (the choice). The "index" is the index of + choice in all the possible choices. + + For example, for choices of ["a", "b", "c"] and value of [0, "a"], 0 is the index, + and "a" is the corresponding choice. For value of [2, "c"], 2 would be the index and + "c" would be the corresponding choice. + """ + + def __init__(self, index: int, values: Dict[str, Tuple[int, Any]]): + self.index = index + self.values = values + + def _format(self, arg: Any) -> str: + """Return a string representation of the input arg. Builds the representation recursively.""" + return ( + "{" + ", ".join(f"{k}[{i}]: {self._format(v)}" for k, (i, v) in arg.items()) + "}" + if isinstance(arg, OrderedDict) + else str(arg) + ) + + def __repr__(self): + """Return a string representation.""" + return f"SearchPoint({self.index}, {self._format(self.values)})" + + def __eq__(self, other): + """Return true if this instance is the same as the input one.""" + return ( + (self.index == other.index) and (self.values == other.values) if isinstance(other, SearchPoint) else False + ) + + def is_valid(self) -> bool: + """Return true if none of the value in the hierarchy is invalid.""" + + def _is_valid(values: Dict[str, Tuple[int, Any]]) -> bool: + for v in values.values(): + if isinstance(v, OrderedDict): + if not _is_valid(v): + return False + elif v == SpecialParamValue.INVALID: + return False + return True + + return _is_valid(self.values) + + def to_json(self): + """Return a json representation.""" + return {"index": self.index, "values": self.values} + + @classmethod + def from_json(cls, json_dict): + """Create a SearchPoint object from a json representation.""" + return cls(json_dict["index"], json_dict["values"]) diff --git a/olive/search/search_results.py b/olive/search/search_results.py new file mode 100644 index 000000000..eb9e587d9 --- /dev/null +++ b/olive/search/search_results.py @@ -0,0 +1,124 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import sys +from typing import TYPE_CHECKING, Any, Dict, List, Tuple + +import numpy as np + +if TYPE_CHECKING: + from olive.evaluator.metric_result import MetricResult + + +class SearchResults: + def __init__(self): + self._results: Tuple[MetricResult, List[str], Dict[str, Any]] = [] + self._sorted_indices: List[int] = [] + + def record_feedback_signal( + self, search_point_index: int, objectives: Dict[str, dict], result: "MetricResult", model_ids: List[str] + ): + """Record the evaluation result of a search point.""" + self._results += [None] * ((search_point_index + 1) - len(self._results)) + self._results[search_point_index] = (result, model_ids, objectives) + + def meets_goals(self, search_point_index: int) -> bool: + """Check if the result satisfies the constraints.""" + if search_point_index >= len(self._results): + return False + + if not self._results[search_point_index]: + return False + + result, _, objectives = self._results[search_point_index] + goals = {name: obj["goal"] for name, obj in objectives.items() if obj.get("goal") is not None} + if not goals: + return True # if goals are not set, always return True + + # multiplier for each objective and goals + multipliers = { + name: 1 if objective.get("higher_is_better", False) else -1 for name, objective in objectives.items() + } + return all((multipliers[obj] * result[obj].value) >= (multipliers[obj] * goal) for obj, goal in goals.items()) + + def sort(self, apply_goals: bool = False): + indices, results = self._get_results_list(apply_goals) + if not results: + self._sorted_indices = indices + return False + + # sort by objectives, left most objective has highest priority + # flip the order of the objectives since np.lexsort prioritizes the last column + # negate the results since np.lexsort sorts in ascending order + results = -np.flip(np.array(results), 1) + sorted_indices = np.lexsort(results.T) + self._sorted_indices = [indices[i] for i in sorted_indices] + + return True + + def get_next_best_result(self, start_index: int) -> Tuple[int, int, List[str]]: + assert start_index is not None, "Expecting an index, got None" + + if start_index < -1: + return None, None, None + + next_best_index = start_index + 1 + if next_best_index >= len(self._sorted_indices): + return None, None, None + + _, model_ids, _ = self._results[self._sorted_indices[next_best_index]] + return next_best_index, self._sorted_indices[next_best_index], model_ids + + def _get_results_list(self, apply_goals: bool = False) -> Tuple[List[int], List[float]]: + """Return the results as a tuple of indices and values. + + Values are multiplied by the objective multiplier so that higher is better for all objectives. + """ + all_objectives = {} + for spi, entry in enumerate(self._results): + if entry and (not apply_goals or self.meets_goals(spi)): + _, _, objectives = entry + for name in objectives: + if name in all_objectives: + assert all_objectives[name] == objectives[name].get( + "higher_is_better", False + ), "Conflicting values for higher_is_better across same named objectives" + else: + all_objectives[name] = objectives[name].get("higher_is_better", False) + + indices = [] + values = [] + if not all_objectives: + # If no objectives, then use the indices of the valid results in no specific order + indices = [spi for spi, entry in enumerate(self._results) if entry] + return indices, values + + # NOTE: values array need to be packed but a simple loop thru' each entry could + # possibly create a zagged array if the number of objectives are different. + + for spi, entry in enumerate(self._results): + if entry and (not apply_goals or self.meets_goals(spi)): + result, _, objectives = entry + if objectives: + indices.append(spi) + v = [] + for name, hib in all_objectives.items(): + if name in objectives: + v.append((1 if hib else -1) * result[name].value) + else: + v.append(-sys.maxsize - 1 if hib else sys.maxsize) + values.append(v) + + return indices, values + + def to_json(self): + """Return a json representation of the search results.""" + return {"results": self._results} + + @classmethod + def from_json(cls, json_dict): + """Create a SearchResults object from a json representation.""" + search_results = cls() + search_results._results = json_dict["results"] + return search_results diff --git a/olive/search/search_sample.py b/olive/search/search_sample.py new file mode 100644 index 000000000..2aaa9ee82 --- /dev/null +++ b/olive/search/search_sample.py @@ -0,0 +1,61 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Dict, List + +from olive.search.search_parameter import SpecialParamValue +from olive.search.search_point import SearchPoint + +# ruff: noqa: PD011 + + +@dataclass +class SearchSample: + """Search step result from search strategy. + + Includes the search point and the input model ids to use to process the search point. + """ + + def __init__(self, search_point: SearchPoint, model_ids: List[str]): + self.search_point = search_point + self.model_ids = model_ids + + def __repr__(self): + """Return the string representation.""" + return f"SearchSample({self.search_point.index}, {self.passes_configs}, {self.model_ids})" + + @property + def passes_configs(self) -> Dict[str, Any]: + """Return the pass config that can be merged with the workflow config. + + If any value in the hierarchy is SearchParameter.INVALID, return value would be None. + If any value in the hierarchy is SearchParameter.IGNORED, return value exclude these parameters. + """ + passes_configs = OrderedDict() + for pass_name, (pass_index, params) in self.search_point.values.items(): + passes_configs[pass_name] = OrderedDict( + [ + ("index", pass_index), + ("params", OrderedDict()), + ] + ) + + for param_name, (_, param_value) in params.items(): + if param_value == SpecialParamValue.INVALID: + return None # Prune out invalid configurations + elif param_value != SpecialParamValue.IGNORED: + passes_configs[pass_name]["params"][param_name] = param_value + + return passes_configs + + def to_json(self): + """Return a json representation.""" + return {"search_point": self.search_point.to_json(), "model_ids": self.model_ids} + + @classmethod + def from_json(cls, json_dict): + """Create a SearchSample object from a json representation.""" + return cls(SearchPoint.from_json(json_dict["search_point"]), json_dict["model_ids"]) diff --git a/olive/search/search_space.py b/olive/search/search_space.py new file mode 100644 index 000000000..caa7fae17 --- /dev/null +++ b/olive/search/search_space.py @@ -0,0 +1,230 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict +from typing import Any, Dict, Generator, List, Tuple, Union + +from olive.search.search_parameter import Categorical, Conditional, SearchParameter +from olive.search.search_point import SearchPoint +from olive.search.utils import order_search_parameters + + +class SearchSpace: + """Search space for sampling. + + While a SearchParameter represents a tree (leaf nodes are the possible choices), + SearchSpace represents multiple trees (each tree with its own possible choices). + Also, each intermediate node in the SearchSpace could also either be a SearchParameter + or a SearchSpace, generating its own possible choices. + + Note that the length in this context is the same as the number of possible unique choices. + + A "index" represent the index across all of the possible permutations of the leaf nodes. + An result of indexing is deterministic only if the length of all the intermediate nodes + in the tree are constant. + + Catch: Conditional search parameter don't have a constant length since the length depend on + the fixed values chosen so far. To circumvent the limitation, the maximum of all of + the possible lengths of the choices is used i.e. ignoring the parents. The default value + of the parameter is used to pad the suggestions. This will generate duplicate search points + which will be discarded if the default is invalid, or would early-out during evaluation from + the cache. + + Indexing logic: + Consider the following search space, + [ + ["A", ["a"]], + ["B", ["b", "c"]], + ["C", ["d", "e", "f"]] + ] + + The length of this search space would be 6 i.e. len(A) * len(B) * len(C) = 1 * 2 * 3 = 6. + + Given the index, to compute the values of each search point at that index, + - for each param, p, in search space + - Index of choice for that parameter would be index % len(p) + - Update index to index / len(p) + + Here's a list of all possible choices - + index | values + 0 | {"A": [0, "a"], "B": [0, "b"], "C": [0, "d"]} + 1 | {"A": [0, "a"], "B": [1, "c"], "C": [0, "d"]} + 2 | {"A": [0, "a"], "B": [0, "b"], "C": [1, "e"]} + 3 | {"A": [0, "a"], "B": [1, "c"], "C": [1, "e"]} + 4 | {"A": [0, "a"], "B": [0, "b"], "C": [2, "f"]} + 5 | {"A": [0, "a"], "B": [1, "c"], "C": [2, "f"]} + + For index = 0, + At parameter "A", index = 0 % len(A) = 0, carry-forward = 0 / len(A) = 0 => [0, "a"] + At parameter "B", index = 0 % len(B) = 0, carry-forward = 0 / len(B) = 0 => [0, "b"] + At parameter "C", index = 0 % len(C) = 0, carry-forward = 0 / len(C) = 0 => [0, "d"] + + For index = 1, + At parameter "A", index = 1 % len(A) = 0, carry-forward = 1 / len(A) = 1 => [0, "a"] + At parameter "B", index = 1 % len(B) = 1, carry-forward = 1 / len(B) = 0 => [1, "c"] + At parameter "C", index = 0 % len(C) = 0, carry-forward = 0 / len(C) = 0 => [0, "d"] + + For index = 2, + At parameter "A", index = 2 % len(A) = 0, carry-forward = 2 / len(A) = 2 => [0, "a"] + At parameter "B", index = 2 % len(B) = 0, carry-forward = 2 / len(B) = 1 => [0, "b"] + At parameter "C", index = 1 % len(C) = 1, carry-forward = 1 / len(C) = 0 => [1, "e"] + + For index = 3, + At parameter "A", index = 3 % len(A) = 0, carry-forward = 3 / len(A) = 3 => [0, "a"] + At parameter "B", index = 3 % len(B) = 1, carry-forward = 3 / len(B) = 1 => [1, "c"] + At parameter "C", index = 1 % len(C) = 1, carry-forward = 1 / len(C) = 0 => [1, "e"] + + For index = 4, + At parameter "A", index = 4 % len(A) = 0, carry-forward = 4 / len(A) = 4 => [0, "a"] + At parameter "B", index = 4 % len(B) = 0, carry-forward = 4 / len(B) = 2 => [0, "b"] + At parameter "C", index = 2 % len(C) = 2, carry-forward = 2 / len(C) = 0 => [2, "f"] + + For index = 5, + At parameter "A", index = 5 % len(A) = 0, carry-forward = 5 / len(A) = 5 => [0, "a"] + At parameter "B", index = 5 % len(B) = 1, carry-forward = 5 / len(B) = 2 => [1, "c"] + At parameter "C", index = 2 % len(C) = 2, carry-forward = 2 / len(C) = 0 => [2, "f"] + + The logic can be extrapolated to any number of parameters each with any number of choices + as long the number of choices remains a constant at any given parameter. + """ + + def __init__(self, parameters: List[Tuple[str, Union[SearchParameter, "SearchSpace"]]]): + assert len(parameters) == len( + {name for name, _ in parameters} + ), "Parameter name in search space should be unique." + + self._parameters = self._order_search_space(parameters) + + # Consider the following basic search space scenarios - + # + # [["a"]] => len = 1 + # [["a"], [1]] => len = 1 + # [["a", "b"], [1]] => len = 2 + # [["a", "b"], [1, 2]] => len = 4 + # [["a", "b"], [1, 2, 3]] => len = 6 + # + # Extrapolating to n parameters with Mi search points each, the total + # search points would be the product of number of search points in each. + # So, for a search space with n parameters, [M1, M2, M3, ... Mn], + # len(SearchSpace) = len(M1) * len(M2) * len(M3) * .... * len(Mn) + + self._length = 1 + for _, param in self._parameters: + self._length *= SearchSpace.get_param_length(param) + + @property + def parameters(self) -> List[Tuple[str, Union[SearchParameter, "SearchSpace"]]]: + """Return the parameters of this search space.""" + return self._parameters + + def __repr__(self): + """Return the string representation of this search space.""" + return f"SearchSpace({self._parameters}, {self._length})" + + def __len__(self) -> int: + """Return the length i.e. total number of possible search points in this search space.""" + return self._length + + def __iter__(self) -> Generator[SearchPoint, None, None]: + """Iterate search points in this search space.""" + for index in range(self._length): + yield self[index] + + def __getitem__(self, index: int) -> SearchPoint: + """Return search point by index.""" + assert index < self._length + return SearchPoint(index, self.get_sample_point_values(index)) + + def _order_search_space(self, parameters: List[Tuple[str, SearchParameter]]) -> List[Tuple[str, SearchParameter]]: + """Order the search space by topological order of parameters for each pass_id/space_name.""" + unordered_nodes = dict(parameters) + ordered_nodes = order_search_parameters(unordered_nodes) + return [(name, unordered_nodes[name]) for name in ordered_nodes] + + def get_sample_point_values(self, index: int) -> Dict[str, Tuple[int, Any]]: + """Iterate parameters of this search space to generate a search point.""" + assert index < self._length + + values = OrderedDict() + for name, param in self._parameters: + index, values[name] = SearchSpace.get_suggestion(param, index, values) + return values + + @staticmethod + def get_param_length(param: Any) -> int: + """Return the length (computed recursively) of the input parameter.""" + if isinstance(param, SearchParameter): + if isinstance(param, Categorical): + return sum( + ( + SearchSpace.get_param_length(suggestion) + if isinstance(suggestion, (SearchParameter, SearchSpace)) + else 1 + ) + for suggestion in param.get_support() + ) + + elif isinstance(param, Conditional): + # For conditional search parameters, length is computed based on the support + # that has the most choices. See explanation above. + return max(SearchSpace.get_param_length(support) for support in param.support.values()) + + elif isinstance(param, SearchSpace): + return len(param) + + return 0 + + @staticmethod + def get_param_suggestions(param: Any, values: Dict[str, Any]) -> Union[List[Any], "SearchSpace"]: + """Return the suggestions for the input param based on the values chosen so far.""" + if isinstance(param, SearchParameter): + if isinstance(param, Categorical): + return param.get_support() + + elif isinstance(param, Conditional): + parent_values = {k: values[k][1] for k in param.parents} + suggestions = param.get_support_with_args(parent_values) + # Pad the suggestions to maximum length using the default value of the param. + max_length = max(SearchSpace.get_param_length(support) for support in param.support.values()) + suggestions += param.default.get_support() * (max_length - len(suggestions)) + return suggestions + + elif isinstance(param, SearchSpace): + return param + + return [] + + @staticmethod + def get_suggestion(param: Any, index: int, values: Dict[str, Any]) -> Tuple[int, Tuple[int, Any]]: + """Recursively, compute the values for the input param based on the index. + + Each entry is a tuple of the index in the list of suggestions for that param and the corresponding choice. + """ + length = SearchSpace.get_param_length(param) + + if index < length: + if isinstance(param, SearchParameter): + suggestions = SearchSpace.get_param_suggestions(param, values) + + for i, suggestion in enumerate(suggestions): + if isinstance(suggestion, (SearchParameter, SearchSpace)): + suggestion_length = SearchSpace.get_param_length(suggestion) + if index < suggestion_length: + _, (_, i_suggestion) = SearchSpace.get_suggestion(suggestion, index, values) + return 0, (i, i_suggestion) + else: + index -= suggestion_length + elif index > 0: + index -= 1 + else: + return 0, (i, suggestion) + + elif isinstance(param, SearchSpace): + return 0, (index, param.get_sample_point_values(index)) + + else: + return index, param + + _, suggestion = SearchSpace.get_suggestion(param, index % length, values) + return index // length, suggestion diff --git a/olive/search/search_strategy.py b/olive/search/search_strategy.py new file mode 100644 index 000000000..38373dc7c --- /dev/null +++ b/olive/search/search_strategy.py @@ -0,0 +1,345 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +import time +from copy import deepcopy +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple, Union + +from olive.common.config_utils import CaseInsensitiveEnum, ConfigBase, NestedConfig, validate_config, validate_enum +from olive.common.pydantic_v1 import validator +from olive.search.samplers import REGISTRY, SearchSampler +from olive.search.search_parameter import Categorical +from olive.search.search_results import SearchResults +from olive.search.search_sample import SearchSample +from olive.search.search_space import SearchSpace + +if TYPE_CHECKING: + from olive.evaluator.metric_result import MetricResult + from olive.search.search_parameter import SearchParameter + +logger = logging.getLogger(__name__) + +# ruff: noqa: PD011 + + +class SearchStrategyExecutionOrder(CaseInsensitiveEnum): + JOINT = "joint" + PASS_BY_PASS = "pass-by-pass" + + +class SearchStrategyConfig(NestedConfig): + _nested_field_name = "sampler_config" + execution_order: Union[str, SearchStrategyExecutionOrder] = None + sampler: str = None + sampler_config: ConfigBase = None + output_model_num: int = None + stop_when_goals_met: bool = False + max_iter: int = None + max_time: int = None + include_pass_params: bool = True + + @validator("execution_order", pre=True) + def _validate_execution_order(cls, v): + return validate_enum(SearchStrategyExecutionOrder, v) + + @validator("sampler", pre=True) + def _validate_sampler(cls, v): + if v not in REGISTRY: + raise ValueError(f"Unknown sampler: {v}") + return v + + @validator("sampler_config", pre=True, always=True) + def _validate_sampler_config(cls, v, values): + if "sampler" not in values: + raise ValueError("Invalid sampler") + + config_class = REGISTRY[values["sampler"]].get_config_class() + return validate_config(v, config_class) + + @validator("stop_when_goals_met", "max_iter", "max_time", pre=True) + def _validate_stop_when_goals_met(cls, v, values, field): + if "execution_order" not in values: + raise ValueError("Invalid execution_order") + if v and values["execution_order"] != SearchStrategyExecutionOrder.JOINT: + logger.info("%s is only supported for joint execution order. Ignoring...", field.name) + return field.default + return v + + +@dataclass +class SearchWalkState: + """A simple data class to hold the state while traversing the search space. + + Each instance of this class holds data (sampler to use, results of evaluation, etc.) + for a single search space. + """ + + def __init__(self, path: List[int], model_ids: List[str], sampler: SearchSampler, results: SearchResults): + # Unique identification for the state. + self.path: List[int] = deepcopy(path) + + # Sampler to use for the relevant/owning search space + self.sampler: SearchSampler = sampler + + # Result of evaluating generated samples + self.results: SearchResults = results + + # Input model ids to be used for processing the generated sample + self.model_ids: List[str] = model_ids + + # Once the search space has exhausted all its samples, the results + # are sorted to find the order in which to move to the next search + # space. This is the index in the those sorted results. + self.best_result_index = -1 + + +class SearchStrategy: + def __init__(self, config: Union[Dict[str, Any], SearchStrategyConfig]): + self.config: SearchStrategyConfig = validate_config(config, SearchStrategyConfig) + + # Initialization variables + self._search_spaces: List[SearchSpace] = None + self._objectives: Dict[str, Dict[str, Any]] = None + self._init_model_id: str = None + + # State variables + self._path: List[int] = None + self._state: Dict[Tuple, SearchWalkState] = None + + # self._iteration_count and self._num_samples_suggested include counts across all search spaces. + # For specific counts, query the sampler corresponding to the specific search space + # i.e. SearchWalkState.sampler.num_samples_suggested. + # Also, note that the iteration count includes invalid search points, but num_samples_suggested doesn't. + # Invalid search points are automatically discarded during iteration. + + self._start_time: float = 0 + self._iteration_count: int = 0 + self._num_samples_suggested: int = 0 + + self._initialized: bool = False + + def initialize( + self, + space_config: Dict[str, List[Dict[str, "SearchParameter"]]], + init_model_id: str, + objectives: Dict[str, List[Dict[str, Dict[str, Any]]]], + ): + """Initialize the search strategy. + + space_config: Ordered dictionary of format {pass_name, [{param_name: SearchParameter}]} + init_model_id: Input model id to use to start searching. + objectives: dictionary of format {objective_name: {"higher_is_better": bool, "goal": float}} + + Depending on the execution order, we could generate either a single search space (for join mode) + or multiple search spaces (for pass-by-pass mode). However, the logic in how we process these + search spaces not differ. + """ + # for the fist search group, init_model_id must be provided + if not init_model_id: + raise ValueError("init_model_id must be provided for search") + + if self.config.execution_order == SearchStrategyExecutionOrder.JOINT: + self._search_spaces = [ + SearchSpace( + [ + (pass_name, Categorical([SearchSpace(list(params.items())) for params in passes])) + for pass_name, passes in space_config.items() + ] + ) + ] + elif self.config.execution_order == SearchStrategyExecutionOrder.PASS_BY_PASS: + self._search_spaces = [ + SearchSpace([(pass_name, Categorical([SearchSpace(list(params.items())) for params in passes]))]) + for pass_name, passes in space_config.items() + ] + else: + raise ValueError(f"Unsupported execution order: {self.config.execution_order}") + + if objectives: + for pass_name, pass_params in space_config.items(): + pass_objectives = objectives.get(pass_name) + assert not pass_objectives or len(pass_objectives) == len( + pass_params + ), "Expect none or all passes to have objectives" + + self._objectives = objectives or {} + self._init_model_id = init_model_id + + # Note that the state variables will be initialized at start of iteration. + self._initialized = True + + @property + def search_spaces(self): + """Returns the list of search spaces.""" + return self._search_spaces + + @property + def iteration_count(self) -> int: + """Returns the number of iterations so far across all search spaces.""" + return self._iteration_count + + @property + def start_time(self) -> float: + """Returns the start time of current iteration.""" + return self._start_time + + @property + def elapsed_time(self) -> float: + """Returns elapsed time of the current iteration.""" + return (time.time() - self._start_time) if self._start_time else 0 + + @property + def num_samples_suggested(self) -> int: + """Returns the number of samples suggested so far across all search spaces.""" + return self._num_samples_suggested + + @property + def max_samples(self) -> int: + """Returns the maximum number of samples.""" + count = 1 + for space in self._search_spaces: + count *= len(space) + return count + + def __iter__(self) -> Generator[SearchSample, None, None]: + # Initialize the state variables + self._path = [] + search_space = self._search_spaces[len(self._path)] + self._state = { + tuple(self._path): SearchWalkState( + self._path, + [self._init_model_id], + self._create_sampler(search_space), + self._create_results(), + ) + } + + self._start_time = time.time() + self._iteration_count = 0 + self._num_samples_suggested = 0 + + while True: + state = self._state[tuple(self._path)] + + while not state.sampler.should_stop: + self._iteration_count += 1 + search_point = state.sampler.suggest() + + # Discard invalid search points + if not search_point.is_valid(): + continue + + self._num_samples_suggested += 1 + yield SearchSample(search_point, state.model_ids) + + # If this is the last pass in the walk, evaluate the model to see if all goals are met. + if ( + self.config.stop_when_goals_met + and (len(self._path) == (len(self._search_spaces) - 1)) + and state.results.meets_goals(search_point.index) + ): + return None + + # Check is any of the global stop criteria are met. + # NOTE: Search will run at least one step before stopping. + if self.should_stop: + return None + + # Try stepping down the tree, and if that fails, try stepping up. If both fails, we are done + if not self._step_down() and not self._step_up(): + return None + + def _create_sampler(self, search_space: SearchSpace) -> SearchSampler: + """Create a search sampler.""" + if self.config.sampler not in REGISTRY: + raise ValueError(f"Unsupported search sampler: {self.config.sampler}") + + return REGISTRY[self.config.sampler](search_space, self.config.sampler_config) + + def _create_results(self) -> SearchResults: + """Create and return a search result.""" + return SearchResults() + + def _initialize_step(self) -> bool: + state = self._state[tuple(self._path)] + + # NOTE: Two possible scenarios for pass-by-pass mode - + # 1. All search points are evaluated for each search space before moving down the tree. + # 2. Evaluate search points until we find a suitable candidate and move down. If failed at end, move + # up to continue finding next candidate. + # Implementing option 1 currently i.e. all search points are evaluated for search space before moving down. + # Logic here can be customized to support the other if need be. + + # Get the next best result index + state.best_result_index, next_search_point, next_model_ids = state.results.get_next_best_result( + state.best_result_index + ) + if state.best_result_index is not None: + self._path.append(next_search_point) + search_space = self._search_spaces[len(self._path)] + self._state[tuple(self._path)] = SearchWalkState( + self._path, + next_model_ids, + self._create_sampler(search_space), + self._create_results(), + ) + return True + + return False + + def _step_up(self) -> bool: + """Step back to the previous search space on stack to evaluate based on the next best sample.""" + if not self._path: + return False + + self._path.pop() + + # Results here are already sorted, don't have to do it again! + return self._initialize_step() + + def _step_down(self) -> bool: + """Step down to the next search space in queue.""" + if len(self._path) == (len(self._search_spaces) - 1): + return False + + state = self._state[tuple(self._path)] + + # Current state is potentially modified, sort the collected results again! + if not state.results.sort(apply_goals=True): + logger.warning( + "No models in path %s met the goals. Sorting the models without applying goals...", self._path + ) + state.results.sort(apply_goals=False) + + return self._initialize_step() + + def record_feedback_signal( + self, + search_point_index: int, + signal: "MetricResult", + model_ids: List[str], + should_prune: bool = False, + ): + """Record the feedback signal for the given search point.""" + assert self._initialized, "Search strategy is not initialized" + + state = self._state[tuple(self._path)] + search_space = self._search_spaces[len(self._path)] + search_point = search_space[search_point_index] + pass_name, _ = search_space.parameters[-1] + pass_index, _ = search_point.values[pass_name] + passes_objectives = self._objectives.get(pass_name, []) + objectives = passes_objectives[pass_index] if pass_index < len(passes_objectives) else {} + state.results.record_feedback_signal(search_point_index, objectives, signal, model_ids) + state.sampler.record_feedback_signal(search_point_index, objectives, signal, should_prune) + + @property + def should_stop(self): + """Check if the search should stop.""" + # NOTE: Individual goal criteria is checked at the end of each step in the iteration loop + return ((self.config.max_iter is not None) and (self._iteration_count > self.config.max_iter)) or ( + (self.config.max_time is not None) and (self.elapsed_time > self.config.max_time) + ) diff --git a/olive/strategy/utils.py b/olive/search/utils.py similarity index 97% rename from olive/strategy/utils.py rename to olive/search/utils.py index 4839b7701..2fe2a4efa 100644 --- a/olive/strategy/utils.py +++ b/olive/search/utils.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- from typing import Dict, List, Set, Tuple -from olive.strategy.search_parameter import Conditional, SearchParameter +from olive.search.search_parameter import Conditional, SearchParameter class DirectedGraph: diff --git a/olive/strategy/search_algorithm/__init__.py b/olive/strategy/search_algorithm/__init__.py deleted file mode 100644 index 8d965b1e2..000000000 --- a/olive/strategy/search_algorithm/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -from olive.strategy.search_algorithm.exhaustive import ExhaustiveSearchAlgorithm -from olive.strategy.search_algorithm.random_sampler import RandomSearchAlgorithm -from olive.strategy.search_algorithm.search_algorithm import SearchAlgorithm -from olive.strategy.search_algorithm.tpe_sampler import TPESearchAlgorithm - -REGISTRY = SearchAlgorithm.registry - -__all__ = ["REGISTRY", "ExhaustiveSearchAlgorithm", "RandomSearchAlgorithm", "SearchAlgorithm", "TPESearchAlgorithm"] diff --git a/olive/strategy/search_algorithm/exhaustive.py b/olive/strategy/search_algorithm/exhaustive.py deleted file mode 100644 index 508520d28..000000000 --- a/olive/strategy/search_algorithm/exhaustive.py +++ /dev/null @@ -1,33 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -from typing import Any, Dict - -from olive.strategy.search_algorithm.search_algorithm import SearchAlgorithm - - -class ExhaustiveSearchAlgorithm(SearchAlgorithm): - """Exhaustive Search Algorithm. Does a grid search over the search space.""" - - name = "exhaustive" - - @classmethod - def _default_config(cls): - return {} - - def initialize(self): - """Initialize the searcher.""" - # pylint: disable=attribute-defined-outside-init - self._iterator = self._search_space.iterate() - - def suggest(self) -> Dict[str, Dict[str, Any]]: - """Suggest a new configuration to try.""" - try: - return next(self._iterator) - except StopIteration: - return None - - def report(self, search_point: Dict[str, Dict[str, Any]], result: Dict[str, Any], should_prune: bool = False): - """Report the result of a configuration.""" - return diff --git a/olive/strategy/search_algorithm/optuna_sampler.py b/olive/strategy/search_algorithm/optuna_sampler.py deleted file mode 100644 index f969bf5a0..000000000 --- a/olive/strategy/search_algorithm/optuna_sampler.py +++ /dev/null @@ -1,101 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Dict, Tuple - -import optuna - -from olive.common.config_utils import ConfigParam -from olive.common.utils import hash_dict -from olive.strategy.search_algorithm.search_algorithm import SearchAlgorithm -from olive.strategy.search_parameter import Categorical, Conditional, SpecialParamValue - -if TYPE_CHECKING: - from olive.evaluator.metric_result import MetricResult - - -optuna.logging.set_verbosity(optuna.logging.WARNING) - - -class OptunaSearchAlgorithm(SearchAlgorithm): - """Optuna sampler for search algorithms.""" - - name = "optuna_sampler" - - @classmethod - def _default_config(cls) -> Dict[str, ConfigParam]: - return { - "num_samples": ConfigParam(type_=int, default_value=1, description="Number of samples to suggest."), - "seed": ConfigParam(type_=int, default_value=1, description="Seed for the rng."), - } - - def initialize(self): - # pylint: disable=attribute-defined-outside-init - """Initialize the searcher.""" - self._sampler = self._create_sampler() - directions = ["maximize" if higher_is_better else "minimize" for higher_is_better in self._higher_is_betters] - self._study = optuna.create_study(directions=directions, sampler=self._sampler) - self._trial_ids = {} - self._num_samples_suggested = 0 - - def should_stop(self): - return ( - (self._search_space.empty() and self._num_samples_suggested > 0) - or (self._num_samples_suggested >= self.config.num_samples) - or super().should_stop() - ) - - @abstractmethod - def _create_sampler(self) -> optuna.samplers.BaseSampler: - """Create the sampler.""" - - def suggest(self) -> Dict[str, Dict[str, Any]]: - """Suggest a new configuration to try.""" - if self.should_stop(): - return None - - trial, search_point, invalid = self._get_trial() - if invalid: - self._study.tell(trial.number, state=optuna.trial.TrialState.PRUNED) - return self.suggest() - - # save history - search_point_hash = hash_dict(search_point) - self._trial_ids[search_point_hash] = trial.number - - self._num_samples_suggested += 1 - - return search_point - - def _get_trial(self) -> Tuple[optuna.trial.Trial, Dict[str, Dict[str, Any]]]: - """Get a trial from the study.""" - trial = self._study.ask() - search_point = self._search_space.empty_search_point() - invalid_trial = False - for space_name, param_name, param in self._search_space.iter_params(): - if space_name not in search_point: - search_point[space_name] = {} - suggestion_name = f"{space_name}___{param_name}" - if isinstance(param, Categorical): - search_point[space_name][param_name] = trial.suggest_categorical(suggestion_name, param.get_support()) - elif isinstance(param, Conditional): - parent_vals = {parent: search_point[space_name][parent] for parent in param.parents} - options = param.get_support_with_args(parent_vals) - parent_vals_name = "_".join([f"{v}" for _, v in parent_vals.items()]) - suggestion_name = f"{space_name}___{param_name}___{parent_vals_name}" - search_point[space_name][param_name] = trial.suggest_categorical(suggestion_name, options) - else: - raise ValueError(f"Unsupported parameter type: {type(param)}") - invalid_trial = invalid_trial or (search_point[space_name][param_name] == SpecialParamValue.INVALID) - return trial, search_point, invalid_trial - - def report(self, search_point: Dict[str, Dict[str, Any]], result: "MetricResult", should_prune: bool = False): - search_point_hash = hash_dict(search_point) - trial_id = self._trial_ids[search_point_hash] - if should_prune: - self._study.tell(trial_id, state=optuna.trial.TrialState.PRUNED) - else: - objectives = [result[objective].value for objective in self._objectives] - self._study.tell(trial_id, objectives) diff --git a/olive/strategy/search_algorithm/random_sampler.py b/olive/strategy/search_algorithm/random_sampler.py deleted file mode 100644 index 0f4c66f4d..000000000 --- a/olive/strategy/search_algorithm/random_sampler.py +++ /dev/null @@ -1,58 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -from typing import Any, Dict - -from olive.common.config_utils import ConfigParam -from olive.strategy.search_algorithm.search_algorithm import SearchAlgorithm - - -class RandomSearchAlgorithm(SearchAlgorithm): - """Random Searcher. Samples random points from the search space with or without replacement.""" - - name = "random" - - @classmethod - def _default_config(cls) -> Dict[str, ConfigParam]: - return { - "num_samples": ConfigParam(type_=int, default_value=1, description="Number of samples to suggest."), - "seed": ConfigParam(type_=int, default_value=1, description="Seed for the rng."), - "with_replacement": ConfigParam(type_=bool, default_value=False, description="Sample with replacement."), - } - - def initialize(self): - """Initialize the searcher.""" - # pylint: disable=attribute-defined-outside-init - self._search_space.set_seed(self.config.seed) - if not self.config.with_replacement: - self._options = list(self._search_space.iterate()) - self._num_samples_suggested = 0 - - def should_stop(self): - should_stop = (self._search_space.empty() and self._num_samples_suggested > 0) or ( - self._num_samples_suggested >= self.config.num_samples - ) - if not self.config.with_replacement: - should_stop = should_stop or (len(self._options) == 0) - return should_stop or super().should_stop() - - def suggest(self) -> Dict[str, Dict[str, Any]]: - """Suggest a new configuration to try.""" - if self.should_stop(): - return None - - if self.config.with_replacement: - # sample a randrom point from the search space with replacement - search_point = self._search_space.random_sample() - else: - # sample a random point from the search space without replacement - search_point = self._search_space.rng.choice(self._options) - self._options.remove(search_point) - - self._num_samples_suggested += 1 - - return search_point - - def report(self, search_point: Dict[str, Dict[str, Any]], result: Dict[str, Any], should_prune: bool = False): - """Report the result of a configuration.""" diff --git a/olive/strategy/search_algorithm/search_algorithm.py b/olive/strategy/search_algorithm/search_algorithm.py deleted file mode 100644 index 6321755ec..000000000 --- a/olive/strategy/search_algorithm/search_algorithm.py +++ /dev/null @@ -1,59 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -from abc import abstractmethod -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Type, Union - -from olive.common.auto_config import AutoConfigClass -from olive.common.config_utils import ConfigBase -from olive.strategy.search_space import SearchSpace - -if TYPE_CHECKING: - from olive.strategy.search_parameter import SearchParameter - - -class SearchAlgorithm(AutoConfigClass): - """Abstract base class for searchers.""" - - name: str = None - registry: ClassVar[Dict[str, Type["SearchAlgorithm"]]] = {} - - def __init__( - self, - search_space: Dict[str, Dict[str, "SearchParameter"]], - objectives: Optional[List[str]] = None, - higher_is_betters: Optional[List[bool]] = None, - config: Optional[Union[Dict[str, Any], ConfigBase]] = None, - ): - # search space - self._search_space = SearchSpace(search_space) - if self._search_space.size() == 0: - raise ValueError("There are no valid points in the search space.") - - # objectives and directions - objectives = objectives or [] - higher_is_betters = higher_is_betters or [] - assert len(objectives) == len(higher_is_betters), "Number of objectives must match number of higher_is_betters" - self._objectives = objectives - self._higher_is_betters = higher_is_betters - - super().__init__(config) - - @abstractmethod - def initialize(self): - """Initialize the searcher.""" - - def should_stop(self): - """Check if the searcher should prune the current trial.""" - return False - - @abstractmethod - def suggest(self) -> Dict[str, Dict[str, Any]]: - """Suggest a new configuration to try.""" - - @abstractmethod - def report( - self, search_point: Dict[str, Dict[str, Any]], result: Dict[str, Union[float, int]], should_prune: bool = False - ): - """Report the result of a configuration.""" diff --git a/olive/strategy/search_results.py b/olive/strategy/search_results.py deleted file mode 100644 index cfb9b4fd7..000000000 --- a/olive/strategy/search_results.py +++ /dev/null @@ -1,131 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Tuple - -import numpy as np - -from olive.common.utils import hash_dict - -if TYPE_CHECKING: - from olive.evaluator.metric_result import MetricResult - - -class SearchResults: - def __init__( - self, - objective_dict: Dict[str, dict], - init_model_history: Dict[str, Any] = None, - ): - self.objective_dict = objective_dict - # objectives and directions of optimization - self.objectives = list(objective_dict.keys()) - self.higher_is_betters = [objective_dict[obj]["higher_is_better"] for obj in self.objectives] - # multiplier for each objective - self.obj_mul = {obj: 1 if hib else -1 for obj, hib in zip(self.objectives, self.higher_is_betters)} - - # objective goal values - self.goals = {} - for name, obj in self.objective_dict.items(): - if obj["goal"] is not None: - self.goals[name] = obj["goal"] - - # Record of the search path that led to the init model - # Of the form {"search_point": ..., "result": ..., "model_ids": ...} - self.init_model_history = init_model_history - - # search results state - self.search_point_hash_table = {} - self.results = {} - self.model_ids = {} - - def record(self, search_point: Dict[str, Dict[str, Any]], result: "MetricResult", model_ids: List[str]): - """Report the result of a configuration.""" - search_point_hash = hash_dict(search_point) - self.search_point_hash_table[search_point_hash] = deepcopy(search_point) - self.results[search_point_hash] = deepcopy(result) - self.model_ids[search_point_hash] = model_ids - - def check_goals(self, result: "MetricResult") -> bool: - """Check if the result satisfies the constraints.""" - # if goals are not set, return True always - if not self.goals: - return True - - for obj, goal in self.goals.items(): - if self.obj_mul[obj] * result[obj].value < self.obj_mul[obj] * goal: - return False - return True - - def sort_search_points(self, objectives: List[str] = None, apply_goals: bool = False) -> List[str]: - """Return the search points sorted by the objectives.""" - # TODO(trajep): this function only works for pass-by-pass execution order, but only return with the first model - # with the best latency results which is not correct. One pass may generate multiple models. - if objectives is None: - objectives = self.objectives - else: - assert set(objectives).issubset(self.objectives) - - results, search_point_hashes = self._get_results_list(objectives, apply_goals) - if not results: - return None, None, None - - # sort by objectives, left most objective has highest priority - # flip the order of the objectives since np.lexsort prioritizes the last column - # negate the results since np.lexsort sorts in ascending order - results = -np.flip(np.array(results), 1) - sorted_indices = np.lexsort(results.T) - sorted_hashes = [search_point_hashes[i] for i in sorted_indices] - - # get model numbers - sorted_model_ids = [self.model_ids[point_hash] for point_hash in sorted_hashes] - sorted_results = [self.results[point_hash] for point_hash in sorted_hashes] - # TODO(jambayk): this will be done using helper later - sorted_search_points = [self.search_point_hash_table[point_hash] for point_hash in sorted_hashes] - return sorted_model_ids, sorted_search_points, sorted_results - - def _get_results_list( - self, objectives: List[str] = None, apply_goals: bool = False - ) -> Tuple[List[List[float]], List[str]]: - """Return the results as a list of lists. - - Values are multiplied by the objective multiplier so that higher is better for all objectives. - """ - if objectives is None: - objectives = self.objectives - else: - assert set(objectives).issubset(self.objectives) - - search_point_hashes = [] - results = [] - for search_point_hash in self.results: - result = self.results[search_point_hash] - if not result: - continue - if apply_goals and not self.check_goals(result): - continue - search_point_hashes.append(search_point_hash) - results.append([self.obj_mul[obj] * result[obj].value for obj in objectives]) - - return results, search_point_hashes - - def to_json(self): - """Return a json representation of the search results.""" - return { - "objective_dict": self.objective_dict, - "init_model_history": self.init_model_history, - "results": self.results, - "model_ids": self.model_ids, - "search_point_hash_table": self.search_point_hash_table, - } - - @classmethod - def from_json(cls, json_dict): - """Create a SearchResults object from a json representation.""" - search_results = cls(json_dict["objective_dict"], json_dict["init_model_history"]) - search_results.search_point_hash_table = json_dict["search_point_hash_table"] - search_results.results = json_dict["results"] - search_results.model_ids = json_dict["model_ids"] - return search_results diff --git a/olive/strategy/search_space.py b/olive/strategy/search_space.py deleted file mode 100644 index e29d26e59..000000000 --- a/olive/strategy/search_space.py +++ /dev/null @@ -1,111 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -from copy import deepcopy -from random import Random -from typing import Any, Dict, Iterator, List, Optional, Tuple - -from olive.strategy.search_parameter import Categorical, Conditional, SearchParameter, SpecialParamValue -from olive.strategy.utils import order_search_parameters - - -class SearchSpace: - """Search space for a search algorithm.""" - - def __init__(self, search_space: Dict[str, Dict[str, SearchParameter]], seed: Optional[int] = 1): - # search_space is dictionary of format: {"pass_id/space_name": {"param_name": SearchParameter} - self._search_space = deepcopy(search_space) - self._iter_order = self._order_search_space(self._search_space) - self._empty_search_point = {space_name: {} for space_name in self._search_space} - self._seed = seed - self.rng = Random(self._seed) - - def _order_search_space(self, search_space) -> List[Tuple[str, str]]: - """Order the search space by topological order of parameters for each pass_id/space_name.""" - full_iter_order = [] - for space_name, space_item in search_space.items(): - iter_order = order_search_parameters(space_item) - full_iter_order.extend([(space_name, param_name) for param_name in iter_order]) - return full_iter_order - - def set_seed(self, seed: int): - """Set the random seed for the search space.""" - self._seed = seed - self.reset_rng() - - def reset_rng(self): - """Reset the random number generator.""" - self.rng = Random(self._seed) - - def random_sample(self) -> Dict[str, Dict[str, Any]]: - """Sample a random configuration from the search space.""" - # initialize search point - search_point = deepcopy(self._empty_search_point) - - # sample from search space - for space_name, param_name in self._iter_order: - param = self._search_space[space_name][param_name] - options = [] - if isinstance(param, Conditional): - parent_vals = {parent: search_point[space_name][parent] for parent in param.parents} - options = param.get_support_with_args(parent_vals) - elif isinstance(param, Categorical): - options = param.get_support() - search_point[space_name][param_name] = self.rng.choice(options) - if search_point[space_name][param_name] == SpecialParamValue.INVALID: - return self.random_sample() - - return search_point - - def _iterate_util( - self, full_iter_order: List[Tuple[str, str]], search_point: Dict[str, Dict[str, Any]], index: int - ) -> Iterator[Dict[str, Dict[str, Any]]]: - if index == len(full_iter_order): - yield deepcopy(search_point) - return - - space_name, param_name = full_iter_order[index] - param = self._search_space[space_name][param_name] - - if isinstance(param, Conditional): - parent_vals = {parent: search_point[space_name][parent] for parent in param.parents} - options = param.get_support_with_args(parent_vals) - elif isinstance(param, Categorical): - options = param.get_support() - else: - return - - for option in options: - if option == SpecialParamValue.INVALID: - continue - search_point[space_name][param_name] = option - yield from self._iterate_util(full_iter_order, search_point, index + 1) - - def iterate(self) -> Iterator[Dict[str, Dict[str, Any]]]: - """Iterate over all possible configurations in the search space.""" - # initialize search point - search_point = deepcopy(self._empty_search_point) - - # iterate over search space - yield from self._iterate_util(self._iter_order, search_point, 0) - - def empty(self) -> bool: - """Check if the search space is empty.""" - return all(not v for v in self._search_space.values()) - - def size(self) -> int: - """Get the size of the search space.""" - size = 0 - for _ in self.iterate(): - size += 1 - return size - - def empty_search_point(self) -> Dict[str, Dict[str, Any]]: - """Get an empty search point.""" - return deepcopy(self._empty_search_point) - - def iter_params(self) -> Iterator[Tuple[str, str, SearchParameter]]: - """Iterate over the search parameters in topological order.""" - for space_name, param_name in self._iter_order: - yield space_name, param_name, self._search_space[space_name][param_name] diff --git a/olive/strategy/search_strategy.py b/olive/strategy/search_strategy.py deleted file mode 100644 index 677e8a81d..000000000 --- a/olive/strategy/search_strategy.py +++ /dev/null @@ -1,282 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -import logging -from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union - -from olive.common.config_utils import ConfigBase, NestedConfig, validate_config -from olive.common.pydantic_v1 import validator -from olive.strategy.search_algorithm import REGISTRY, SearchAlgorithm -from olive.strategy.search_results import SearchResults - -if TYPE_CHECKING: - from olive.evaluator.metric_result import MetricResult - from olive.strategy.search_parameter import SearchParameter - -logger = logging.getLogger(__name__) - -_VALID_EXECUTION_ORDERS = ("joint", "pass-by-pass") - -# pylint: disable=attribute-defined-outside-init - - -class SearchStrategyConfig(NestedConfig): - _nested_field_name = "search_algorithm_config" - execution_order: str - search_algorithm: str - search_algorithm_config: ConfigBase = None - output_model_num: int = None - stop_when_goals_met: bool = False - max_iter: int = None - max_time: int = None - - @validator("execution_order", pre=True) - def _validate_execution_order(cls, v): - if v not in _VALID_EXECUTION_ORDERS: - raise ValueError(f"Unknown execution order: {v}") - return v - - @validator("search_algorithm", pre=True) - def _validate_search_algorithm(cls, v): - if v not in REGISTRY: - raise ValueError(f"Unknown search algorithm: {v}") - return v - - @validator("search_algorithm_config", pre=True, always=True) - def _validate_search_algorithm_config(cls, v, values): - if "search_algorithm" not in values: - raise ValueError("Invalid search_algorithm") - - config_class = REGISTRY[values["search_algorithm"]].get_config_class() - return validate_config(v, config_class) - - @validator("stop_when_goals_met", "max_iter", "max_time", pre=True) - def _validate_stop_when_goals_met(cls, v, values, field): - if "execution_order" not in values: - raise ValueError("Invalid execution_order") - if v and values["execution_order"] != "joint": - logger.info("%s is only supported for joint execution order. Ignoring...", field.name) - return field.default - return v - - -class SearchStrategy: - def __init__(self, config: Union[Dict[str, Any], SearchStrategyConfig]): - self._config = validate_config(config, SearchStrategyConfig) - self._initialized = False - self.exit_criteria_met = False - - def initialize( - self, - pass_flows_search_spaces: List[List[Tuple[str, Dict[str, "SearchParameter"]]]], - init_model_id: str, - objective_dict: Dict[str, dict], - ): - """Initialize the search strategy. - - pass_flows_search_spaces: list of list of tuples of format (search_space_name, {param_name: SearchParameter}) - objective_dict: dictionary of format {objective_name: {"higher_is_better": bool, "goal": float}} - """ - self._objective_dict = objective_dict - - # search spaces - self._spaces_order = [[pass_ss[0] for pass_ss in pass_flow_ss] for pass_flow_ss in pass_flows_search_spaces] - self._spaces_dict = {} - for pass_flow_ss in pass_flows_search_spaces: - for pass_ss in pass_flow_ss: - self._spaces_dict[pass_ss[0]] = pass_ss[1] - - # search space dictionaries for pass are grouped based on execution_order - self._spaces_groups = self._group_search_spaces(self._spaces_order) - # sub spaces group in pass-by-pass execution order - self._pass_by_pass_sg = None - - self._done_spaces_groups = [] - self._active_spaces_group = None - - # state - self._searchers: Dict[Any, SearchAlgorithm] = {} - self._search_results: Dict[Any, SearchResults] = {} - self._init_model_ids: Dict[Any, str] = {} - self.init_model_id = init_model_id - self._best_search_points = {} - - # initialize the first search space - self._next_search_group(init_model_id) - - self._initialized = True - - def _group_search_spaces(self, search_space_names: List[List]): - """Group search spaces based on execution order.""" - # joint: all passes grouped together - # pass-by-pass: each pass is a separate group - if self._config.execution_order == "joint": - search_spaces_groups = search_space_names - elif self._config.execution_order == "pass-by-pass": - # run pass-by-pass for each pass flow which is defined as a list of registered passes - search_spaces_groups = [] - for pass_flow_ss in search_space_names: - pass_flow_groups = [[pass_ss] for pass_ss in pass_flow_ss] - search_spaces_groups.append(pass_flow_groups) - else: - raise ValueError(f"Unknown execution order: {self._config.execution_order}") - - return search_spaces_groups - - def _next_search_group(self, init_model_id: Optional[str] = None) -> Optional[SearchAlgorithm]: - """Get the next search space group and initialize the search algorithm.""" - # if there is no more search space group, return None - # 1. joint: no more flows(self._space_groups) - # 2. pass-by-pass: no more flows(self._space_groups) and no more passes(self._pass_by_pass_sg) - if not (self._spaces_groups or self._pass_by_pass_sg): - self._active_spaces_group = None - return None - - # for the fist search group, init_model_id must be provided - if init_model_id is None and self._active_spaces_group is None: - raise ValueError("init_model_id must be provided for the first search group") - - if self._config.execution_order == "joint": - next_sg = self._next_search_group_joint(init_model_id) - elif self._config.execution_order == "pass-by-pass": - next_sg = self._next_search_group_pass_by_pass(init_model_id) - else: - raise ValueError(f"Invalid execution order {self._config.execution_order}") - return next_sg - - def _next_search_group_pass_by_pass(self, init_model_id: Optional[str] = None) -> Optional[SearchAlgorithm]: - # passes are exhausted or empty for current flow, try next pass flow - if not self._pass_by_pass_sg: - self._pass_by_pass_sg = self._spaces_groups.pop(0) - self._active_spaces_group = None - init_model_id = self.init_model_id - - # get the best model from last space group - if self._active_spaces_group is not None: - self._done_spaces_groups.append(self._active_spaces_group) - # legacy, will update once search results has info function - sorted_model_ids, sorted_search_points, sorted_results = self._search_results[ - tuple(self._active_spaces_group) - ].sort_search_points(apply_goals=True) - if sorted_model_ids is None: - logger.warning( - "No models in this search group %s met the goals. Sorting the models without applying goals...", - self._active_spaces_group, - ) - sorted_model_ids, sorted_search_points, sorted_results = self._search_results[ - tuple(self._active_spaces_group) - ].sort_search_points(apply_goals=False) - # TODO(trajep): this is a hack to get the best search point for the current search space group - # it totally work for joint execution order, but not for pass-by-pass - if sorted_search_points and sorted_results: - best_search_point = ( - sorted_search_points[0], - list(sorted_results[0].values()), - sorted_model_ids[0], - ) - self._best_search_points[tuple(self._active_spaces_group)] = best_search_point - init_model_id = best_search_point[2][-1] - - if init_model_id is None and self._active_spaces_group is not None: - raise ValueError( - f"The previous search group {self._active_spaces_group} has no output models that were created and" - " evaluated successfully. Cannot continue." - ) - - # set up next search group - # if it is the first run in this flow, init_model_id should be input model id - # otherwise, it should be the best model id from last search group - self._active_spaces_group = self._pass_by_pass_sg.pop(0) - self._searchers[tuple(self._active_spaces_group)] = self._create_searcher(self._active_spaces_group) - self._search_results[tuple(self._active_spaces_group)] = SearchResults(self._objective_dict) - self._init_model_ids[tuple(self._active_spaces_group)] = init_model_id - return self._active_spaces_group - - def _next_search_group_joint(self, init_model_id: Optional[str] = None) -> Optional[SearchAlgorithm]: - init_model_id = init_model_id or self.init_model_id - # get the first pass flow - # for "joint" model, init_model_id should be input_model_id - sg = self._spaces_groups.pop(0) - self._searchers[tuple(sg)] = self._create_searcher(sg) - self._search_results[tuple(sg)] = SearchResults(self._objective_dict) - self._init_model_ids[tuple(sg)] = init_model_id - self._active_spaces_group = sg - return self._active_spaces_group - - def _create_searcher(self, search_space_names: List[str]) -> SearchAlgorithm: - """Create a search algorithm.""" - search_spaces_dict = {space_name: deepcopy(self._spaces_dict[space_name]) for space_name in search_space_names} - objectives = list(self._objective_dict.keys()) - higher_is_betters = [self._objective_dict[objective]["higher_is_better"] for objective in objectives] - if self._config.search_algorithm in REGISTRY: - searcher = REGISTRY[self._config.search_algorithm]( - search_spaces_dict, objectives, higher_is_betters, self._config.search_algorithm_config - ) - searcher.initialize() - else: - raise ValueError(f"Unknown search algorithm: {self._config.search_algorithm}") - return searcher - - def next_step(self) -> Optional[Dict[str, Any]]: - """Get the next step in the search.""" - if not self._initialized: - raise ValueError("Search strategy is not initialized") - - if self.exit_criteria_met: - self._next_search_group() - - # if there is no active searcher, we are done - if self._active_spaces_group is None: - return None - - # get the next search point from the active searcher - search_point = self._searchers[tuple(self._active_spaces_group)].suggest() - # if there are no more search points, move to the next search space group - if search_point is None: - self._next_search_group() - return self.next_step() - - return { - "search_point": search_point, - "model_id": self._init_model_ids[tuple(self._active_spaces_group)], - "passes": [(space_name, search_point[space_name]) for space_name in self._active_spaces_group], - } - - def record_feedback_signal( - self, - search_point: Dict[str, Dict[str, Any]], - signal: "MetricResult", - model_ids: List[str], - should_prune: bool = False, - ): - """Record the feedback signal for the given search point.""" - if not self._initialized: - raise ValueError("Search strategy is not initialized") - self._search_results[tuple(self._active_spaces_group)].record(search_point, signal, model_ids) - self._searchers[tuple(self._active_spaces_group)].report(search_point, signal, should_prune) - - def check_exit_criteria(self, iter_num, time_diff, metric_signal): - """Check if the olive search_strategy should exit.""" - self.exit_criteria_met = False - if not self._config.stop_when_goals_met: - # stop early stopping when stop_when_goals_met is False, but still apply goals check without stopping - return - # early exit is not supported for pass-by-pass execution order currently - if self._config.execution_order == "pass-by-pass": - return - if self._config.max_iter is not None and iter_num > self._config.max_iter: - self.exit_criteria_met = True - return - if self._config.max_time is not None and time_diff > self._config.max_time: - self.exit_criteria_met = True - return - if metric_signal == {}: - return - self.exit_criteria_met = self._config.stop_when_goals_met and self._search_results[ - tuple(self._active_spaces_group) - ].check_goals(metric_signal) - - def get_output_model_num(self): - return self._config.output_model_num diff --git a/olive/systems/azureml/aml_system.py b/olive/systems/azureml/aml_system.py index a286dbbf2..76634908d 100644 --- a/olive/systems/azureml/aml_system.py +++ b/olive/systems/azureml/aml_system.py @@ -243,12 +243,7 @@ def _assert_not_none(self, obj): if obj is None: raise ValueError(f"{obj.__class__.__name__} is missing in the inputs!") - def run_pass( - self, - the_pass: "Pass", - model_config: ModelConfig, - output_model_path: str, - ) -> ModelConfig: + def run_pass(self, the_pass: "Pass", model_config: ModelConfig, output_model_path: str) -> ModelConfig: """Run the pass on the model.""" ml_client = self.azureml_client_config.create_client() diff --git a/olive/systems/docker/docker_system.py b/olive/systems/docker/docker_system.py index 29560fe41..47d84d0cf 100644 --- a/olive/systems/docker/docker_system.py +++ b/olive/systems/docker/docker_system.py @@ -103,22 +103,13 @@ def __init__( _print_docker_logs(e.build_log, logging.ERROR) raise - def run_pass( - self, - the_pass: "Pass", - model_config: "ModelConfig", - output_model_path: str, - ) -> "ModelConfig": + def run_pass(self, the_pass: "Pass", model_config: "ModelConfig", output_model_path: str) -> "ModelConfig": """Run the pass on the model.""" with tempfile.TemporaryDirectory() as tempdir: return self._run_pass_container(Path(tempdir), the_pass, model_config, output_model_path) def _run_pass_container( - self, - workdir: Path, - the_pass: "Pass", - model_config: "ModelConfig", - output_model_path: str, + self, workdir: Path, the_pass: "Pass", model_config: "ModelConfig", output_model_path: str ) -> "ModelConfig": pass_config = the_pass.to_json(check_object=True) diff --git a/olive/systems/olive_system.py b/olive/systems/olive_system.py index 7a39a1b90..6c3339d9b 100644 --- a/olive/systems/olive_system.py +++ b/olive/systems/olive_system.py @@ -41,12 +41,7 @@ def __init__( self.hf_token = hf_token @abstractmethod - def run_pass( - self, - the_pass: "Pass", - model_config: "ModelConfig", - output_model_path: str, - ) -> "ModelConfig": + def run_pass(self, the_pass: "Pass", model_config: "ModelConfig", output_model_path: str) -> "ModelConfig": """Run the pass on the model at a specific point in the search space.""" raise NotImplementedError diff --git a/olive/workflows/run/config.py b/olive/workflows/run/config.py index 5265e58e0..50b7c9dd6 100644 --- a/olive/workflows/run/config.py +++ b/olive/workflows/run/config.py @@ -15,50 +15,16 @@ from olive.data.config import DataComponentConfig, DataConfig from olive.data.container.dummy_data_container import TRANSFORMER_DUMMY_DATA_CONTAINER from olive.data.container.huggingface_container import HuggingfaceContainer -from olive.engine import Engine, EngineConfig +from olive.engine import Engine +from olive.engine.config import EngineConfig, RunPassConfig from olive.engine.packaging.packaging_config import PackagingConfig from olive.evaluator.olive_evaluator import OliveEvaluatorConfig from olive.model import ModelConfig -from olive.passes.pass_config import AbstractPassConfig, PassParamDefault +from olive.passes.pass_config import PassParamDefault from olive.resource_path import AZUREML_RESOURCE_TYPES from olive.systems.system_config import SystemConfig -class RunPassConfig(AbstractPassConfig): - """Pass configuration for Olive workflow. - - This is the configuration for a single pass in Olive workflow. It includes configurations for pass type, config, - etc. - - Example: - .. code-block:: json - - { - "type": "OlivePass", - "config": { - "param1": "value1", - "param2": "value2" - } - } - - """ - - host: Union[SystemConfig, str] = Field( - None, - description=( - "Host system for the pass. If it is a string, must refer to a system config under `systems` section. If not" - " provided, use the engine's host system." - ), - ) - evaluator: Union[OliveEvaluatorConfig, str] = Field( - None, - description=( - "Evaluator for the pass. If it is a string, must refer to an evaluator config under `evaluators` section." - " If not provided, use the engine's evaluator." - ), - ) - - class RunEngineConfig(EngineConfig): evaluate_input_model: bool = True output_dir: Union[Path, str] = None @@ -138,15 +104,7 @@ class RunConfig(NestedConfig): " no-search or auto-optimizer mode based on whether passes field is provided." ), ) - passes: Dict[str, RunPassConfig] = Field(default_factory=dict, description="Pass configurations.") - pass_flows: List[List[str]] = Field( - None, - description=( - "Pass flows. Each member must be a list of pass names from `passes` field. If provided," - " each flow will be run sequentially. If not provided, all passes will be run as a single flow in the order" - " of `passes` field." - ), - ) + passes: Dict[str, List[RunPassConfig]] = Field(None, description="Pass configurations.") auto_optimizer_config: AutoOptimizerConfig = Field( default_factory=AutoOptimizerConfig, description="Auto optimizer configuration. Only valid when passes field is empty or not provided.", @@ -155,11 +113,25 @@ class RunConfig(NestedConfig): None, description="Workflow host. None by default. If provided, the workflow will be run on the specified host." ) + @root_validator(pre=True) + def patch_evaluators(cls, values): + if "evaluators" in values: + for name, evaluator_config in values["evaluators"].items(): + evaluator_config["name"] = name + return values + + @root_validator(pre=True) + def patch_passes(cls, values): + if "passes" in values: + for name, passes_config in values["passes"].items(): + if isinstance(passes_config, dict): + values["passes"][name] = [passes_config] + return values + @root_validator(pre=True) def insert_azureml_client(cls, values): values = convert_configs_to_dicts(values) _insert_azureml_client(values, values.get("azureml_client")) - return values @validator("data_configs", pre=True) @@ -237,43 +209,45 @@ def validate_evaluators(cls, v, values): def validate_engine(cls, v, values): v = _resolve_system(v, values, "host") v = _resolve_system(v, values, "target") + if v.get("search_strategy") and not v.get("evaluator"): + raise ValueError( + "Can't search without a valid evaluator config. " + "Either provider a valid evaluator config or disable search." + ) return _resolve_evaluator(v, values) @validator("passes", pre=True, each_item=True) def validate_pass_host_evaluator(cls, v, values): - v = _resolve_system(v, values, "host") - return _resolve_evaluator(v, values) + for i, _ in enumerate(v): + v[i] = _resolve_system(v[i], values, "host") + v[i] = _resolve_evaluator(v[i], values) + return v @validator("passes", pre=True, each_item=True) def validate_pass_search(cls, v, values): if "engine" not in values: raise ValueError("Invalid engine") - # validate first to gather config params - v = validate_config(v, RunPassConfig).dict() - - if not v.get("config"): - return v + for i, _ in enumerate(v): + # validate first to gather config params + v[i] = iv = validate_config(v[i], RunPassConfig).dict() - searchable_configs = set() - for param_name in v["config"]: - if v["config"][param_name] == PassParamDefault.SEARCHABLE_VALUES: - searchable_configs.add(param_name) + if iv.get("config"): + _resolve_all_data_configs(iv["config"], values) - resolve_all_data_configs(v["config"], values) + searchable_configs = set() + for param_name in iv["config"]: + if iv["config"][param_name] == PassParamDefault.SEARCHABLE_VALUES: + searchable_configs.add(param_name) - if not values["engine"].search_strategy and searchable_configs: - raise ValueError( - f"You cannot disable search for {v['type']} and" - f" set {searchable_configs} to SEARCHABLE_VALUES at the same time." - " Please remove SEARCHABLE_VALUES or enable search(needs search strategy configs)." - ) + if not values["engine"].search_strategy and searchable_configs: + raise ValueError( + f"You cannot disable search for {iv['type']} and" + f" set {searchable_configs} to SEARCHABLE_VALUES at the same time." + " Please remove SEARCHABLE_VALUES or enable search (needs search strategy configs)." + ) return v - @validator("pass_flows", pre=True) - def validate_pass_flows(cls, v, values): - return v or [] - @validator("workflow_host", pre=True) def validate_workflow_host(cls, v, values): if v is None: @@ -281,15 +255,17 @@ def validate_workflow_host(cls, v, values): return _resolve_config(values, v) -def resolve_all_data_configs(config, values): +def _resolve_all_data_configs(config, values): """Recursively traverse the config dictionary to resolve all 'data_config' keys.""" - for param_name, param_value in config.items(): - if param_name.endswith("data_config"): - _resolve_data_config(config, values, param_name) - continue - - if isinstance(param_value, dict): - resolve_all_data_configs(param_value, values) + if isinstance(config, dict): + for param_name, param_value in config.items(): + if param_name.endswith("data_config"): + _resolve_data_config(config, values, param_name) + else: + _resolve_all_data_configs(param_value, values) + elif isinstance(config, list): + for element in config: + _resolve_all_data_configs(element, values) def _insert_azureml_client(config, azureml_client): diff --git a/olive/workflows/run/run.py b/olive/workflows/run/run.py index 67c0b691b..4f339160c 100644 --- a/olive/workflows/run/run.py +++ b/olive/workflows/run/run.py @@ -8,14 +8,17 @@ import sys from copy import deepcopy from pathlib import Path -from typing import Generator, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union from olive.common.utils import set_tempdir from olive.logging import set_default_logger_severity, set_ort_logger_severity, set_verbosity_info from olive.package_config import OlivePackageConfig from olive.systems.accelerator_creator import create_accelerators from olive.systems.common import SystemType -from olive.workflows.run.config import RunConfig, RunPassConfig +from olive.workflows.run.config import RunConfig + +if TYPE_CHECKING: + from olive.engine.config import RunPassConfig logger = logging.getLogger(__name__) @@ -58,18 +61,19 @@ def get_pass_extras(pass_type): # add dependencies for passes if run_config.passes: - for pass_config in run_config.passes.values(): - host = pass_config.host or run_config.engine.host - if (host and host.type == SystemType.Local) or not host: - local_packages.extend(get_pass_extras(pass_config.type)) - else: - remote_packages.extend(get_pass_extras(pass_config.type)) - if pass_config.type in ["SNPEConversion", "SNPEQuantization", "SNPEtoONNXConversion"]: - logger.info( - "Please refer to https://microsoft.github.io/Olive/tutorials/passes/snpe.html to install SNPE" - " prerequisites for pass %s", - pass_config.type, - ) + for passes_configs in run_config.passes.values(): + for pass_config in passes_configs: + host = pass_config.host or run_config.engine.host + if (host and host.type == SystemType.Local) or not host: + local_packages.extend(get_pass_extras(pass_config.type)) + else: + remote_packages.extend(get_pass_extras(pass_config.type)) + if pass_config.type in ["SNPEConversion", "SNPEQuantization", "SNPEtoONNXConversion"]: + logger.info( + "Please refer to https://microsoft.github.io/Olive/tutorials/passes/snpe.html to install SNPE" + " prerequisites for pass %s", + pass_config.type, + ) # add dependencies for engine host_type = None @@ -126,8 +130,11 @@ def get_pass_module_path(pass_type: str, package_config: OlivePackageConfig) -> def is_execution_provider_required(run_config: RunConfig, package_config: OlivePackageConfig) -> bool: - passes = get_used_passes(run_config) - return any(get_pass_module_path(p.type, package_config).startswith("olive.passes.onnx") for p in passes) + return any( + get_pass_module_path(p.type, package_config).startswith("olive.passes.onnx") + for passes_configs in run_config.passes.values() + for p in passes_configs + ) def run_engine(package_config: OlivePackageConfig, run_config: RunConfig): @@ -164,48 +171,31 @@ def run_engine(package_config: OlivePackageConfig, run_config: RunConfig): and run_config.auto_optimizer_config is not None and not run_config.auto_optimizer_config.disable_auto_optimizer ) - is_ep_required = auto_optimizer_enabled or is_execution_provider_required(run_config, package_config) - - # Register passes since we need to know whether they need to run on target - used_passes = list(get_used_passes(run_config)) - for pass_config in used_passes: - logger.debug("Registering pass %s", pass_config.type) - package_config.import_pass_module(pass_config.type) # check if target is not used + used_passes_configs = get_used_passes_configs(run_config) target_not_used = ( # no evaluator given (also implies no search) engine.evaluator_config is None # not using auto optimizer - and used_passes + and used_passes_configs # no pass specific evaluator # no pass needs to run on target and all( pass_config.evaluator is None and not get_run_on_target(package_config, pass_config) - for pass_config in used_passes + for pass_config in used_passes_configs ) ) + + is_ep_required = auto_optimizer_enabled or is_execution_provider_required(run_config, package_config) accelerator_specs = create_accelerators( engine.target_config, skip_supported_eps_check=target_not_used, is_ep_required=is_ep_required ) - # Initializes the passes and register it with the engine - passes_to_run = ( - {pass_name for pass_flow in run_config.pass_flows for pass_name in pass_flow} - if run_config.pass_flows - else set(run_config.passes.keys()) - ) - for pass_name in passes_to_run: - pass_run_config = run_config.passes[pass_name] - engine.register( - pass_run_config.type, - config=pass_run_config.config, - name=pass_name, - host=pass_run_config.host.create_system() if pass_run_config.host is not None else None, - evaluator_config=pass_run_config.evaluator, - ) - engine.set_pass_flows(run_config.pass_flows or [list(run_config.passes.keys())]) + # Set passes with the engine + engine.set_input_passes_configs(run_config.passes) + # run return engine.run( run_config.input_model, accelerator_specs, @@ -334,18 +324,10 @@ def get_local_ort_packages() -> List[str]: return local_ort_packages -def get_used_passes(run_config: RunConfig) -> Generator["RunPassConfig", None, None]: - if run_config.pass_flows: - passes = set() - for pass_flow in run_config.pass_flows: - for pass_name in pass_flow: - if run_config.passes[pass_name].type not in passes: - passes.add(run_config.passes[pass_name].type) - yield run_config.passes[pass_name] - elif run_config.passes: - yield from run_config.passes.values() +def get_used_passes_configs(run_config: RunConfig) -> List["RunPassConfig"]: + return [pass_config for _, pass_configs in run_config.passes.items() for pass_config in pass_configs] -def get_run_on_target(package_config: OlivePackageConfig, pass_config: RunPassConfig) -> bool: +def get_run_on_target(package_config: OlivePackageConfig, pass_config: "RunPassConfig") -> bool: pass_module_config = package_config.get_pass_module_config(pass_config.type) return pass_module_config.run_on_target diff --git a/test/unit_test/engine/packaging/test_packaging_generator.py b/test/unit_test/engine/packaging/test_packaging_generator.py index 98261d6c7..9e108fea6 100644 --- a/test/unit_test/engine/packaging/test_packaging_generator.py +++ b/test/unit_test/engine/packaging/test_packaging_generator.py @@ -54,7 +54,7 @@ def test_generate_zipfile_artifacts(mock_sys_getsizeof, save_as_external_data, m }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, "evaluator": evaluator_config, } diff --git a/test/unit_test/engine/test_engine.py b/test/unit_test/engine/test_engine.py index 0e4e82cff..4e7293721 100644 --- a/test/unit_test/engine/test_engine.py +++ b/test/unit_test/engine/test_engine.py @@ -19,6 +19,7 @@ from olive.data.config import DataComponentConfig, DataConfig from olive.engine import Engine +from olive.engine.config import RunPassConfig from olive.evaluator.metric import AccuracySubType from olive.evaluator.metric_result import MetricResult, joint_metric_key from olive.evaluator.olive_evaluator import OliveEvaluatorConfig @@ -28,7 +29,6 @@ from olive.passes.onnx.quantization import OnnxDynamicQuantization, OnnxStaticQuantization from olive.systems.accelerator_creator import create_accelerators from olive.systems.common import SystemType -from olive.systems.local import LocalSystem from olive.systems.system_config import LocalTargetUserConfig, SystemConfig # pylint: disable=protected-access @@ -43,7 +43,7 @@ def test_register(self, tmpdir): # setup p = get_onnxconversion_pass() name = p.__class__.__name__ - system = LocalSystem() + host = SystemConfig(type=SystemType.Local) evaluator_config = OliveEvaluatorConfig(metrics=[get_accuracy_metric(AccuracySubType.ACCURACY_SCORE)]) options = { @@ -53,19 +53,20 @@ def test_register(self, tmpdir): }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, } engine = Engine(**options) # execute - engine.register(OnnxConversion, host=system, evaluator_config=evaluator_config) + engine.register(OnnxConversion, host=host, evaluator_config=evaluator_config) # assert - assert name in engine.pass_run_configs - assert engine.pass_run_configs[name]["type"] == OnnxConversion.__name__ - assert engine.pass_run_configs[name]["host"] == system - assert engine.pass_run_configs[name]["evaluator"] == evaluator_config + assert name in engine.input_passes_configs + assert len(engine.input_passes_configs[name]) == 1 + assert engine.input_passes_configs[name][0].type == OnnxConversion.__name__.lower() + assert engine.input_passes_configs[name][0].host == host + assert engine.input_passes_configs[name][0].evaluator == evaluator_config def test_register_no_search(self, tmpdir): # setup @@ -82,7 +83,7 @@ def test_register_no_search(self, tmpdir): engine.register(OnnxDynamicQuantization) # assert - assert "OnnxDynamicQuantization" in engine.pass_run_configs + assert "OnnxDynamicQuantization" in engine.input_passes_configs def test_default_engine_run(self, tmpdir): # setup @@ -100,7 +101,7 @@ def test_default_engine_run(self, tmpdir): for fp_nodes in outputs.values(): for node in fp_nodes.nodes.values(): assert node.model_config - assert node.from_pass == "OnnxConversion" + assert node.from_pass == "onnxconversion" assert node.metrics is None, "Should not evaluate input/output model by default" @patch("olive.systems.local.LocalSystem") @@ -118,7 +119,7 @@ def test_run(self, mock_local_system, tmp_path): }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, "evaluator": evaluator_config, } @@ -143,16 +144,18 @@ def test_run(self, mock_local_system, tmp_path): system_object.olive_managed_env = False engine = Engine(**options) - p1_name = "converter_13" - p2_name = "converter_14" - p1, p1_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=13) - p2, p2_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=14) - engine.register(OnnxConversion, name=p1_name, config=p1_config) - engine.register(OnnxConversion, name=p2_name, config=p2_config) - engine.set_pass_flows([[p1_name], [p2_name]]) + p_name = "converter" + p1, p1_run_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=13) + p2, p2_run_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=14) + p1_run_config = RunPassConfig(**p1_run_config) + p2_run_config = RunPassConfig(**p2_run_config) + engine.set_input_passes_configs({p_name: [p1_run_config, p2_run_config]}) + + p1_pass_config = p1.serialize_config(p1_run_config.config, check_object=True) + p2_pass_config = p2.serialize_config(p2_run_config.config, check_object=True) model_ids = [ - engine.cache.get_output_model_id(p1.__class__.__name__, p1_config, input_model_id), - engine.cache.get_output_model_id(p2.__class__.__name__, p2_config, input_model_id), + engine.cache.get_output_model_id(p1.__class__.__name__, p1_pass_config, input_model_id), + engine.cache.get_output_model_id(p2.__class__.__name__, p2_pass_config, input_model_id), ] expected_res = { model_id: { @@ -214,7 +217,6 @@ def test_run_no_search_model_components(self, mock_local_system_init, tmpdir): engine = Engine(cache_config={"cache_dir": tmpdir}) engine.register(OptimumConversion) - engine.set_pass_flows() # output model to output_dir output_dir = Path(tmpdir) @@ -257,7 +259,6 @@ def test_run_no_search(self, mock_local_system_init, tmp_path): engine = Engine(**options) _, p_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=13) engine.register(OnnxConversion, config=p_config) - engine.set_pass_flows() accelerator_spec = DEFAULT_CPU_ACCELERATOR output_model_id = engine.cache.get_output_model_id( @@ -310,7 +311,7 @@ def test_run_no_search(self, mock_local_system_init, tmp_path): [ { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, None, ], @@ -361,7 +362,7 @@ def test_pass_exception(self, caplog, tmpdir): }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, "evaluator": evaluator_config, } @@ -499,7 +500,7 @@ def test_pass_cache(self, mock_get_available_providers, mock_local_system_init, }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, "evaluator": evaluator_config, } @@ -563,7 +564,7 @@ def test_pass_value_error(self, caplog, tmpdir): }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", }, "evaluator": evaluator_config, } @@ -601,7 +602,9 @@ def test_pass_quantization_error(self, is_search, caplog, tmpdir): }, "search_strategy": { "execution_order": "joint", - "search_algorithm": "random", + "sampler": "random", + "max_samples": 1, + "seed": 1, }, "evaluator": evaluator_config, } diff --git a/test/unit_test/passes/common/test_user_script.py b/test/unit_test/passes/common/test_user_script.py index d21fe9934..c1460dea4 100644 --- a/test/unit_test/passes/common/test_user_script.py +++ b/test/unit_test/passes/common/test_user_script.py @@ -8,6 +8,6 @@ class TestUserScriptConfig: def test_no_config(self): - config = {} - config = OrtSessionParamsTuning.generate_config(DEFAULT_CPU_ACCELERATOR, config, True) + config = OrtSessionParamsTuning.generate_config(DEFAULT_CPU_ACCELERATOR, disable_search=True) assert config + assert OrtSessionParamsTuning.validate_config(config, DEFAULT_CPU_ACCELERATOR, disable_search=True) diff --git a/test/unit_test/passes/test_pass_serialization.py b/test/unit_test/passes/test_pass_serialization.py index 88e09906e..ea05b0e35 100644 --- a/test/unit_test/passes/test_pass_serialization.py +++ b/test/unit_test/passes/test_pass_serialization.py @@ -11,8 +11,7 @@ @pytest.mark.parametrize("host_device", [None, "cpu", "gpu"]) def test_pass_serialization(host_device): - config = {} - config = OnnxConversion.generate_config(DEFAULT_CPU_ACCELERATOR, config) + config = OnnxConversion.generate_config(DEFAULT_CPU_ACCELERATOR) onnx_conversion = OnnxConversion(DEFAULT_CPU_ACCELERATOR, config, host_device=host_device) json = onnx_conversion.to_json(True) diff --git a/test/unit_test/search/samplers/test_random_sampler.py b/test/unit_test/search/samplers/test_random_sampler.py new file mode 100644 index 000000000..6f9179a2f --- /dev/null +++ b/test/unit_test/search/samplers/test_random_sampler.py @@ -0,0 +1,101 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict +from unittest.mock import patch + +from olive.search.samplers.random_sampler import RandomSampler +from olive.search.search_parameter import Categorical +from olive.search.search_space import SearchSpace + +# ruff: noqa: PD011 + + +class TestRandomSampler: + @patch("olive.search.search_space.SearchSpace.__getitem__") + def test_call_count(self, mock_search_space_get_item): + search_space = SearchSpace( + [ + ("PassA", Categorical([1, 2])), + ("PassB", Categorical([1, 2, 3])), + ("PassC", Categorical(["a", "b"])), + ("PassD", Categorical(["a", "b", "c"])), + ] + ) + assert len(search_space) == 36 + + config = {"max_samples": 50} + sampler = RandomSampler(search_space, config=config) + + count = 0 + while not sampler.should_stop: + sampler.suggest() + count += 1 + + assert count == 36 + assert mock_search_space_get_item.call_count == 36 + + def test_iteration(self): + search_space = SearchSpace( + [ + ("PassA", Categorical([1, 2])), + ("PassB", Categorical([1, 2, 3])), + ("PassC", Categorical(["a", "b"])), + ("PassD", Categorical(["a", "b", "c"])), + ] + ) + assert len(search_space) == 36 + + config = {"seed": 101, "max_samples": 50} + sampler = RandomSampler(search_space, config=config) + + count = 0 + actual = [] + while not sampler.should_stop: + actual.append(sampler.suggest()) + count += 1 + + actual = [(search_point.index, search_point.values) for search_point in actual] + expected = [ + (12, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (1, "b")})), + (35, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (2, "c")})), + (23, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (1, "b")})), + (31, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (2, "c")})), + (3, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (0, "a")})), + (24, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (2, "c")})), + (18, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (1, "b")})), + (7, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (0, "a")})), + (25, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (2, "c")})), + (9, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (0, "a")})), + (13, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (1, "b")})), + (21, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (1, "b")})), + (33, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (2, "c")})), + (8, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (0, "a")})), + (16, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (1, "b")})), + (26, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (2, "c")})), + (2, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (0, "a")})), + (15, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (1, "b")})), + (11, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (0, "a")})), + (10, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (0, "a")})), + (4, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (0, "a")})), + (20, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (1, "b")})), + (32, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (2, "c")})), + (27, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (2, "c")})), + (17, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (1, "b")})), + (6, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (0, "a")})), + (29, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (2, "c")})), + (22, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (1, "b")})), + (14, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (1, "b")})), + (30, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (2, "c")})), + (19, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (1, "b")})), + (28, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (2, "c")})), + (1, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (0, "a")})), + (0, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (0, "a")})), + (34, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (2, "c")})), + (5, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (0, "a")})), + ] + + assert count == 36 + assert actual == expected + assert len({sp for sp, _ in expected}) == len(expected) diff --git a/test/unit_test/search/samplers/test_sequential_sampler.py b/test/unit_test/search/samplers/test_sequential_sampler.py new file mode 100644 index 000000000..daa72a820 --- /dev/null +++ b/test/unit_test/search/samplers/test_sequential_sampler.py @@ -0,0 +1,93 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict +from unittest.mock import patch + +from olive.search.samplers.sequential_sampler import SequentialSampler +from olive.search.search_parameter import Categorical +from olive.search.search_space import SearchSpace + +# ruff: noqa: PD011 + + +class TestSequentialSampler: + @patch("olive.search.search_space.SearchSpace.__getitem__") + def test_length(self, mock_search_space_get_item): + search_space = SearchSpace( + [ + ("PassA", Categorical([1, 2])), + ("PassB", Categorical([1, 2, 3])), + ("PassC", Categorical(["a", "b"])), + ("PassD", Categorical(["a", "b", "c"])), + ] + ) + assert len(search_space) == 36 + + sampler = SequentialSampler(search_space) + + suggestions = [] + while not sampler.should_stop: + suggestions.append(sampler.suggest()) + + assert len(suggestions) == 36 + assert mock_search_space_get_item.call_count == 36 + + def test_iteration(self): + search_space = SearchSpace( + [ + ("PassA", Categorical([1, 2])), + ("PassB", Categorical([1, 2, 3])), + ("PassC", Categorical(["a", "b"])), + ("PassD", Categorical(["a", "b", "c"])), + ] + ) + assert len(search_space) == 36 + + sampler = SequentialSampler(search_space) + + actual = [] + while not sampler.should_stop: + actual.append(sampler.suggest()) + + actual = [(search_point.index, search_point.values) for search_point in actual] + expected = [ + (0, OrderedDict([("PassA", (0, 1)), ("PassB", (0, 1)), ("PassC", (0, "a")), ("PassD", (0, "a"))])), + (1, OrderedDict([("PassA", (1, 2)), ("PassB", (0, 1)), ("PassC", (0, "a")), ("PassD", (0, "a"))])), + (2, OrderedDict([("PassA", (0, 1)), ("PassB", (1, 2)), ("PassC", (0, "a")), ("PassD", (0, "a"))])), + (3, OrderedDict([("PassA", (1, 2)), ("PassB", (1, 2)), ("PassC", (0, "a")), ("PassD", (0, "a"))])), + (4, OrderedDict([("PassA", (0, 1)), ("PassB", (2, 3)), ("PassC", (0, "a")), ("PassD", (0, "a"))])), + (5, OrderedDict([("PassA", (1, 2)), ("PassB", (2, 3)), ("PassC", (0, "a")), ("PassD", (0, "a"))])), + (6, OrderedDict([("PassA", (0, 1)), ("PassB", (0, 1)), ("PassC", (1, "b")), ("PassD", (0, "a"))])), + (7, OrderedDict([("PassA", (1, 2)), ("PassB", (0, 1)), ("PassC", (1, "b")), ("PassD", (0, "a"))])), + (8, OrderedDict([("PassA", (0, 1)), ("PassB", (1, 2)), ("PassC", (1, "b")), ("PassD", (0, "a"))])), + (9, OrderedDict([("PassA", (1, 2)), ("PassB", (1, 2)), ("PassC", (1, "b")), ("PassD", (0, "a"))])), + (10, OrderedDict([("PassA", (0, 1)), ("PassB", (2, 3)), ("PassC", (1, "b")), ("PassD", (0, "a"))])), + (11, OrderedDict([("PassA", (1, 2)), ("PassB", (2, 3)), ("PassC", (1, "b")), ("PassD", (0, "a"))])), + (12, OrderedDict([("PassA", (0, 1)), ("PassB", (0, 1)), ("PassC", (0, "a")), ("PassD", (1, "b"))])), + (13, OrderedDict([("PassA", (1, 2)), ("PassB", (0, 1)), ("PassC", (0, "a")), ("PassD", (1, "b"))])), + (14, OrderedDict([("PassA", (0, 1)), ("PassB", (1, 2)), ("PassC", (0, "a")), ("PassD", (1, "b"))])), + (15, OrderedDict([("PassA", (1, 2)), ("PassB", (1, 2)), ("PassC", (0, "a")), ("PassD", (1, "b"))])), + (16, OrderedDict([("PassA", (0, 1)), ("PassB", (2, 3)), ("PassC", (0, "a")), ("PassD", (1, "b"))])), + (17, OrderedDict([("PassA", (1, 2)), ("PassB", (2, 3)), ("PassC", (0, "a")), ("PassD", (1, "b"))])), + (18, OrderedDict([("PassA", (0, 1)), ("PassB", (0, 1)), ("PassC", (1, "b")), ("PassD", (1, "b"))])), + (19, OrderedDict([("PassA", (1, 2)), ("PassB", (0, 1)), ("PassC", (1, "b")), ("PassD", (1, "b"))])), + (20, OrderedDict([("PassA", (0, 1)), ("PassB", (1, 2)), ("PassC", (1, "b")), ("PassD", (1, "b"))])), + (21, OrderedDict([("PassA", (1, 2)), ("PassB", (1, 2)), ("PassC", (1, "b")), ("PassD", (1, "b"))])), + (22, OrderedDict([("PassA", (0, 1)), ("PassB", (2, 3)), ("PassC", (1, "b")), ("PassD", (1, "b"))])), + (23, OrderedDict([("PassA", (1, 2)), ("PassB", (2, 3)), ("PassC", (1, "b")), ("PassD", (1, "b"))])), + (24, OrderedDict([("PassA", (0, 1)), ("PassB", (0, 1)), ("PassC", (0, "a")), ("PassD", (2, "c"))])), + (25, OrderedDict([("PassA", (1, 2)), ("PassB", (0, 1)), ("PassC", (0, "a")), ("PassD", (2, "c"))])), + (26, OrderedDict([("PassA", (0, 1)), ("PassB", (1, 2)), ("PassC", (0, "a")), ("PassD", (2, "c"))])), + (27, OrderedDict([("PassA", (1, 2)), ("PassB", (1, 2)), ("PassC", (0, "a")), ("PassD", (2, "c"))])), + (28, OrderedDict([("PassA", (0, 1)), ("PassB", (2, 3)), ("PassC", (0, "a")), ("PassD", (2, "c"))])), + (29, OrderedDict([("PassA", (1, 2)), ("PassB", (2, 3)), ("PassC", (0, "a")), ("PassD", (2, "c"))])), + (30, OrderedDict([("PassA", (0, 1)), ("PassB", (0, 1)), ("PassC", (1, "b")), ("PassD", (2, "c"))])), + (31, OrderedDict([("PassA", (1, 2)), ("PassB", (0, 1)), ("PassC", (1, "b")), ("PassD", (2, "c"))])), + (32, OrderedDict([("PassA", (0, 1)), ("PassB", (1, 2)), ("PassC", (1, "b")), ("PassD", (2, "c"))])), + (33, OrderedDict([("PassA", (1, 2)), ("PassB", (1, 2)), ("PassC", (1, "b")), ("PassD", (2, "c"))])), + (34, OrderedDict([("PassA", (0, 1)), ("PassB", (2, 3)), ("PassC", (1, "b")), ("PassD", (2, "c"))])), + (35, OrderedDict([("PassA", (1, 2)), ("PassB", (2, 3)), ("PassC", (1, "b")), ("PassD", (2, "c"))])), + ] + assert actual == expected diff --git a/test/unit_test/search/samplers/test_tpe_sampler.py b/test/unit_test/search/samplers/test_tpe_sampler.py new file mode 100644 index 000000000..192df28cb --- /dev/null +++ b/test/unit_test/search/samplers/test_tpe_sampler.py @@ -0,0 +1,168 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict + +from olive.search.samplers.tpe_sampler import TPESampler +from olive.search.search_parameter import Categorical, Conditional +from olive.search.search_space import SearchSpace + +# ruff: noqa: PD011 + + +class TestTPESampler: + def test_iteration(self): + search_space = SearchSpace( + [ + ("PassA", Categorical([1, 2])), + ("PassB", Categorical([1, 2, 3])), + ("PassC", Categorical(["a", "b"])), + ("PassD", Categorical(["a", "b", "c"])), + ] + ) + + config = {"seed": 101, "max_samples": 50} + sampler = TPESampler(search_space, config=config) + + count = 0 + actual = [] + while not sampler.should_stop: + actual.append(sampler.suggest()) + count += 1 + + actual = [(search_point.index, search_point.values) for search_point in actual] + expected = [ + (5, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (0, "a")})), + (16, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (1, "b")})), + (34, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (2, "c")})), + (21, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (1, "b")})), + (25, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (2, "c")})), + (22, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (1, "b")})), + (13, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (1, "b")})), + (20, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (1, "b")})), + (32, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (2, "c")})), + (9, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (0, "a")})), + (35, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (2, "c")})), + (10, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (0, "a")})), + (4, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (0, "a")})), + (28, OrderedDict({"PassA": (0, 1), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (2, "c")})), + (30, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (2, "c")})), + (27, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (2, "c")})), + (12, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (1, "b")})), + (33, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (2, "c")})), + (24, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (2, "c")})), + (1, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (0, "a")})), + (3, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (0, "a")})), + (31, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (2, "c")})), + (6, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (0, "a")})), + (11, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (0, "a")})), + (15, OrderedDict({"PassA": (1, 2), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (1, "b")})), + (18, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (1, "b")})), + (17, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (1, "b")})), + (2, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (0, "a")})), + (0, OrderedDict({"PassA": (0, 1), "PassB": (0, 1), "PassC": (0, "a"), "PassD": (0, "a")})), + (14, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (1, "b")})), + (7, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (0, "a")})), + (29, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (0, "a"), "PassD": (2, "c")})), + (8, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (1, "b"), "PassD": (0, "a")})), + (26, OrderedDict({"PassA": (0, 1), "PassB": (1, 2), "PassC": (0, "a"), "PassD": (2, "c")})), + (23, OrderedDict({"PassA": (1, 2), "PassB": (2, 3), "PassC": (1, "b"), "PassD": (1, "b")})), + (19, OrderedDict({"PassA": (1, 2), "PassB": (0, 1), "PassC": (1, "b"), "PassD": (1, "b")})), + ] + + assert count == 36 + assert actual == expected + assert len({spi for spi, _ in expected}) == len(expected) + + def test_suggest(self): + search_space = SearchSpace( + [ + ("conversion", Categorical([SearchSpace([])])), + ("transformers_optimization", Categorical([SearchSpace([])])), + ( + "quantization", + Categorical( + [ + SearchSpace( + [ + ("quant_mode", Categorical(["dynamic", "static"])), + ("weight_type", Categorical(["QInt8", "QUInt8"])), + ( + "quant_format", + Conditional( + parents=("quant_mode",), + support={("static",): Categorical(["QOperator", "QDQ"])}, + default=Conditional.get_ignored_choice(), + ), + ), + ( + "activation_type", + Conditional( + parents=("quant_mode", "quant_format", "weight_type"), + support={ + ("static", "QDQ", "QInt8"): Categorical(["QInt8"]), + ("static", "QDQ", "QUInt8"): Categorical(["QUInt8"]), + ("static", "QOperator", "QUInt8"): Categorical(["QUInt8"]), + ("static", "QOperator", "QInt8"): Conditional.get_invalid_choice(), + }, + default=Conditional.get_ignored_choice(), + ), + ), + ( + "prepare_qnn_config", + Conditional( + parents=("quant_mode",), + support={ + ("static",): Categorical([False]), + ("dynamic",): Conditional.get_ignored_choice(), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ( + "qnn_extra_options", + Conditional( + parents=("quant_mode",), + support={ + ("static",): Categorical([None]), + ("dynamic",): Conditional.get_ignored_choice(), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ( + "MatMulConstBOnly", + Conditional( + parents=("quant_mode",), + support={ + ("dynamic",): Categorical([True]), + ("static",): Categorical([False]), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ] + ) + ] + ), + ), + ( + "session_params_tuning", + Categorical( + [ + SearchSpace( + [("providers_list", Categorical(["OpenVINOExecutionProvider", "CPUExecutionProvider"]))] + ) + ] + ), + ), + ] + ) + + config = {"seed": 101, "max_samples": 500} + sampler = TPESampler(search_space, config=config) + + while not sampler.should_stop: + sp = sampler.suggest() + assert sp == search_space[sp.index] diff --git a/test/unit_test/search/test_search_results.py b/test/unit_test/search/test_search_results.py new file mode 100644 index 000000000..10376eaf1 --- /dev/null +++ b/test/unit_test/search/test_search_results.py @@ -0,0 +1,106 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from olive.evaluator.metric_result import MetricResult +from olive.search.search_results import SearchResults + +# ruff: noqa: PD011 +# pylint: disable=W0212 + + +class TestSearchResults: + def test_empty(self): + results = SearchResults() + results.sort() + + assert results._sorted_indices == [] + assert results.get_next_best_result(-1) == (None, None, None) + + def test_sort(self): + objectives2 = { + "accuracy-accuracy_custom": {"goal": 0.7220000147819519, "higher_is_better": True, "priority": 1}, + "latency-avg": {"goal": 24.044427, "higher_is_better": False, "priority": 2}, + } + objectives3 = { + **objectives2, + "latency-max": {"goal": 30, "higher_is_better": False, "priority": 3}, + } + + signal1 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.75, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 4.5, "priority": 2, "higher_is_better": False}, + } + ) + signal2 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.78, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 55.9, "priority": 2, "higher_is_better": False}, + } + ) + signal3 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.76, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 53.0, "priority": 2, "higher_is_better": False}, + "latency-max": {"value": 60.1, "priority": 3, "higher_is_better": False}, + } + ) + + signals = [ + ( + objectives2, + signal1, + ["model_id_1"], + ), + None, + None, + ( + objectives2, + signal2, + ["model_id_2"], + ), + None, + None, + None, + ( + objectives3, + signal3, + ["model_id_3"], + ), + ] + + results = SearchResults() + for i, signal in enumerate(signals): + if signal: + results.record_feedback_signal(i, *signal) + + results.sort() + + assert results._sorted_indices == [3, 7, 0] + + next_best_spi = -1 + actual_order = [] + while next_best_spi is not None: + next_best_spi, spi, model_ids = results.get_next_best_result(next_best_spi) + if next_best_spi is not None: + actual_order.append((next_best_spi, spi, model_ids)) + + assert actual_order == [ + (0, 3, ["model_id_2"]), + (1, 7, ["model_id_3"]), + (2, 0, ["model_id_1"]), + ] diff --git a/test/unit_test/search/test_search_space.py b/test/unit_test/search/test_search_space.py new file mode 100644 index 000000000..9035b4cc4 --- /dev/null +++ b/test/unit_test/search/test_search_space.py @@ -0,0 +1,1963 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from collections import OrderedDict + +from olive.search.search_parameter import Categorical, Conditional, SpecialParamValue +from olive.search.search_point import SearchPoint +from olive.search.search_space import SearchSpace + +# ruff: noqa: PD011 + + +class TestSearchSpace: + def test_length_categoricals(self): + search_space = SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2])), + ("Pass/0/SP2", Categorical([1, 2, 3])), + ("Pass/0/SP3", Categorical([1, 2, 3, 4])), + ("Pass/0/SP4", Categorical([1, 2, 3, 4, 5])), + ("Pass/0/SP5", Categorical(["a", "b"])), + ("Pass/0/SP6", Categorical(["a", "b", "c"])), + ("Pass/0/SP7", Categorical(["a", "b", "c", "d"])), + ("Pass/0/SP8", Categorical(["a", "b", "c", "d", "e"])), + ] + ) + assert len(search_space) == 14400 + + def test_length_with_conditionals(self): + search_space = SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2, 3])), + ("Pass/0/SP2", Categorical(["x", "y", "z"])), + ( + "Pass/0/SP3", + Conditional( + parents=("Pass/0/SP1",), + support={(1,): Categorical(["a"]), (3,): Categorical(["a", "b", "c"])}, + default=Conditional.get_ignored_choice(), + ), + ), + ("Pass/0/SP4", Categorical(["4a", "4b", "4c"])), + ( + "Pass/0/SP5", + Conditional( + parents=("Pass/0/SP1", "Pass/0/SP2"), + support={ + (1, "x"): Categorical(["1x", "x1"]), + (2, "y"): Categorical(["2y", "y2"]), + (3, "z"): Categorical(["3z", "z3"]), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ("Pass/0/SP6", Categorical(["5x", "5y", "5z"])), + ] + ) + assert len(search_space) == 486 + + def test_length_with_search_spaces(self): + search_space = SearchSpace( + [ + ( + "Pass/0", + SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2])), + ("Pass/0/SP2", Categorical([1, 2, 3])), + ] + ), + ), + ( + "Pass/1", + SearchSpace( + [ + ("Pass/1/SP1", Categorical(["a", "b"])), + ("Pass/1/SP2", Categorical(["a", "b", "c"])), + ] + ), + ), + ] + ) + assert len(search_space) == 36 + + def test_length(self): + search_space = SearchSpace( + [ + ( + "PassA", + Categorical( + [SearchSpace([("PA0_SP1", Categorical([1, 2, 3])), ("PA0_SP2", Categorical([4, 5, 6]))])] + ), + ), + ( + "PassB", + Categorical( + [ + SearchSpace( + [("PB0_SP1", Categorical(["a", "b", "c"])), ("PB0_SP2", Categorical(["x", "y", "z"]))] + ), + SearchSpace( + [("PB0_SP1", Categorical(["x", "y", "z"])), ("PB0_SP2", Categorical(["a", "b", "c"]))] + ), + ] + ), + ), + ("PassC", SearchSpace([("PC0_SP1", Categorical([3, 2, 1])), ("PC0_SP2", Categorical([9, 8, 7]))])), + ( + "PassD", + SearchSpace( + [ + ("PD0_SP1", Categorical([1, 2, 3])), + ("PD0_SP2", Categorical(["x", "y", "z"])), + ( + "PD0_SP3", + Conditional( + parents=("PD0_SP1",), + support={(1,): Categorical(["a"]), (3,): Categorical(["a", "b", "c"])}, + default=Conditional.get_ignored_choice(), + ), + ), + ( + "PD0_SP4", + Conditional( + parents=("PD0_SP1", "PD0_SP2"), + support={ + (1, "x"): Categorical(["1x", "x1"]), + (2, "y"): Categorical(["2y", "y2"]), + (3, "z"): Categorical(["3y", "y3"]), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ( + "PD0_SP5", + Conditional( + parents=("PD0_SP3", "PD0_SP4"), + support={ + ("a", "1x"): SearchSpace( + [ + ("PD0_SP5_SP1", Categorical(["a1x", "x1a"])), + ("PD0_SP5_SP2", Categorical(["x1a", "a1x"])), + ] + ), + ("c", "3y"): SearchSpace( + [ + ("PD0_SP5_SP3", Categorical(["c3y", "y3c"])), + ("PD0_SP5_SP4", Categorical(["y3c", "3cy"])), + ] + ), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ] + ), + ), + ] + ) + # PassA:length = 9 + # PassB:length = 18 + # PassB[0]:length = 9 + # PassB[1]:length = 9 + # PassC:length = 9 + # PassD:length = 216 + # PassD[0].PD0_SP1:length = 3 + # PassD[0].PD0_SP2:length = 3 + # PassD[0].PD0_SP3:length = 3 + # PassD[0].PD0_SP4:length = 2 + # PassD[0].PD0_SP5:length = 4 + assert len(search_space) == 314928 # 9 * 18 * 9 * 216 + + def test_empty(self): + search_space = SearchSpace([]) + assert len(search_space) == 1 + + actual = [(search_point.index, search_point.values) for search_point in search_space] + expected = [(0, OrderedDict())] + assert actual == expected + + search_space = SearchSpace( + [ + ("PassA", Categorical([SearchSpace([]), SearchSpace([]), SearchSpace([])])), + ("PassB", Categorical([SearchSpace([]), SearchSpace([])])), + ("PassC", Categorical([SearchSpace([])])), + ] + ) + assert len(search_space) == 6 + + actual = [(search_point.index, search_point.values) for search_point in search_space] + expected = [ + ( + 0, + OrderedDict( + [("PassA", (0, OrderedDict())), ("PassB", (0, OrderedDict())), ("PassC", (0, OrderedDict()))] + ), + ), + ( + 1, + OrderedDict( + [("PassA", (1, OrderedDict())), ("PassB", (0, OrderedDict())), ("PassC", (0, OrderedDict()))] + ), + ), + ( + 2, + OrderedDict( + [("PassA", (2, OrderedDict())), ("PassB", (0, OrderedDict())), ("PassC", (0, OrderedDict()))] + ), + ), + ( + 3, + OrderedDict( + [("PassA", (0, OrderedDict())), ("PassB", (1, OrderedDict())), ("PassC", (0, OrderedDict()))] + ), + ), + ( + 4, + OrderedDict( + [("PassA", (1, OrderedDict())), ("PassB", (1, OrderedDict())), ("PassC", (0, OrderedDict()))] + ), + ), + ( + 5, + OrderedDict( + [("PassA", (2, OrderedDict())), ("PassB", (1, OrderedDict())), ("PassC", (0, OrderedDict()))] + ), + ), + ] + assert actual == expected + + def test_iteration_with_categoricals(self): + search_space = SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2, 3])), + ("Pass/0/SP2", Categorical(["a", "b"])), + ("Pass/0/SP3", Categorical([10, 20, 30])), + ("Pass/0/SP4", Categorical(["x", "y", "z"])), + ] + ) + + actual = [(search_point.index, search_point.values) for search_point in search_space] + expected = [ + ( + 0, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 1, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 2, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 3, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 4, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 5, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 6, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 7, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 8, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 9, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 10, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 11, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 12, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 13, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 14, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 15, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 16, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 17, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ), + ( + 18, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 19, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 20, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 21, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 22, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 23, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 24, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 25, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 26, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 27, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 28, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 29, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 30, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 31, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 32, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 33, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 34, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 35, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (1, "y")), + ] + ), + ), + ( + 36, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 37, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 38, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 39, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 40, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 41, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 42, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 43, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 44, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 45, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 46, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 47, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 48, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 49, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 50, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 51, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 52, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ( + 53, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, 30)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ), + ] + assert actual == expected + + def test_iteration_with_conditionals(self): + search_space = SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2, 3])), + ( + "Pass/0/SP2", + Conditional( + parents=("Pass/0/SP1",), + support={(1,): Categorical(["a"]), (3,): Categorical(["a", "b", "c"])}, + default=Conditional.get_ignored_choice(), + ), + ), + ("Pass/0/SP3", Categorical(["x", "y", "z"])), + ( + "Pass/0/SP4", + Conditional( + parents=("Pass/0/SP1", "Pass/0/SP3"), + support={ + (1, "x"): Categorical(["1x", "x1"]), + (2, "y"): Categorical(["2y", "y2"]), + (3, "z"): Categorical(["3z", "z3"]), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ] + ) + + actual = [(search_point.index, search_point.values) for search_point in search_space] + expected = [ + ( + 0, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, "1x")), + ] + ), + ), + ( + 1, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 2, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 3, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, "1x")), + ] + ), + ), + ( + 4, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 5, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 6, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, "1x")), + ] + ), + ), + ( + 7, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 8, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 9, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 10, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, "2y")), + ] + ), + ), + ( + 11, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 12, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 13, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, "2y")), + ] + ), + ), + ( + 14, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 15, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 16, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, "2y")), + ] + ), + ), + ( + 17, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 18, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 19, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 20, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, "3z")), + ] + ), + ), + ( + 21, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 22, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 23, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, "3z")), + ] + ), + ), + ( + 24, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 25, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 26, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (0, "3z")), + ] + ), + ), + ( + 27, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, "x1")), + ] + ), + ), + ( + 28, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 29, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 30, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, "x1")), + ] + ), + ), + ( + 31, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 32, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 33, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, "x1")), + ] + ), + ), + ( + 34, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 35, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 36, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 37, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, "y2")), + ] + ), + ), + ( + 38, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 39, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 40, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, "y2")), + ] + ), + ), + ( + 41, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 42, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 43, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, "y2")), + ] + ), + ), + ( + 44, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 45, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 46, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (0, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 47, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (0, "a")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, "z3")), + ] + ), + ), + ( + 48, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 49, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (1, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 50, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, "z3")), + ] + ), + ), + ( + 51, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 52, + OrderedDict( + [ + ("Pass/0/SP1", (1, 2)), + ("Pass/0/SP2", (2, SpecialParamValue.IGNORED)), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, SpecialParamValue.INVALID)), + ] + ), + ), + ( + 53, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (2, "z")), + ("Pass/0/SP4", (1, "z3")), + ] + ), + ), + ] + assert actual == expected + + def test_iteration_with_search_spaces(self): + search_space = SearchSpace( + [ + ( + "Pass/0", + SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2])), + ("Pass/0/SP2", Categorical([1, 2, 3])), + ] + ), + ), + ( + "Pass/1", + SearchSpace( + [ + ("Pass/1/SP1", Categorical(["a", "b"])), + ("Pass/1/SP2", Categorical(["a", "b", "c"])), + ] + ), + ), + ] + ) + assert len(search_space) == 36 + + actual = [(search_point.index, search_point.values) for search_point in search_space] + expected = [ + ( + 0, + OrderedDict( + [ + ("Pass/0", (0, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 1, + OrderedDict( + [ + ("Pass/0", (1, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 2, + OrderedDict( + [ + ("Pass/0", (2, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 3, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 4, + OrderedDict( + [ + ("Pass/0", (4, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 5, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 6, + OrderedDict( + [ + ("Pass/0", (0, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (1, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 7, + OrderedDict( + [ + ("Pass/0", (1, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (1, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 8, + OrderedDict( + [ + ("Pass/0", (2, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (1, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 9, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (1, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 10, + OrderedDict( + [ + ("Pass/0", (4, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (1, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 11, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (1, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ), + ( + 12, + OrderedDict( + [ + ("Pass/0", (0, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (2, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 13, + OrderedDict( + [ + ("Pass/0", (1, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (2, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 14, + OrderedDict( + [ + ("Pass/0", (2, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (2, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 15, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (2, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 16, + OrderedDict( + [ + ("Pass/0", (4, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (2, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 17, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (2, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 18, + OrderedDict( + [ + ("Pass/0", (0, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (3, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 19, + OrderedDict( + [ + ("Pass/0", (1, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (3, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 20, + OrderedDict( + [ + ("Pass/0", (2, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (3, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 21, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (3, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 22, + OrderedDict( + [ + ("Pass/0", (4, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (3, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 23, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (3, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (1, "b"))]))), + ] + ), + ), + ( + 24, + OrderedDict( + [ + ("Pass/0", (0, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 25, + OrderedDict( + [ + ("Pass/0", (1, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 26, + OrderedDict( + [ + ("Pass/0", (2, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 27, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 28, + OrderedDict( + [ + ("Pass/0", (4, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 29, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 30, + OrderedDict( + [ + ("Pass/0", (0, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (5, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 31, + OrderedDict( + [ + ("Pass/0", (1, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (0, 1))]))), + ("Pass/1", (5, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 32, + OrderedDict( + [ + ("Pass/0", (2, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (5, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 33, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (5, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 34, + OrderedDict( + [ + ("Pass/0", (4, OrderedDict([("Pass/0/SP1", (0, 1)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (5, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ( + 35, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (5, OrderedDict([("Pass/1/SP1", (1, "b")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ), + ] + assert actual == expected + + def test_get_item_with_categoricals(self): + search_space = SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2, 3])), + ("Pass/0/SP2", Categorical(["a", "b"])), + ("Pass/0/SP3", Categorical([10, 20, 30])), + ("Pass/0/SP4", Categorical(["x", "y", "z"])), + ] + ) + + actual_5 = search_space[5] + expected_5 = SearchPoint( + 5, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, 10)), + ("Pass/0/SP4", (0, "x")), + ] + ), + ) + assert actual_5 == expected_5 + + actual_45 = search_space[45] + expected_45 = SearchPoint( + 45, + OrderedDict( + [ + ("Pass/0/SP1", (0, 1)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (1, 20)), + ("Pass/0/SP4", (2, "z")), + ] + ), + ) + assert actual_45 == expected_45 + + def test_get_item_with_conditionals(self): + search_space = SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2, 3])), + ( + "Pass/0/SP2", + Conditional( + parents=("Pass/0/SP1",), + support={(1,): Categorical(["a"]), (3,): Categorical(["a", "b", "c"])}, + default=Conditional.get_ignored_choice(), + ), + ), + ("Pass/0/SP3", Categorical(["x", "y", "z"])), + ( + "Pass/0/SP4", + Conditional( + parents=("Pass/0/SP1", "Pass/0/SP3"), + support={ + (1, "x"): Categorical(["1x", "x1"]), + (2, "y"): Categorical(["2y", "y2"]), + (3, "z"): Categorical(["3z", "z3"]), + }, + default=Conditional.get_invalid_choice(), + ), + ), + ] + ) + + actual_5 = search_space[5] + expected_5 = SearchPoint( + 5, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (1, "b")), + ("Pass/0/SP3", (0, "x")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ) + assert actual_5 == expected_5 + + actual_17 = search_space[17] + expected_17 = SearchPoint( + 17, + OrderedDict( + [ + ("Pass/0/SP1", (2, 3)), + ("Pass/0/SP2", (2, "c")), + ("Pass/0/SP3", (1, "y")), + ("Pass/0/SP4", (0, SpecialParamValue.INVALID)), + ] + ), + ) + assert actual_17 == expected_17 + + def test_get_item_with_search_spaces(self): + search_space = SearchSpace( + [ + ( + "Pass/0", + SearchSpace( + [ + ("Pass/0/SP1", Categorical([1, 2])), + ("Pass/0/SP2", Categorical([1, 2, 3])), + ] + ), + ), + ( + "Pass/1", + SearchSpace( + [ + ("Pass/1/SP1", Categorical(["a", "b"])), + ("Pass/1/SP2", Categorical(["a", "b", "c"])), + ] + ), + ), + ] + ) + assert len(search_space) == 36 + + actual_5 = search_space[5] + expected_5 = SearchPoint( + 5, + OrderedDict( + [ + ("Pass/0", (5, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (2, 3))]))), + ("Pass/1", (0, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (0, "a"))]))), + ] + ), + ) + assert actual_5 == expected_5 + + actual_27 = search_space[27] + expected_27 = SearchPoint( + 27, + OrderedDict( + [ + ("Pass/0", (3, OrderedDict([("Pass/0/SP1", (1, 2)), ("Pass/0/SP2", (1, 2))]))), + ("Pass/1", (4, OrderedDict([("Pass/1/SP1", (0, "a")), ("Pass/1/SP2", (2, "c"))]))), + ] + ), + ) + assert actual_27 == expected_27 diff --git a/test/unit_test/search/test_search_strategy.py b/test/unit_test/search/test_search_strategy.py new file mode 100644 index 000000000..5bc4f43cb --- /dev/null +++ b/test/unit_test/search/test_search_strategy.py @@ -0,0 +1,849 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import re +from collections import OrderedDict + +import pytest + +from olive.evaluator.metric_result import MetricResult +from olive.search.search_parameter import Categorical +from olive.search.search_strategy import SearchStrategy, SearchStrategyConfig, SearchStrategyExecutionOrder + +# pylint: disable=protected-access +# ruff: noqa: PD011 + + +class TestSearchStrategy: + @pytest.mark.parametrize( + "execution_order", [SearchStrategyExecutionOrder.JOINT, SearchStrategyExecutionOrder.PASS_BY_PASS] + ) + def test_initialize(self, execution_order, tmpdir): + config = SearchStrategyConfig(execution_order=execution_order, sampler="random") + space_config = OrderedDict( + [ + ( + "PassA", + [ + {"PA0_SP1": Categorical(["A01a", "A01b"]), "PA0_SP2": Categorical(["A02a", "A02b"])}, + {"PA1_SP1": Categorical(["A11a", "A11b"])}, + {}, + {}, + ], + ), + ("PassB", [{"PB0_SP1": Categorical(["B01a", "B01b"])}, {}]), + ("PassC", [{}, {}]), + ] + ) + + strategy = SearchStrategy(config) + strategy.initialize(space_config, "whatever", {}) + + actual = re.sub(r"[\n\t\s]*", "", str(strategy._search_spaces)) + + if execution_order == SearchStrategyExecutionOrder.JOINT: + assert len(strategy._search_spaces) == 1 + + expected = re.sub( + r"[\n\t\s]*", + "", + """[ + SearchSpace([ + ('PassA', Categorical([ + SearchSpace([ + ('PA0_SP1', Categorical(['A01a', 'A01b'])), + ('PA0_SP2', Categorical(['A02a', 'A02b'])) + ], 4), + SearchSpace([('PA1_SP1', Categorical(['A11a', 'A11b']))], 2), + SearchSpace([], 1), + SearchSpace([], 1) + ]) + ), + ('PassB', Categorical([ + SearchSpace([('PB0_SP1', Categorical(['B01a', 'B01b']))], 2), + SearchSpace([], 1) + ]) + ), + ('PassC', Categorical([SearchSpace([], 1), SearchSpace([], 1)])) + ], 48) + ]""", + ) + + elif execution_order == SearchStrategyExecutionOrder.PASS_BY_PASS: + assert len(strategy._search_spaces) == 3 + + expected = re.sub( + r"[\n\t\s]*", + "", + """[ + SearchSpace([ + ('PassA', Categorical([ + SearchSpace([ + ('PA0_SP1', Categorical(['A01a', 'A01b'])), + ('PA0_SP2', Categorical(['A02a', 'A02b'])) + ], 4), + SearchSpace([('PA1_SP1', Categorical(['A11a', 'A11b']))], 2), + SearchSpace([], 1), + SearchSpace([], 1) + ]) + )], 8), + SearchSpace([ + ('PassB', Categorical([ + SearchSpace([('PB0_SP1', Categorical(['B01a', 'B01b']))], 2), + SearchSpace([], 1) + ]) + )], 3), + SearchSpace([('PassC', Categorical([SearchSpace([], 1), SearchSpace([], 1)]))], 2) + ]""", + ) + else: + expected = None + pytest.fail("Unsupported execution_order") + + assert actual == expected + + def test_iteration_joint(self, tmpdir): + config = SearchStrategyConfig(execution_order="joint", sampler="sequential") + space_config = OrderedDict( + [ + ( + "PassA", + [ + {"PA0_SP1": Categorical(["A01a", "A01b"]), "PA0_SP2": Categorical(["A02a", "A02b"])}, + {"PA1_SP1": Categorical(["A11a", "A11b"])}, + {}, + {}, + ], + ), + ("PassB", [{"PB0_SP1": Categorical(["B01a", "B01b"])}, {}]), + ("PassC", [{}, {}]), + ] + ) + + strategy = SearchStrategy(config) + strategy.initialize(space_config, "whatever", {}) + + actual = [(sample.search_point.index, sample.passes_configs) for sample in strategy] + expected = [ + ( + 0, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 1, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 2, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 3, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 4, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11a")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 5, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11b")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 6, + { + "PassA": OrderedDict([("index", 2), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 7, + { + "PassA": OrderedDict([("index", 3), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 8, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 9, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 10, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 11, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 12, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11a")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 13, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11b")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 14, + { + "PassA": OrderedDict([("index", 2), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 15, + { + "PassA": OrderedDict([("index", 3), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 16, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 17, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 18, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 19, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 20, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11a")]))]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 21, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11b")]))]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 22, + { + "PassA": OrderedDict([("index", 2), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 23, + { + "PassA": OrderedDict([("index", 3), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 0), ("params", OrderedDict())]), + }, + ), + ( + 24, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 25, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 26, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 27, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 28, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11a")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 29, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11b")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 30, + { + "PassA": OrderedDict([("index", 2), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 31, + { + "PassA": OrderedDict([("index", 3), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01a")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 32, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 33, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 34, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 35, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 36, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11a")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 37, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11b")]))]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 38, + { + "PassA": OrderedDict([("index", 2), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 39, + { + "PassA": OrderedDict([("index", 3), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 0), ("params", OrderedDict([("PB0_SP1", "B01b")]))]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 40, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 41, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02a")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 42, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01a"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 43, + { + "PassA": OrderedDict( + [("index", 0), ("params", OrderedDict([("PA0_SP1", "A01b"), ("PA0_SP2", "A02b")]))] + ), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 44, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11a")]))]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 45, + { + "PassA": OrderedDict([("index", 1), ("params", OrderedDict([("PA1_SP1", "A11b")]))]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 46, + { + "PassA": OrderedDict([("index", 2), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ( + 47, + { + "PassA": OrderedDict([("index", 3), ("params", OrderedDict())]), + "PassB": OrderedDict([("index", 1), ("params", OrderedDict())]), + "PassC": OrderedDict([("index", 1), ("params", OrderedDict())]), + }, + ), + ] + + assert actual == expected + + def test_iteration_pass_by_pass(self, tmpdir): + config = SearchStrategyConfig(execution_order="pass-by-pass", sampler="sequential") + space_config = OrderedDict( + [ + ( + "PassA", + [ + {"PA0_SP1": Categorical(["A01a", "A01b"]), "PA0_SP2": Categorical(["A02a", "A02b"])}, + {"PA1_SP1": Categorical(["A11a", "A11b"])}, + {}, + {}, + ], + ), + ("PassB", [{"PB0_SP1": Categorical(["B01a", "B01b"])}, {}]), + ("PassC", [{}, {}]), + ] + ) + objectives = OrderedDict( + [ + ( + "PassA", + [ + { + "accuracy-accuracy_custom": { + "goal": 0.70, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 24.0, "higher_is_better": False, "priority": 2}, + }, + { + "accuracy-accuracy_custom": { + "goal": 0.85, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 23.0, "higher_is_better": False, "priority": 2}, + }, + { + "accuracy-accuracy_custom": { + "goal": 0.80, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 25.0, "higher_is_better": False, "priority": 2}, + }, + { + "accuracy-accuracy_custom": { + "goal": 0.75, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 20.0, "higher_is_better": False, "priority": 2}, + }, + ], + ), + ( + "PassB", + [ + { + "accuracy-accuracy_custom": { + "goal": 0.70, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 40.0, "higher_is_better": False, "priority": 2}, + "latency-max": {"goal": 64.0, "higher_is_better": False, "priority": 2}, + }, + { + "accuracy-accuracy_custom": { + "goal": 0.80, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 45.0, "higher_is_better": False, "priority": 2}, + "latency-max": {"goal": 72.0, "higher_is_better": False, "priority": 2}, + }, + ], + ), + ( + "PassC", + [ + { + "accuracy-accuracy_custom": { + "goal": 0.89, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 16.0, "higher_is_better": False, "priority": 2}, + }, + { + "accuracy-accuracy_custom": { + "goal": 0.92, + "higher_is_better": True, + "priority": 1, + }, + "latency-avg": {"goal": 14.0, "higher_is_better": False, "priority": 2}, + }, + ], + ), + ] + ) + + signal1 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.96, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 4.5, "priority": 2, "higher_is_better": False}, + } + ) + signal2 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.72, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 55.9, "priority": 2, "higher_is_better": False}, + } + ) + signal3 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.73, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 8.9, "priority": 2, "higher_is_better": False}, + } + ) + signal4 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.91, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 13.9, "priority": 2, "higher_is_better": False}, + } + ) + signal5 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.82, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 53.0, "priority": 2, "higher_is_better": False}, + "latency-max": {"value": 60.1, "priority": 2, "higher_is_better": False}, + } + ) + signal6 = MetricResult.parse_obj( + { + "accuracy-accuracy_custom": { + "value": 0.76, + "priority": 1, + "higher_is_better": True, + }, + "latency-avg": {"value": 58.0, "priority": 2, "higher_is_better": False}, + "latency-max": {"value": 55.2, "priority": 3, "higher_is_better": False}, + } + ) + + signals = { + "PassA": [ + None, + ( + signal1, + ["model_id_1"], + False, + ), + None, + None, + ( + signal2, + ["model_id_2"], + False, + ), + ( + signal3, + ["model_id_3"], + True, + ), + ], + "PassB": [ + None, + None, + ( + signal5, + ["model_id_5"], + False, + ), + None, + None, + ( + signal6, + ["model_id_6"], + False, + ), + ], + "PassC": [ + None, + None, + None, + ( + signal4, + ["model_id_4"], + False, + ), + ], + } + + strategy = SearchStrategy(config) + strategy.initialize(space_config, "whatever", objectives) + + actual = [] + for sample in strategy: + actual.append((sample.search_point.index, sample.passes_configs, sample.model_ids)) + + pass_name = next(iter(sample.search_point.values.keys())) + if sample.search_point.index < len(signals[pass_name]) and signals[pass_name][sample.search_point.index]: + strategy.record_feedback_signal( + sample.search_point.index, *signals[pass_name][sample.search_point.index] + ) + + expected = [ + ( + 0, + OrderedDict( + {"PassA": OrderedDict({"index": 0, "params": OrderedDict({"PA0_SP1": "A01a", "PA0_SP2": "A02a"})})} + ), + ["whatever"], + ), + ( + 1, + OrderedDict( + {"PassA": OrderedDict({"index": 0, "params": OrderedDict({"PA0_SP1": "A01b", "PA0_SP2": "A02a"})})} + ), + ["whatever"], + ), + ( + 2, + OrderedDict( + {"PassA": OrderedDict({"index": 0, "params": OrderedDict({"PA0_SP1": "A01a", "PA0_SP2": "A02b"})})} + ), + ["whatever"], + ), + ( + 3, + OrderedDict( + {"PassA": OrderedDict({"index": 0, "params": OrderedDict({"PA0_SP1": "A01b", "PA0_SP2": "A02b"})})} + ), + ["whatever"], + ), + ( + 4, + OrderedDict({"PassA": OrderedDict({"index": 1, "params": OrderedDict({"PA1_SP1": "A11a"})})}), + ["whatever"], + ), + ( + 5, + OrderedDict({"PassA": OrderedDict({"index": 1, "params": OrderedDict({"PA1_SP1": "A11b"})})}), + ["whatever"], + ), + (6, OrderedDict({"PassA": OrderedDict({"index": 2, "params": OrderedDict()})}), ["whatever"]), + (7, OrderedDict({"PassA": OrderedDict({"index": 3, "params": OrderedDict()})}), ["whatever"]), + ( + 0, + OrderedDict({"PassB": OrderedDict({"index": 0, "params": OrderedDict({"PB0_SP1": "B01a"})})}), + ["model_id_1"], + ), + ( + 1, + OrderedDict({"PassB": OrderedDict({"index": 0, "params": OrderedDict({"PB0_SP1": "B01b"})})}), + ["model_id_1"], + ), + (2, OrderedDict({"PassB": OrderedDict({"index": 1, "params": OrderedDict()})}), ["model_id_1"]), + (0, OrderedDict({"PassC": OrderedDict({"index": 0, "params": OrderedDict()})}), ["model_id_5"]), + (1, OrderedDict({"PassC": OrderedDict({"index": 1, "params": OrderedDict()})}), ["model_id_5"]), + ] + + assert actual == expected diff --git a/test/unit_test/utils.py b/test/unit_test/utils.py index f3f670c08..e4f000796 100644 --- a/test/unit_test/utils.py +++ b/test/unit_test/utils.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import os from pathlib import Path +from typing import Any, Dict, Tuple, Type, Union from unittest.mock import MagicMock import numpy as np @@ -17,7 +18,7 @@ from olive.evaluator.metric import AccuracySubType, LatencySubType, Metric, MetricType from olive.evaluator.metric_config import MetricGoal from olive.model import HfModelHandler, ModelConfig, ONNXModelHandler, PyTorchModelHandler -from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.olive_pass import Pass, create_pass_from_dict ONNX_MODEL_PATH = Path(__file__).absolute().parent / "dummy_model.onnx" @@ -235,12 +236,14 @@ def get_throughput_metric(*lat_subtype, user_config=None): ) -def get_onnxconversion_pass(ignore_pass_config=True, target_opset=13): +def get_onnxconversion_pass( + ignore_pass_config=True, target_opset=13 +) -> Union[Type[Pass], Tuple[Type[Pass], Dict[str, Any]]]: from olive.passes.onnx.conversion import OnnxConversion onnx_conversion_config = {"target_opset": target_opset} p = create_pass_from_dict(OnnxConversion, onnx_conversion_config) - return p if ignore_pass_config else (p, p.to_json(check_object=True)["config"]) + return p if ignore_pass_config else (p, p.to_json(check_object=True)) def get_onnx_dynamic_quantization_pass(disable_search=False): diff --git a/test/unit_test/workflows/test_run_config.py b/test/unit_test/workflows/test_run_config.py index 75674e6cb..35fe011cf 100644 --- a/test/unit_test/workflows/test_run_config.py +++ b/test/unit_test/workflows/test_run_config.py @@ -150,42 +150,19 @@ def test_get_module_path(self, pass_type, is_onnx): assert pass_module.startswith("olive.passes.onnx") == is_onnx @pytest.mark.parametrize( - ("passes", "pass_flows", "is_onnx"), + ("passes", "is_onnx"), [ - (None, None, True), - ( - { - "lora": {"type": "LoRA"}, - }, - None, - False, - ), - ( - { - "lora": {"type": "LoRA"}, - "quantization": {"type": "IncQuantization"}, - }, - None, - True, - ), - ( - { - "lora": {"type": "LoRA"}, - "quantization": {"type": "IncQuantization"}, - }, - [["lora"]], - False, - ), + (None, True), + ({"lora": {"type": "LoRA"}}, False), + ({"lora": {"type": "LoRA"}, "quantization": {"type": "IncQuantization"}}, True), ], ) - def test_is_execution_provider_required(self, passes, pass_flows, is_onnx): + def test_is_execution_provider_required(self, passes, is_onnx): with self.user_script_config_file.open() as f: user_script_config = json.load(f) if passes: user_script_config["passes"] = passes - if pass_flows: - user_script_config["pass_flows"] = pass_flows run_config = RunConfig.parse_obj(user_script_config) result = is_execution_provider_required(run_config, self.package_config) @@ -212,7 +189,7 @@ def setup(self): }, } ], - "passes": {"tuning": {"type": "OrtSessionParamsTuning"}}, + "passes": {"tuning": [{"type": "OrtSessionParamsTuning"}]}, "engine": {"evaluate_input_model": False}, } @@ -286,10 +263,10 @@ def test_auto_insert_trust_remote_code( ) def test_str_to_data_config(self, data_config_str): config_dict = deepcopy(self.template) - config_dict["passes"]["tuning"]["data_config"] = data_config_str + config_dict["passes"]["tuning"][0]["data_config"] = data_config_str run_config = RunConfig.parse_obj(config_dict) - pass_data_config = run_config.passes["tuning"].config["data_config"] + pass_data_config = run_config.passes["tuning"][0].config["data_config"] if data_config_str is None: assert pass_data_config is None else: @@ -305,9 +282,11 @@ def setup(self): "type": "OnnxModel", }, "passes": { - "tuning": { - "type": "IncQuantization", - } + "tuning": [ + { + "type": "IncQuantization", + } + ] }, "evaluate_input_model": False, } @@ -319,19 +298,19 @@ def setup(self): (None, "SEARCHABLE_VALUES", False), (None, "dummy_approach", True), ( - {"execution_order": "joint", "search_algorithm": "exhaustive"}, + {"execution_order": "joint", "sampler": "sequential"}, "SEARCHABLE_VALUES", - True, + False, ), ], ) def test_pass_config_(self, search_strategy, approach, is_valid): config_dict = self.template.copy() config_dict["search_strategy"] = search_strategy - config_dict["passes"]["tuning"]["approach"] = approach + config_dict["passes"]["tuning"][0]["approach"] = approach if not is_valid: with pytest.raises(ValueError): # noqa: PT011 RunConfig.parse_obj(config_dict) else: config = RunConfig.parse_obj(config_dict) - assert config.passes["tuning"].config["approach"] == approach + assert config.passes["tuning"][0].config["approach"] == approach diff --git a/test/unit_test/workflows/test_workflow_run.py b/test/unit_test/workflows/test_workflow_run.py index ccdcd843c..f4a4ce6e8 100644 --- a/test/unit_test/workflows/test_workflow_run.py +++ b/test/unit_test/workflows/test_workflow_run.py @@ -1,3 +1,4 @@ +from copy import deepcopy from pathlib import Path from test.unit_test.utils import ( get_pytorch_model, @@ -100,7 +101,7 @@ def test_run_without_ep(mock_model_to_json, mock_model_from_json, mock_run, conf with user_script.open("w"): pass - config = config_test + config = deepcopy(config_test) config["passes"]["qat"]["config"]["user_script"] = str(user_script) config["engine"]["cache_dir"] = str(tmp_path / "cache") config["engine"]["output_dir"] = str(tmp_path / "output")