Skip to content

Commit

Permalink
Delay pass creation until time to run
Browse files Browse the repository at this point in the history
Logic avoids instantiating a Pass unless it is to be run.
To avoid instantiation, only FullPassConfig is passed around.
  • Loading branch information
shaahji committed Feb 14, 2025
1 parent 5a616d8 commit 1b6bcca
Show file tree
Hide file tree
Showing 23 changed files with 204 additions and 137 deletions.
23 changes: 13 additions & 10 deletions olive/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from olive.logging import enable_filelog
from olive.model import ModelConfig
from olive.package_config import OlivePackageConfig
from olive.passes.olive_pass import FullPassConfig
from olive.search.search_sample import SearchSample
from olive.search.search_strategy import SearchStrategy, SearchStrategyConfig
from olive.systems.common import SystemType
Expand Down Expand Up @@ -682,28 +683,27 @@ def _run_pass(
"""Run a pass on the input model."""
run_start_time = datetime.now().timestamp()

pass_config: RunPassConfig = self.computed_passes_configs[pass_name]
pass_type_name = pass_config.type
run_pass_config: RunPassConfig = self.computed_passes_configs[pass_name]
pass_type_name = run_pass_config.type

logger.info("Running pass %s:%s", pass_name, pass_type_name)

# check whether the config is valid
pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_config.type)
if not pass_cls.validate_config(pass_config.config, accelerator_spec):
pass_cls: Type[Pass] = self.olive_config.import_pass_module(run_pass_config.type)
if not pass_cls.validate_config(run_pass_config.config, accelerator_spec):
logger.warning("Invalid config, pruned.")
logger.debug(pass_config)
logger.debug(run_pass_config)
# no need to record in footprint since there was no run and thus no valid/failed model
# invalid configs are also not cached since the same config can be valid for other accelerator specs
# a pass can be accelerator agnostic but still have accelerator specific invalid configs
# this helps reusing cached models for different accelerator specs
return INVALID_CONFIG, None

p: Pass = pass_cls(accelerator_spec, pass_config.config, self.get_host_device())
pass_config = p.config.to_json()
pass_config = run_pass_config.config.to_json()
output_model_config = None

# load run from cache if it exists
run_accel = None if p.is_accelerator_agnostic(accelerator_spec) else accelerator_spec
run_accel = None if pass_cls.is_accelerator_agnostic(accelerator_spec) else accelerator_spec
output_model_id = self.cache.get_output_model_id(pass_type_name, pass_config, input_model_id, run_accel)
run_cache = self.cache.load_run_from_model_id(output_model_id)
if run_cache:
Expand Down Expand Up @@ -734,15 +734,18 @@ def _run_pass(
input_model_config = self.cache.prepare_resources_for_local(input_model_config)

try:
if p.run_on_target:
if pass_cls.run_on_target:
if self.target.system_type == SystemType.IsolatedORT:
logger.warning(
"Cannot run pass %s on IsolatedORT target, will use the host to run the pass.", pass_name
)
else:
host = self.target

output_model_config = host.run_pass(p, input_model_config, output_model_path)
full_pass_config = FullPassConfig.from_run_pass_config(
run_pass_config, accelerator_spec, self.get_host_device()
)
output_model_config = host.run_pass(full_pass_config, input_model_config, output_model_path)
except OlivePassError:
logger.exception("Pass run_pass failed")
output_model_config = FAILED_CONFIG
Expand Down
57 changes: 36 additions & 21 deletions olive/passes/olive_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import logging
import shutil
from abc import ABC, abstractmethod
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, get_args
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, get_args

from olive.common.config_utils import ParamCategory, validate_config
from olive.common.pydantic_v1 import BaseModel, ValidationError, create_model
from olive.common.pydantic_v1 import BaseModel, ValidationError, create_model, validator
from olive.common.user_module_loader import UserModuleLoader
from olive.data.config import DataConfig
from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec
from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec, Device
from olive.model import CompositeModelHandler, DistributedOnnxModelHandler, OliveModelHandler, ONNXModelHandler
from olive.passes.pass_config import (
AbstractPassConfig,

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'AbstractPassConfig' may not be defined if module
olive.passes.pass_config
is imported before module
olive.passes.olive_pass
, as the
definition
of AbstractPassConfig occurs after the cyclic
import
of olive.passes.olive_pass.
Expand All @@ -32,6 +33,9 @@
)
from olive.search.utils import cyclic_search_space, order_search_parameters

if TYPE_CHECKING:
from olive.engine.config import RunPassConfig

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'RunPassConfig' may not be defined if module
olive.engine.config
is imported before module
olive.passes.olive_pass
, as the
definition
of RunPassConfig occurs after the cyclic
import
of olive.passes.olive_pass.

logger = logging.getLogger(__name__)

# ruff: noqa: B027
Expand Down Expand Up @@ -80,17 +84,10 @@ def __init__(
if hasattr(self.config, "user_script") and hasattr(self.config, "script_dir"):
self._user_module_loader = UserModuleLoader(self.config.user_script, self.config.script_dir)

# Params that are paths [(param_name, required)]
self.path_params = [
(param, param_config.required, param_config.category)
for param, param_config in self.default_config(accelerator_spec).items()
if param_config.category in (ParamCategory.PATH, ParamCategory.DATA)
]

self._initialized = False

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Whether the pass is accelerator agnostic. If True, the pass will be reused for all accelerators.
The default value is True. The subclass could choose to override this method to return False by using the
Expand Down Expand Up @@ -482,17 +479,35 @@ class FullPassConfig(AbstractPassConfig):
reconstruct the pass from the JSON file.
"""

accelerator: Dict[str, str] = None
host_device: Optional[str] = None
accelerator: Optional[AcceleratorSpec] = None
host_device: Optional[Device] = None

def create_pass(self):
if not isinstance(self.accelerator, dict):
raise ValueError(f"accelerator must be a dict, got {self.accelerator}")
@validator("accelerator", pre=True)
def validate_accelerator(cls, v):
if isinstance(v, AcceleratorSpec):
return v
elif isinstance(v, dict):
return AcceleratorSpec(**v)
raise ValueError("Invalid accelerator input.")

pass_cls = Pass.registry[self.type.lower()]
accelerator_spec = AcceleratorSpec(**self.accelerator) # pylint: disable=not-a-mapping
self.config = pass_cls.generate_config(accelerator_spec, self.config)
return pass_cls(accelerator_spec, self.config, self.host_device)
def create_pass(self) -> Pass:
"""Create a Pass."""
return super().create_pass_with_args(self.accelerator, self.host_device)

@staticmethod
def from_run_pass_config(
run_pass_config: Union[Dict[str, Any], "RunPassConfig"],
accelerator: "AcceleratorSpec",
host_device: Device = None,
) -> "FullPassConfig":
config = deepcopy(run_pass_config) if isinstance(run_pass_config, dict) else run_pass_config.dict()
config.update(
{
"accelerator": accelerator,
"host_device": host_device,
}
)
return validate_config(config, FullPassConfig)


# TODO(myguo): deprecate or remove this function by explicitly specify the accelerator_spec in the arguments
Expand Down
4 changes: 2 additions & 2 deletions olive/passes/onnx/inc_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@
class IncQuantization(Pass):
"""Quantize ONNX model with Intel® Neural Compressor."""

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
return False

Expand Down
4 changes: 2 additions & 2 deletions olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def validate_config(
return False
return True

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
return False

def _run_for_config(
Expand Down
4 changes: 2 additions & 2 deletions olive/passes/onnx/optimum_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class OptimumMerging(Pass):

_accepts_composite_model = True

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
return False

Expand Down
4 changes: 2 additions & 2 deletions olive/passes/onnx/session_params_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def get_thread_affinity_nums(affinity_str):
class OrtSessionParamsTuning(Pass):
"""Optimize ONNX Runtime inference settings."""

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
return False

Expand Down
4 changes: 2 additions & 2 deletions olive/passes/onnx/transformer_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class OrtTransformersOptimization(Pass):
# using a Linux machine which doesn't support onnxruntime-directml package.
# It is enough for the pass to fail if `opt_level` > 0 and the host doesn't have the required packages.

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
from onnxruntime import __version__ as OrtVersion
from packaging import version
Expand Down
4 changes: 2 additions & 2 deletions olive/passes/onnx/vitis_ai_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def _initialize(self):
super()._initialize()
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_vaiq_tmp")

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
return False

Expand Down
21 changes: 20 additions & 1 deletion olive/passes/pass_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Type, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Set, Type, Union

from olive.common.config_utils import (
ConfigBase,
Expand All @@ -20,6 +20,10 @@
from olive.resource_path import validate_resource_path
from olive.search.search_parameter import SearchParameter, SpecialParamValue, json_to_search_parameter

if TYPE_CHECKING:
from olive.hardware.accelerator import AcceleratorSpec
from olive.passes.olive_pass import Pass

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'Pass' may not be defined if module
olive.passes.olive_pass
is imported before module
olive.passes.pass_config
, as the
definition
of Pass occurs after the cyclic
import
of olive.passes.pass_config.
'Pass' may not be defined if module
olive.passes.olive_pass
is imported before module
olive.passes.pass_config
, as the
definition
of Pass occurs after the cyclic
import
of olive.passes.pass_config.


class PassParamDefault(StrEnumBase):
"""Default values for passes."""
Expand Down Expand Up @@ -128,6 +132,21 @@ class AbstractPassConfig(NestedConfig):
def validate_type(cls, v):
return validate_lowercase(v)

@validator("config", pre=True, always=True)
def validate_config(cls, v):
return v or {}

def create_pass_with_args(self, accelerator: "AcceleratorSpec", host_device: Device) -> "Pass":
"""Create a Pass."""
from olive.package_config import OlivePackageConfig

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
olive.package_config
begins an import cycle.

olive_config = OlivePackageConfig.load_default_config()
pass_cls: Type[Pass] = olive_config.import_pass_module(self.type)
self.config = pass_cls.generate_config(
accelerator, self.config if isinstance(self.config, dict) else self.config.dict()
)
return pass_cls(accelerator, self.config, host_device)


def create_config_class(
pass_type: str,
Expand Down
4 changes: 2 additions & 2 deletions olive/passes/pytorch/capture_split_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def validate_config(

return True

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
return False

def _run_for_config(
Expand Down
19 changes: 13 additions & 6 deletions olive/systems/azureml/aml_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

if TYPE_CHECKING:
from olive.hardware.accelerator import AcceleratorSpec
from olive.passes.olive_pass import Pass
from olive.passes.olive_pass import FullPassConfig


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -243,23 +243,30 @@ def _assert_not_none(self, obj):
if obj is None:
raise ValueError(f"{obj.__class__.__name__} is missing in the inputs!")

def run_pass(self, the_pass: "Pass", model_config: ModelConfig, output_model_path: str) -> ModelConfig:
def run_pass(
self,
full_pass_config: "FullPassConfig",
model_config: "ModelConfig",
output_model_path: str,
) -> ModelConfig:
"""Run the pass on the model."""
ml_client = self.azureml_client_config.create_client()

# serialize pass
pass_config = the_pass.to_json(check_object=True)
# serialize config
serialized_pass_config = full_pass_config.to_json(check_object=True)

with tempfile.TemporaryDirectory() as tempdir:
pipeline_job = self._create_pipeline_for_pass(tempdir, model_config.to_json(check_object=True), pass_config)
pipeline_job = self._create_pipeline_for_pass(
tempdir, model_config.to_json(check_object=True), serialized_pass_config
)

# submit job
named_outputs_dir = self._run_job(
ml_client,
pipeline_job,
"olive-pass",
tempdir,
tags={"Pass": pass_config["type"]},
tags={"Pass": serialized_pass_config["type"]},
output_name="pipeline_output",
)
pipeline_output_path = named_outputs_dir / "pipeline_output"
Expand Down
Loading

0 comments on commit 1b6bcca

Please sign in to comment.