Skip to content

Commit

Permalink
📶 Make engine optional to simplify the default user config (#707)
Browse files Browse the repository at this point in the history
## Describe your changes
Make engine optional to simplify the default user config

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.

## (Optional) Issue link
  • Loading branch information
trajepl authored Nov 14, 2023
1 parent a4b9e51 commit 18d1203
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 26 deletions.
39 changes: 19 additions & 20 deletions docs/source/overview/options.md
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ When `pass_flows` is not specified, the passes are executed in the order of the

This is a dictionary that contains the information of the engine. The information of the engine contains following items:

- `search_strategy: [Dict | Boolean | None]` The search strategy of the engine. It contains the following items:
- `search_strategy: [Dict | Boolean | None]`, `None` by default. The search strategy of the engine. It contains the following items:

- `execution_order: [str]` The execution order of the optimizations of passes. The options are `pass-by-pass` and `joint`.

Expand All @@ -506,45 +506,44 @@ This is a dictionary that contains the information of the engine. The informatio
If `search_strategy` is `true`, the search strategy will be the default search strategy. The default search strategy is `exhaustive` search
algorithm with `joint` execution order.

- `evaluate_input_model: [Boolean]` In this mode, the engine will evaluate the input model using the engine's evaluator and return the results. If the engine has no evaluator, it will raise an error. This is `true` by default.
- `evaluate_input_model: [Boolean]` In this mode, the engine will evaluate the input model using the engine's evaluator and return the results. If the engine has no evaluator, it will skip the evaluation. This is `true` by default.

- `host: [str | Dict]` The host of the engine. It can be a string or a dictionary. If it is a string, it is the name of a system in `systems`.
- `host: [str | Dict | None]`, `None` be default. The host of the engine. It can be a string or a dictionary. If it is a string, it is the name of a system in `systems`.
If it is a dictionary, it contains the system information. If not specified, it is the local system.

- `target: [str | Dict]` The target to run model evaluations on. It can be a string or a dictionary. If it is a string, it is the name of
- `target: [str | Dict | None]`, `None` be default. The target to run model evaluations on. It can be a string or a dictionary. If it is a string, it is the name of
a system in `systems`. If it is a dictionary, it contains the system information. If not specified, it is the local system.

- `evaluator: [str | Dict]` The evaluator of the engine. It can be a string or a dictionary. If it is a string, it is the name of an evaluator
- `evaluator: [str | Dict | None]`, `None` by default. The evaluator of the engine. It can be a string or a dictionary. If it is a string, it is the name of an evaluator
in `evaluators`. If it is a dictionary, it contains the evaluator information. This evaluator will be used to evaluate the input model if
needed. It is also used to evaluate the output models of passes that don't have their own evaluators.
needed. It is also used to evaluate the output models of passes that don't have their own evaluators. If it is None, skip the evaluation for input model and any output models.

- `cache_dir: [str]` The directory to store the cache of the engine. If not specified, the cache will be stored in the `.olive-cache` directory
- `cache_dir: [str]`, `.olive-cache` by default. The directory to store the cache of the engine. If not specified, the cache will be stored in the `.olive-cache` directory
under the current working directory.

- `clean_cache: [Boolean]` This decides whether to clean the cache of the engine before running the engine. This is `false` by default.
- `clean_cache: [Boolean]`, `false` by default. This decides whether to clean the cache of the engine before running the engine.

- `clean_evaluation_cache: [Boolean]` This decides whether to clean the evaluation cache of the engine before running the engine. This is
`false` by default.
- `clean_evaluation_cache: [Boolean]` , `false` by default. This decides whether to clean the evaluation cache of the engine before running the engine.

- `plot_pareto_frontier` This decides whether to plot the pareto frontier of the search results. This is `false` by default.
- `plot_pareto_frontier`, `false` by default. This decides whether to plot the pareto frontier of the search results.

- `output_dir: [str]` The directory to store the output of the engine. If not specified, the output will be stored in the current working
- `output_dir: [str]`, `None` by default. The directory to store the output of the engine. If not specified, the output will be stored in the current working
directory. For a run with no search, the output is the output model of the final pass and its evaluation result. For a run with search, the
output is a json file with the search results.

- `output_name: [str]` The name of the output. This string will be used as the prefix of the output file name. If not specified, there is no
- `output_name: [str]`, `None` by default. The name of the output. This string will be used as the prefix of the output file name. If not specified, there is no
prefix.

- `packaging_config: [PackagingConfig]` Olive artifacts packaging configurations. If not specified, Olive will not package artifacts.
- `packaging_config: [PackagingConfig]`, `None` by default. Olive artifacts packaging configurations. If not specified, Olive will not package artifacts.

- `log_severity_level: [int]` The log severity level of Olive. The options are `0` for `VERBOSE`, `1` for
`INFO`, `2` for `WARNING`, `3` for `ERROR`, `4` for `FATAL`. The default value is `1` for `INFO`.
- `log_severity_level: [int]`, `1` by default. The log severity level of Olive. The options are `0` for `VERBOSE`, `1` for
`INFO`, `2` for `WARNING`, `3` for `ERROR`, `4` for `FATAL`.

- `ort_log_severity_level: [int]` The log severity level of ONNX Runtime C++ logs. The options are `0` for `VERBOSE`, `1` for
`INFO`, `2` for `WARNING`, `3` for `ERROR`, `4` for `FATAL`. The default value is `3` for `ERROR`.
- `ort_log_severity_level: [int]`, `3` by default. The log severity level of ONNX Runtime C++ logs. The options are `0` for `VERBOSE`, `1` for
`INFO`, `2` for `WARNING`, `3` for `ERROR`, `4` for `FATAL`.

- `ort_py_log_severity_level: [int]` The log severity level of ONNX Runtime Python logs. The options are `0` for `VERBOSE`, `1` for
`INFO`, `2` for `WARNING`, `3` for `ERROR`, `4` for `FATAL`. The default value is `3` for `ERROR`.
- `ort_py_log_severity_level: [int]`, `3` by default. The log severity level of ONNX Runtime Python logs. The options are `0` for `VERBOSE`, `1` for
`INFO`, `2` for `WARNING`, `3` for `ERROR`, `4` for `FATAL`.

Please find the detailed config options from following table for each search algorithm:

Expand Down
3 changes: 1 addition & 2 deletions olive/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def run_accelerator(
self.footprints[accelerator_spec].record(model_id=input_model_id)

try:
if evaluate_input_model and self.no_search and not self.evaluator_config:
if evaluate_input_model and not self.evaluator_config:
logger.debug(
"evaluate_input_model is True but no evaluator provided in no-search mode. Skipping input model"
" evaluation."
Expand All @@ -399,7 +399,6 @@ def run_accelerator(
prefix_output_name = (
f"{output_name}_{accelerator_spec}_" if output_name is not None else f"{accelerator_spec}"
)
assert self.evaluator_config is not None, "evaluate_input_model is True but no evaluator provided"
results = self._evaluate_model(
input_model_config, input_model_id, data_root, self.evaluator_config, accelerator_spec
)
Expand Down
14 changes: 12 additions & 2 deletions olive/workflows/run/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import Dict, List, Union

Expand All @@ -19,6 +20,8 @@
from olive.resource_path import AZUREML_RESOURCE_TYPES
from olive.systems.system_config import SystemConfig

logger = logging.getLogger(__name__)


class RunPassConfig(FullPassConfig):
host: SystemConfig = None
Expand Down Expand Up @@ -62,8 +65,8 @@ class RunConfig(ConfigBase):
data_root: str = None
data_configs: Dict[str, DataConfig] = None
evaluators: Dict[str, OliveEvaluatorConfig] = None
engine: RunEngineConfig = None
pass_flows: List[List[str]] = None
engine: RunEngineConfig
passes: Dict[str, RunPassConfig] = None

@validator("input_model", pre=True)
Expand All @@ -80,6 +83,12 @@ def insert_aml_client(cls, v, values):
v["config"]["model_path"]["config"]["azureml_client"] = values["azureml_client"]
return v

@validator("engine", pre=True, always=True)
def default_engine_config(cls, v):
if v is None:
v = {}
return v

@validator("data_configs", pre=True, always=True)
def insert_input_model_data_config(cls, v, values):
if "input_model" not in values:
Expand Down Expand Up @@ -150,7 +159,8 @@ def validate_engine(cls, v, values):
@validator("engine")
def validate_evaluate_input_model(cls, v):
if v.evaluate_input_model and v.evaluator is None:
raise ValueError("Evaluation only requires evaluator")
logger.info("No evaluator is specified, skip to evaluate model")
v.evaluate_input_model = False
return v

@validator("passes", pre=True, each_item=True)
Expand Down
17 changes: 16 additions & 1 deletion test/unit_test/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ def test_register_no_search_fail(self, tmpdir):

assert str(exc_info.value) == f"Search strategy is None but pass {name} has search space"

def test_default_engine_run(self, tmpdir):
# setup
model_config = get_pytorch_model_config()
engine = Engine({"cache_dir": tmpdir})
assert engine.no_search, "Expect no_search to be True by default"

engine.register(OnnxConversion, name="converter_13", config={"target_opset": 13}, clean_run_cache=True)
outputs = engine.run(model_config, output_dir=tmpdir)

assert outputs
for fp_nodes in outputs.values():
for node in fp_nodes.nodes.values():
assert node.model_config
assert node.from_pass == "OnnxConversion"
assert node.metrics is None, "Should not evaluate input/output model by default"

@patch("olive.systems.local.LocalSystem")
def test_run(self, mock_local_system, tmpdir):
# setup
Expand All @@ -101,7 +117,6 @@ def test_run(self, mock_local_system, tmpdir):
metric = get_accuracy_metric(AccuracySubType.ACCURACY_SCORE)
evaluator_config = OliveEvaluatorConfig(metrics=[metric])
options = {
"output_dir": tmpdir,
"output_name": "test",
"cache_dir": tmpdir,
"clean_cache": True,
Expand Down
60 changes: 60 additions & 0 deletions test/unit_test/workflows/mock_data/default_engine.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{
"input_model":{
"type": "PyTorchModel",
"config": {
"hf_config": {
"model_name": "Intel/bert-base-uncased-mrpc",
"task": "text-classification",
"dataset": {
"data_name":"glue",
"subset": "mrpc",
"split": "validation",
"input_cols": ["sentence1", "sentence2"],
"label_cols": ["label"],
"batch_size": 1
}
},
"io_config" : {
"input_names": ["input_ids", "attention_mask", "token_type_ids"],
"input_shapes": [[1, 128], [1, 128], [1, 128]],
"input_types": ["int64", "int64", "int64"],
"output_names": ["output"],
"dynamic_axes": {
"input_ids": {"0": "batch_size", "1": "seq_length"},
"attention_mask": {"0": "batch_size", "1": "seq_length"},
"token_type_ids": {"0": "batch_size", "1": "seq_length"}
}
}
}
},
"systems": {
"local_system": {"type": "LocalSystem"}
},
"passes": {
"conversion": {
"type": "OnnxConversion",
"config": {
"target_opset": 13
}
},
"transformers_optimization": {
"type": "OrtTransformersOptimization",
"disable_search": true,
"config": {"model_type": "bert"}
},
"quantization": {
"type": "OnnxQuantization",
"config": {
"data_config": "__input_model_data_config__"
}
},
"perf_tuning": {
"type": "OrtPerfTuning",
"config": {
"input_names": ["input_ids", "attention_mask", "token_type_ids"],
"input_shapes": [[1, 128], [1, 128], [1, 128]],
"input_types": ["int64", "int64", "int64"]
}
}
}
}
8 changes: 7 additions & 1 deletion test/unit_test/workflows/test_run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class TestRunConfig:
# TODO(jiapli): add more tests for different config files to test olive features
# like: Systems/Evaluation/Model and etc.
@pytest.fixture(autouse=True)
def setup(self):
Expand Down Expand Up @@ -132,6 +131,13 @@ def test_readymade_system(self):
cfg = RunConfig.parse_obj(user_script_config)
assert cfg.engine.target.config.accelerators == ["GPU"]

def test_default_engine(self):
default_engine_config_file = Path(__file__).parent / "mock_data" / "default_engine.json"
run_config = RunConfig.parse_file(default_engine_config_file)
assert run_config.evaluators is None
assert run_config.engine.host is None
assert run_config.engine.target is None


class TestDataConfigValidation:
@pytest.fixture(autouse=True)
Expand Down

0 comments on commit 18d1203

Please sign in to comment.