diff --git a/olive/engine/engine.py b/olive/engine/engine.py index cc88f0821..6d02e5819 100644 --- a/olive/engine/engine.py +++ b/olive/engine/engine.py @@ -26,6 +26,7 @@ from olive.logging import enable_filelog from olive.model import ModelConfig from olive.package_config import OlivePackageConfig +from olive.passes.olive_pass import FullPassConfig from olive.search.search_sample import SearchSample from olive.search.search_strategy import SearchStrategy, SearchStrategyConfig from olive.systems.common import SystemType @@ -682,28 +683,27 @@ def _run_pass( """Run a pass on the input model.""" run_start_time = datetime.now().timestamp() - pass_config: RunPassConfig = self.computed_passes_configs[pass_name] - pass_type_name = pass_config.type + run_pass_config: RunPassConfig = self.computed_passes_configs[pass_name] + pass_type_name = run_pass_config.type logger.info("Running pass %s:%s", pass_name, pass_type_name) # check whether the config is valid - pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_config.type) - if not pass_cls.validate_config(pass_config.config, accelerator_spec): + pass_cls: Type[Pass] = self.olive_config.import_pass_module(run_pass_config.type) + if not pass_cls.validate_config(run_pass_config.config, accelerator_spec): logger.warning("Invalid config, pruned.") - logger.debug(pass_config) + logger.debug(run_pass_config) # no need to record in footprint since there was no run and thus no valid/failed model # invalid configs are also not cached since the same config can be valid for other accelerator specs # a pass can be accelerator agnostic but still have accelerator specific invalid configs # this helps reusing cached models for different accelerator specs return INVALID_CONFIG, None - p: Pass = pass_cls(accelerator_spec, pass_config.config, self.get_host_device()) - pass_config = p.config.to_json() + pass_config = run_pass_config.config.to_json() output_model_config = None # load run from cache if it exists - run_accel = None if p.is_accelerator_agnostic(accelerator_spec) else accelerator_spec + run_accel = None if pass_cls.is_accelerator_agnostic(accelerator_spec) else accelerator_spec output_model_id = self.cache.get_output_model_id(pass_type_name, pass_config, input_model_id, run_accel) run_cache = self.cache.load_run_from_model_id(output_model_id) if run_cache: @@ -734,7 +734,7 @@ def _run_pass( input_model_config = self.cache.prepare_resources_for_local(input_model_config) try: - if p.run_on_target: + if pass_cls.run_on_target: if self.target.system_type == SystemType.IsolatedORT: logger.warning( "Cannot run pass %s on IsolatedORT target, will use the host to run the pass.", pass_name @@ -742,7 +742,10 @@ def _run_pass( else: host = self.target - output_model_config = host.run_pass(p, input_model_config, output_model_path) + full_pass_config = FullPassConfig.from_run_pass_config( + run_pass_config, accelerator_spec, self.get_host_device() + ) + output_model_config = host.run_pass(full_pass_config, input_model_config, output_model_path) except OlivePassError: logger.exception("Pass run_pass failed") output_model_config = FAILED_CONFIG diff --git a/olive/passes/olive_pass.py b/olive/passes/olive_pass.py index b0b51af0a..c000016cf 100644 --- a/olive/passes/olive_pass.py +++ b/olive/passes/olive_pass.py @@ -6,14 +6,15 @@ import logging import shutil from abc import ABC, abstractmethod +from copy import deepcopy from pathlib import Path -from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, get_args +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, get_args from olive.common.config_utils import ParamCategory, validate_config -from olive.common.pydantic_v1 import BaseModel, ValidationError, create_model +from olive.common.pydantic_v1 import BaseModel, ValidationError, create_model, validator from olive.common.user_module_loader import UserModuleLoader from olive.data.config import DataConfig -from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec +from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec, Device from olive.model import CompositeModelHandler, DistributedOnnxModelHandler, OliveModelHandler, ONNXModelHandler from olive.passes.pass_config import ( AbstractPassConfig, @@ -32,6 +33,9 @@ ) from olive.search.utils import cyclic_search_space, order_search_parameters +if TYPE_CHECKING: + from olive.engine.config import RunPassConfig + logger = logging.getLogger(__name__) # ruff: noqa: B027 @@ -80,17 +84,10 @@ def __init__( if hasattr(self.config, "user_script") and hasattr(self.config, "script_dir"): self._user_module_loader = UserModuleLoader(self.config.user_script, self.config.script_dir) - # Params that are paths [(param_name, required)] - self.path_params = [ - (param, param_config.required, param_config.category) - for param, param_config in self.default_config(accelerator_spec).items() - if param_config.category in (ParamCategory.PATH, ParamCategory.DATA) - ] - self._initialized = False - @staticmethod - def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool: + @classmethod + def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool: """Whether the pass is accelerator agnostic. If True, the pass will be reused for all accelerators. The default value is True. The subclass could choose to override this method to return False by using the @@ -482,17 +479,35 @@ class FullPassConfig(AbstractPassConfig): reconstruct the pass from the JSON file. """ - accelerator: Dict[str, str] = None - host_device: Optional[str] = None + accelerator: Optional[AcceleratorSpec] = None + host_device: Optional[Device] = None - def create_pass(self): - if not isinstance(self.accelerator, dict): - raise ValueError(f"accelerator must be a dict, got {self.accelerator}") + @validator("accelerator", pre=True) + def validate_accelerator(cls, v): + if isinstance(v, AcceleratorSpec): + return v + elif isinstance(v, dict): + return AcceleratorSpec(**v) + raise ValueError("Invalid accelerator input.") - pass_cls = Pass.registry[self.type.lower()] - accelerator_spec = AcceleratorSpec(**self.accelerator) # pylint: disable=not-a-mapping - self.config = pass_cls.generate_config(accelerator_spec, self.config) - return pass_cls(accelerator_spec, self.config, self.host_device) + def create_pass(self) -> Pass: + """Create a Pass.""" + return super().create_pass_with_args(self.accelerator, self.host_device) + + @staticmethod + def from_run_pass_config( + run_pass_config: Union[Dict[str, Any], "RunPassConfig"], + accelerator: "AcceleratorSpec", + host_device: Device = None, + ) -> "FullPassConfig": + config = deepcopy(run_pass_config) if isinstance(run_pass_config, dict) else run_pass_config.dict() + config.update( + { + "accelerator": accelerator, + "host_device": host_device, + } + ) + return validate_config(config, FullPassConfig) # TODO(myguo): deprecate or remove this function by explicitly specify the accelerator_spec in the arguments diff --git a/olive/passes/onnx/inc_quantization.py b/olive/passes/onnx/inc_quantization.py index 97b251420..c445900bb 100644 --- a/olive/passes/onnx/inc_quantization.py +++ b/olive/passes/onnx/inc_quantization.py @@ -251,8 +251,8 @@ class IncQuantization(Pass): """Quantize ONNX model with IntelĀ® Neural Compressor.""" - @staticmethod - def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool: + @classmethod + def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool: """Override this method to return False by using the accelerator spec information.""" return False diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index 4b393a2a2..f719335d2 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -114,8 +114,8 @@ def validate_config( return False return True - @staticmethod - def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool: + @classmethod + def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool: return False def _run_for_config( diff --git a/olive/passes/onnx/optimum_merging.py b/olive/passes/onnx/optimum_merging.py index a9834a15a..03de29ea7 100644 --- a/olive/passes/onnx/optimum_merging.py +++ b/olive/passes/onnx/optimum_merging.py @@ -19,8 +19,8 @@ class OptimumMerging(Pass): _accepts_composite_model = True - @staticmethod - def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool: + @classmethod + def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool: """Override this method to return False by using the accelerator spec information.""" return False diff --git a/olive/passes/onnx/session_params_tuning.py b/olive/passes/onnx/session_params_tuning.py index baa392111..3f4ce88c0 100644 --- a/olive/passes/onnx/session_params_tuning.py +++ b/olive/passes/onnx/session_params_tuning.py @@ -79,8 +79,8 @@ def get_thread_affinity_nums(affinity_str): class OrtSessionParamsTuning(Pass): """Optimize ONNX Runtime inference settings.""" - @staticmethod - def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool: + @classmethod + def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool: """Override this method to return False by using the accelerator spec information.""" return False diff --git a/olive/passes/onnx/transformer_optimization.py b/olive/passes/onnx/transformer_optimization.py index 931c2d689..352371c1d 100644 --- a/olive/passes/onnx/transformer_optimization.py +++ b/olive/passes/onnx/transformer_optimization.py @@ -36,8 +36,8 @@ class OrtTransformersOptimization(Pass): # using a Linux machine which doesn't support onnxruntime-directml package. # It is enough for the pass to fail if `opt_level` > 0 and the host doesn't have the required packages. - @staticmethod - def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool: + @classmethod + def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool: """Override this method to return False by using the accelerator spec information.""" from onnxruntime import __version__ as OrtVersion from packaging import version diff --git a/olive/passes/onnx/vitis_ai_quantization.py b/olive/passes/onnx/vitis_ai_quantization.py index dc3bd87f8..408b1b01a 100644 --- a/olive/passes/onnx/vitis_ai_quantization.py +++ b/olive/passes/onnx/vitis_ai_quantization.py @@ -209,8 +209,8 @@ def _initialize(self): super()._initialize() self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_vaiq_tmp") - @staticmethod - def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool: + @classmethod + def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool: """Override this method to return False by using the accelerator spec information.""" return False diff --git a/olive/passes/pass_config.py b/olive/passes/pass_config.py index 624af33c7..c7bf15507 100644 --- a/olive/passes/pass_config.py +++ b/olive/passes/pass_config.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from pathlib import Path -from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Type, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Set, Type, Union from olive.common.config_utils import ( ConfigBase, @@ -20,6 +20,10 @@ from olive.resource_path import validate_resource_path from olive.search.search_parameter import SearchParameter, SpecialParamValue, json_to_search_parameter +if TYPE_CHECKING: + from olive.hardware.accelerator import AcceleratorSpec + from olive.passes.olive_pass import Pass + class PassParamDefault(StrEnumBase): """Default values for passes.""" @@ -128,6 +132,21 @@ class AbstractPassConfig(NestedConfig): def validate_type(cls, v): return validate_lowercase(v) + @validator("config", pre=True, always=True) + def validate_config(cls, v): + return v or {} + + def create_pass_with_args(self, accelerator: "AcceleratorSpec", host_device: Device) -> "Pass": + """Create a Pass.""" + from olive.package_config import OlivePackageConfig + + olive_config = OlivePackageConfig.load_default_config() + pass_cls: Type[Pass] = olive_config.import_pass_module(self.type) + self.config = pass_cls.generate_config( + accelerator, self.config if isinstance(self.config, dict) else self.config.dict() + ) + return pass_cls(accelerator, self.config, host_device) + def create_config_class( pass_type: str, diff --git a/olive/passes/pytorch/capture_split_info.py b/olive/passes/pytorch/capture_split_info.py index 0ea1674e0..11b13574d 100644 --- a/olive/passes/pytorch/capture_split_info.py +++ b/olive/passes/pytorch/capture_split_info.py @@ -81,8 +81,8 @@ def validate_config( return True - @staticmethod - def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool: + @classmethod + def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool: return False def _run_for_config( diff --git a/olive/systems/azureml/aml_system.py b/olive/systems/azureml/aml_system.py index 76634908d..fa4a58028 100644 --- a/olive/systems/azureml/aml_system.py +++ b/olive/systems/azureml/aml_system.py @@ -40,7 +40,7 @@ if TYPE_CHECKING: from olive.hardware.accelerator import AcceleratorSpec - from olive.passes.olive_pass import Pass + from olive.passes.olive_pass import FullPassConfig logger = logging.getLogger(__name__) @@ -243,15 +243,22 @@ def _assert_not_none(self, obj): if obj is None: raise ValueError(f"{obj.__class__.__name__} is missing in the inputs!") - def run_pass(self, the_pass: "Pass", model_config: ModelConfig, output_model_path: str) -> ModelConfig: + def run_pass( + self, + full_pass_config: "FullPassConfig", + model_config: "ModelConfig", + output_model_path: str, + ) -> ModelConfig: """Run the pass on the model.""" ml_client = self.azureml_client_config.create_client() - # serialize pass - pass_config = the_pass.to_json(check_object=True) + # serialize config + serialized_pass_config = full_pass_config.to_json(check_object=True) with tempfile.TemporaryDirectory() as tempdir: - pipeline_job = self._create_pipeline_for_pass(tempdir, model_config.to_json(check_object=True), pass_config) + pipeline_job = self._create_pipeline_for_pass( + tempdir, model_config.to_json(check_object=True), serialized_pass_config + ) # submit job named_outputs_dir = self._run_job( @@ -259,7 +266,7 @@ def run_pass(self, the_pass: "Pass", model_config: ModelConfig, output_model_pat pipeline_job, "olive-pass", tempdir, - tags={"Pass": pass_config["type"]}, + tags={"Pass": serialized_pass_config["type"]}, output_name="pipeline_output", ) pipeline_output_path = named_outputs_dir / "pipeline_output" diff --git a/olive/systems/docker/docker_system.py b/olive/systems/docker/docker_system.py index 69e7e8468..81d7480b9 100644 --- a/olive/systems/docker/docker_system.py +++ b/olive/systems/docker/docker_system.py @@ -8,7 +8,7 @@ import shutil import tempfile from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import docker from docker.errors import BuildError, ContainerError @@ -18,6 +18,7 @@ from olive.evaluator.metric_result import MetricResult from olive.hardware import Device from olive.model import ModelConfig +from olive.package_config import OlivePackageConfig from olive.systems.common import AcceleratorConfig, LocalDockerConfig, SystemType from olive.systems.olive_system import OliveSystem from olive.systems.system_config import DockerTargetUserConfig @@ -26,7 +27,7 @@ from olive.evaluator.metric import Metric from olive.evaluator.olive_evaluator import OliveEvaluatorConfig from olive.hardware.accelerator import AcceleratorSpec - from olive.passes import Pass + from olive.passes.olive_pass import FullPassConfig, Pass logger = logging.getLogger(__name__) @@ -103,16 +104,23 @@ def __init__( _print_docker_logs(e.build_log, logging.ERROR) raise - def run_pass(self, the_pass: "Pass", model_config: "ModelConfig", output_model_path: str) -> "ModelConfig": + def run_pass( + self, + full_pass_config: "FullPassConfig", + model_config: "ModelConfig", + output_model_path: str, + ) -> "ModelConfig": """Run the pass on the model.""" with tempfile.TemporaryDirectory() as tempdir: - return self._run_pass_container(Path(tempdir), the_pass, model_config, output_model_path) + return self._run_pass_container(Path(tempdir), full_pass_config, model_config, output_model_path) def _run_pass_container( - self, workdir: Path, the_pass: "Pass", model_config: "ModelConfig", output_model_path: str + self, + workdir: Path, + full_pass_config: "FullPassConfig", + model_config: "ModelConfig", + output_model_path: str, ) -> "ModelConfig": - pass_config = the_pass.to_json(check_object=True) - volumes_list = [] runner_output_path = "runner_output" runner_output_name = "runner_res.json" @@ -133,11 +141,14 @@ def _run_pass_container( volumes_list.extend(model_mount_str_list) # data_dir or data_config - docker_data_paths, data_mount_str_list = self._create_data_mounts_for_pass(container_root_path, the_pass) + docker_data_paths, data_mount_str_list = self._create_data_mounts_for_pass( + container_root_path, + full_pass_config, + ) volumes_list.extend(data_mount_str_list) # mount config file - data = self._create_runner_config(model_config, pass_config, docker_model_files, docker_data_paths) + data = self._create_runner_config(model_config, full_pass_config, docker_model_files, docker_data_paths) logger.debug("Runner config is %s", data) docker_config_file, config_file_mount_str = docker_utils.create_config_file( workdir=workdir, @@ -163,7 +174,7 @@ def _run_pass_container( ) model_output_json_file = self._run_container( - runner_command, volumes_list, output_local_path, runner_output_name, the_pass.accelerator_spec + runner_command, volumes_list, output_local_path, runner_output_name, full_pass_config.accelerator ) if model_output_json_file.is_file(): with model_output_json_file.open() as f: @@ -289,7 +300,7 @@ def _create_eval_config(model_config: "ModelConfig", metrics: List["Metric"], mo @staticmethod def _create_runner_config( model_config: "ModelConfig", - pass_config: Dict[str, Any], + full_pass_config: "FullPassConfig", model_mounts: Dict[str, str], data_mounts: Dict[str, str], ): @@ -297,11 +308,11 @@ def _create_runner_config( for k, v in model_mounts.items(): model_json["config"][k] = v - pass_config_copy = copy.deepcopy(pass_config) + pass_json = full_pass_config.to_json(check_object=True) for k, v in data_mounts.items(): - pass_config_copy["config"][k] = v + pass_json["config"][k] = v - return {"model": model_json, "pass": pass_config_copy} + return {"model": model_json, "pass": pass_json} def _run_container( self, @@ -358,16 +369,25 @@ def _run_container( return Path(output_local_path) / output_name - def _create_data_mounts_for_pass(self, container_root_path: Path, the_pass: "Pass"): + def _create_data_mounts_for_pass( + self, container_root_path: Path, full_pass_config: "FullPassConfig" + ) -> Tuple[Dict[str, str], List[str]]: mounts = {} mount_strs = [] - config_dict = the_pass.config.dict() - for param, _, category in the_pass.path_params: - param_val = config_dict.get(param) - if category == ParamCategory.DATA and param_val: - mount = str(container_root_path / param) - mounts[param] = mount - mount_strs.append(f"{param_val}:{mount}") + if not full_pass_config.config: + return mounts, mount_strs + + olive_config = OlivePackageConfig.load_default_config() + pass_cls: Type[Pass] = olive_config.import_pass_module(full_pass_config.type) + + config_dict = full_pass_config.to_json().get("config") or {} + for param, param_config in pass_cls.default_config(full_pass_config.accelerator).items(): + if param_config.category == ParamCategory.DATA: + param_val = config_dict.get(param) + if param_val: + mount = str(container_root_path / param) + mounts[param] = mount + mount_strs.append(f"{param_val}:{mount}") return mounts, mount_strs diff --git a/olive/systems/docker/runner.py b/olive/systems/docker/runner.py index fe82625ec..48a3fbe36 100644 --- a/olive/systems/docker/runner.py +++ b/olive/systems/docker/runner.py @@ -12,7 +12,6 @@ from olive.common.hf.login import huggingface_login from olive.logging import set_verbosity_from_env from olive.model import ModelConfig -from olive.package_config import OlivePackageConfig from olive.passes.olive_pass import FullPassConfig logger = logging.getLogger("olive") @@ -31,10 +30,6 @@ def runner_entry(config, output_path, output_name): pass_config = config_json["pass"] - # Import the pass package configuration from the package_config - package_config = OlivePackageConfig.load_default_config() - package_config.import_pass_module(pass_config["type"]) - the_pass = FullPassConfig.from_json(pass_config).create_pass() output_model = the_pass.run(model, output_path) # save model json diff --git a/olive/systems/isolated_ort/isolated_ort_system.py b/olive/systems/isolated_ort/isolated_ort_system.py index a16d5952f..eaba9f096 100644 --- a/olive/systems/isolated_ort/isolated_ort_system.py +++ b/olive/systems/isolated_ort/isolated_ort_system.py @@ -30,7 +30,7 @@ from olive.evaluator.olive_evaluator import OliveEvaluatorConfig from olive.hardware.accelerator import AcceleratorSpec from olive.model import ModelConfig, ONNXModelHandler - from olive.passes.olive_pass import Pass + from olive.passes.olive_pass import FullPassConfig logger = logging.getLogger(__name__) @@ -63,7 +63,7 @@ def __init__( def run_pass( self, - the_pass: "Pass", + full_pass_config: "FullPassConfig", model_config: "ModelConfig", output_model_path: str, ) -> "ModelConfig": diff --git a/olive/systems/local.py b/olive/systems/local.py index 355dd747b..71de29328 100644 --- a/olive/systems/local.py +++ b/olive/systems/local.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from olive.evaluator.metric_result import MetricResult from olive.evaluator.olive_evaluator import OliveEvaluator, OliveEvaluatorConfig - from olive.passes.olive_pass import Pass + from olive.passes.olive_pass import FullPassConfig, Pass class LocalSystem(OliveSystem): @@ -20,11 +20,12 @@ class LocalSystem(OliveSystem): def run_pass( self, - the_pass: "Pass", - model_config: ModelConfig, + full_pass_config: "FullPassConfig", + model_config: "ModelConfig", output_model_path: str, ) -> ModelConfig: """Run the pass on the model.""" + the_pass: Pass = full_pass_config.create_pass() model = model_config.create_model() output_model = the_pass.run(model, output_model_path) return ModelConfig.from_json(output_model.to_json()) diff --git a/olive/systems/olive_system.py b/olive/systems/olive_system.py index 6c3339d9b..e05e62c62 100644 --- a/olive/systems/olive_system.py +++ b/olive/systems/olive_system.py @@ -14,7 +14,7 @@ from olive.evaluator.olive_evaluator import OliveEvaluatorConfig from olive.hardware.accelerator import AcceleratorSpec from olive.model import ModelConfig - from olive.passes.olive_pass import Pass + from olive.passes.olive_pass import FullPassConfig logger = logging.getLogger(__name__) @@ -41,8 +41,13 @@ def __init__( self.hf_token = hf_token @abstractmethod - def run_pass(self, the_pass: "Pass", model_config: "ModelConfig", output_model_path: str) -> "ModelConfig": - """Run the pass on the model at a specific point in the search space.""" + def run_pass( + self, + full_pass_config: "FullPassConfig", + model_config: "ModelConfig", + output_model_path: str, + ) -> "ModelConfig": + """Create and run the pass on the model.""" raise NotImplementedError @abstractmethod diff --git a/olive/systems/python_environment/python_environment_system.py b/olive/systems/python_environment/python_environment_system.py index 52f2528f1..252314bba 100644 --- a/olive/systems/python_environment/python_environment_system.py +++ b/olive/systems/python_environment/python_environment_system.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from olive.evaluator.olive_evaluator import OliveEvaluatorConfig from olive.hardware.accelerator import AcceleratorSpec - from olive.passes.olive_pass import Pass + from olive.passes.olive_pass import FullPassConfig logger = logging.getLogger(__name__) @@ -107,15 +107,14 @@ def _run_command(self, script_path: Path, config_jsons: Dict[str, Any], **kwargs def run_pass( self, - the_pass: "Pass", - model_config: ModelConfig, + full_pass_config: "FullPassConfig", + model_config: "ModelConfig", output_model_path: str, - ) -> ModelConfig: + ) -> "ModelConfig": """Run the pass on the model.""" - pass_config = the_pass.to_json(check_object=True) config_jsons = { "model_config": model_config.to_json(check_object=True), - "pass_config": pass_config, + "pass_config": full_pass_config.to_json(check_object=True), } output_model_json = self._run_command( self.pass_runner_path, diff --git a/test/integ_test/aml_model_test/test_aml_model.py b/test/integ_test/aml_model_test/test_aml_model.py index 04462f7da..b444185f7 100644 --- a/test/integ_test/aml_model_test/test_aml_model.py +++ b/test/integ_test/aml_model_test/test_aml_model.py @@ -7,7 +7,7 @@ from olive.azureml.azureml_client import AzureMLClientConfig from olive.model import ModelConfig -from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.olive_pass import FullPassConfig from olive.passes.onnx.conversion import OnnxConversion from olive.resource_path import ResourcePath from olive.systems.azureml import AzureMLDockerConfig, AzureMLSystem @@ -39,12 +39,9 @@ def test_aml_model_pass_run(tmp_path): # ------------------------------------------------------------------ # Onnx conversion pass # config can be a dictionary - onnx_conversion_config = { - "target_opset": 13, - } onnx_model_file = tmp_path / "model.onnx" - onnx_conversion_pass = create_pass_from_dict(OnnxConversion, onnx_conversion_config) - onnx_model = aml_system.run_pass(onnx_conversion_pass, pytorch_model_config, onnx_model_file) + full_pass_config = FullPassConfig.parse_obj({"type": OnnxConversion.__name__, "config": {"target_opset": 13}}) + onnx_model = aml_system.run_pass(full_pass_config, pytorch_model_config, onnx_model_file) model_path = onnx_model.config["model_path"] if isinstance(model_path, ResourcePath): model_path = model_path.get_path() diff --git a/test/integ_test/pass_runner/test_docker_system.py b/test/integ_test/pass_runner/test_docker_system.py index 7e5454cdd..3d70bae2b 100644 --- a/test/integ_test/pass_runner/test_docker_system.py +++ b/test/integ_test/pass_runner/test_docker_system.py @@ -13,7 +13,7 @@ from olive.hardware.accelerator import DEFAULT_CPU_ACCELERATOR from olive.logging import set_default_logger_severity from olive.model.config.model_config import ModelConfig -from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.olive_pass import FullPassConfig from olive.passes.onnx.session_params_tuning import OrtSessionParamsTuning @@ -33,8 +33,10 @@ def test_pass_runner(tmp_path): model_conf = ModelConfig.parse_obj({"type": "ONNXModel", "config": model_config}) set_default_logger_severity(0) - p = create_pass_from_dict(OrtSessionParamsTuning, {}, True, DEFAULT_CPU_ACCELERATOR) - output_model = docker_target.run_pass(p, model_conf, tmp_path) + full_pass_config = FullPassConfig.from_run_pass_config( + {"type": OrtSessionParamsTuning.__name__}, DEFAULT_CPU_ACCELERATOR + ) + output_model = docker_target.run_pass(full_pass_config, model_conf, tmp_path) result_eps = output_model.config["inference_settings"]["execution_provider"] assert result_eps == [DEFAULT_CPU_ACCELERATOR.execution_provider] assert output_model.config["model_path"] == model_config["model_path"] diff --git a/test/unit_test/systems/azureml/test_aml_system.py b/test/unit_test/systems/azureml/test_aml_system.py index e947ae548..413dbadb1 100644 --- a/test/unit_test/systems/azureml/test_aml_system.py +++ b/test/unit_test/systems/azureml/test_aml_system.py @@ -34,8 +34,6 @@ from olive.evaluator.olive_evaluator import OliveEvaluatorConfig from olive.hardware import DEFAULT_CPU_ACCELERATOR from olive.model import ONNXModelHandler -from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.onnx.conversion import OnnxConversion from olive.resource_path import ResourcePath, ResourceType from olive.systems.azureml.aml_evaluation_runner import main as aml_evaluation_runner_main from olive.systems.azureml.aml_pass_runner import main as aml_pass_runner_main @@ -167,8 +165,9 @@ def test_run_pass(self, mock_create_pipeline, mock_retry_func, tmp_path): with dummy_config_path.open("w") as f: json.dump(dummy_config, f, indent=4) - onnx_conversion_config = {} - p = create_pass_from_dict(OnnxConversion, onnx_conversion_config) + full_pass_config = MagicMock() + full_pass_config.create_pass.return_value = MagicMock() + model_config = get_hf_model_config() output_model_path = tmp_path / "output_folder" / "output_model_path" output_model_path.mkdir(parents=True, exist_ok=True) @@ -186,11 +185,11 @@ def test_run_pass(self, mock_create_pipeline, mock_retry_func, tmp_path): with patch("olive.systems.azureml.aml_system.tempfile.TemporaryDirectory") as mock_tempdir: mock_tempdir.return_value.__enter__.return_value = output_folder # execute - actual_res = self.system.run_pass(p, model_config, output_model_path) + actual_res = self.system.run_pass(full_pass_config, model_config, output_model_path) # assert mock_create_pipeline.assert_called_once_with( - output_folder, model_config.to_json(check_object=True), p.to_json() + output_folder, model_config.to_json(check_object=True), full_pass_config.to_json() ) assert mock_retry_func.call_count == 2 ml_client.jobs.stream.assert_called_once() diff --git a/test/unit_test/systems/docker/test_docker_system.py b/test/unit_test/systems/docker/test_docker_system.py index b1eaca8dd..2d1d348b7 100644 --- a/test/unit_test/systems/docker/test_docker_system.py +++ b/test/unit_test/systems/docker/test_docker_system.py @@ -15,7 +15,7 @@ from olive.evaluator.metric_result import joint_metric_key from olive.evaluator.olive_evaluator import OliveEvaluatorConfig from olive.hardware import DEFAULT_CPU_ACCELERATOR -from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.olive_pass import FullPassConfig from olive.passes.onnx.session_params_tuning import OrtSessionParamsTuning from olive.systems.common import LocalDockerConfig from olive.systems.docker.docker_system import DockerSystem @@ -226,7 +226,9 @@ def container_run(image, command=None, stdout=True, stderr=False, remove=False, ) docker_system = DockerSystem(docker_config, is_dev=True) - p = create_pass_from_dict(OrtSessionParamsTuning, {}, disable_search=True) + full_pass_config = FullPassConfig.from_run_pass_config( + {"type": OrtSessionParamsTuning.__name__}, DEFAULT_CPU_ACCELERATOR + ) output_folder = str(tmp_path / "onnx") def validate_file_or_folder(v, values, **kwargs): @@ -240,7 +242,7 @@ def is_dir_mock(self): ), patch("olive.resource_path._validate_path", side_effect=validate_file_or_folder), patch.object( Path, "is_dir", side_effect=is_dir_mock, autospec=True ): - output_model = docker_system.run_pass(p, onnx_model, output_folder) + output_model = docker_system.run_pass(full_pass_config, onnx_model, output_folder) assert output_model is not None runner_local_path = Path(__file__).resolve().parents[4] / "olive" / "systems" / "docker" / "runner.py" @@ -275,20 +277,20 @@ def test_runner_entry(self, tmp_path): from olive.systems.docker import utils as docker_utils from olive.systems.docker.runner import runner_entry as docker_runner_entry - p = create_pass_from_dict(OrtSessionParamsTuning, {}, disable_search=True) - pass_config = p.to_json(check_object=True) - - onnx_model = get_onnx_model_config() + full_pass_config = FullPassConfig.from_run_pass_config( + {"type": OrtSessionParamsTuning.__name__}, DEFAULT_CPU_ACCELERATOR + ) + onnx_model_config = get_onnx_model_config() container_root_path = Path("/olive-ws/") data = DockerSystem._create_runner_config( - onnx_model, - pass_config, - {"model_path": onnx_model.config["model_path"]}, + onnx_model_config, + full_pass_config, + {"model_path": onnx_model_config.config["model_path"]}, {}, ) docker_utils.create_config_file(tmp_path, data, container_root_path) - with patch.object(OrtSessionParamsTuning, "run", return_value=onnx_model): + with patch.object(OrtSessionParamsTuning, "run", return_value=onnx_model_config): docker_runner_entry(str(tmp_path / "config.json"), str(tmp_path), "runner_res.json") assert (tmp_path / "runner_res.json").exists() diff --git a/test/unit_test/systems/python_environment/test_python_environment_system.py b/test/unit_test/systems/python_environment/test_python_environment_system.py index 4087c6247..af104b916 100644 --- a/test/unit_test/systems/python_environment/test_python_environment_system.py +++ b/test/unit_test/systems/python_environment/test_python_environment_system.py @@ -25,6 +25,7 @@ from olive.evaluator.metric_result import MetricResult, joint_metric_key from olive.evaluator.olive_evaluator import OliveEvaluatorConfig from olive.hardware import DEFAULT_CPU_ACCELERATOR +from olive.passes.olive_pass import FullPassConfig from olive.systems.python_environment import PythonEnvironmentSystem from olive.systems.python_environment.evaluation_runner import main as evaluation_runner_main from olive.systems.python_environment.pass_runner import main as pass_runner_main @@ -143,17 +144,16 @@ def test_run_pass(self, mock_model_config_parse_obj, mock__run_command): model_config = MagicMock() dummy_model_config = {"dummy_model_key": "dummy_model_value"} model_config.to_json.return_value = dummy_model_config - the_pass = MagicMock() - dummy_pass_config = { - "type": "DummyPass", - "config": { - "dummy_param_1": "dummy_param_1_value", - "dummy_param_2": "dummy_param_2_value", - }, - } - dummy_config = dummy_pass_config["config"] - expected_pass_config = {"type": "DummyPass", "config": dummy_config} - the_pass.to_json.return_value = dummy_pass_config + dummy_full_pass_config = FullPassConfig.parse_obj( + { + "type": "DummyPass", + "config": { + "dummy_param_1": "dummy_param_1_value", + "dummy_param_2": "dummy_param_2_value", + }, + "accelerator": DEFAULT_CPU_ACCELERATOR, + } + ) # mock return value mock_return_value = {"dummy_output_model_key": "dummy_output_model_value"} @@ -165,14 +165,14 @@ def test_run_pass(self, mock_model_config_parse_obj, mock__run_command): dummy_output_model_path = "dummy_output_model_path" # execute - res = self.system.run_pass(the_pass, model_config, dummy_output_model_path) + res = self.system.run_pass(dummy_full_pass_config, model_config, dummy_output_model_path) # assert assert res == mock_output_model_config mock_model_config_parse_obj.assert_called_once_with(mock_return_value) mock__run_command.assert_called_once_with( self.system.pass_runner_path, - {"model_config": dummy_model_config, "pass_config": expected_pass_config}, + {"model_config": dummy_model_config, "pass_config": dummy_full_pass_config.to_json(check_object=True)}, tempdir=tempfile.tempdir, output_model_path=dummy_output_model_path, ) diff --git a/test/unit_test/systems/test_local.py b/test/unit_test/systems/test_local.py index cde0f2918..c8997f866 100644 --- a/test/unit_test/systems/test_local.py +++ b/test/unit_test/systems/test_local.py @@ -27,16 +27,19 @@ def setup(self): def test_run_pass(self): # setup - p = MagicMock() - p.run.return_value = PyTorchModelHandler("model_path") - olive_model = MagicMock() + full_pass_config = MagicMock() + model_config = MagicMock() + the_pass = MagicMock() + output_model_path = "output_model_path" + full_pass_config.create_pass.return_value = the_pass + the_pass.run.return_value = PyTorchModelHandler("model_path") # execute - self.system.run_pass(p, olive_model, output_model_path) + self.system.run_pass(full_pass_config, model_config, output_model_path) # assert - p.run.assert_called_once_with(olive_model.create_model(), output_model_path) + the_pass.run.assert_called_once_with(model_config.create_model(), output_model_path) METRIC_TEST_CASE: ClassVar[List[Metric]] = [ (partial(get_accuracy_metric, AccuracySubType.ACCURACY_SCORE)),