Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delay pass creation until time to run #1621

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions olive/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from olive.logging import enable_filelog
from olive.model import ModelConfig
from olive.package_config import OlivePackageConfig
from olive.passes.olive_pass import FullPassConfig
from olive.search.search_sample import SearchSample
from olive.search.search_strategy import SearchStrategy, SearchStrategyConfig
from olive.systems.common import SystemType
Expand Down Expand Up @@ -682,28 +683,27 @@ def _run_pass(
"""Run a pass on the input model."""
run_start_time = datetime.now().timestamp()

pass_config: RunPassConfig = self.computed_passes_configs[pass_name]
pass_type_name = pass_config.type
run_pass_config: RunPassConfig = self.computed_passes_configs[pass_name]
pass_type_name = run_pass_config.type

logger.info("Running pass %s:%s", pass_name, pass_type_name)

# check whether the config is valid
pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_config.type)
if not pass_cls.validate_config(pass_config.config, accelerator_spec):
pass_cls: Type[Pass] = self.olive_config.import_pass_module(run_pass_config.type)
if not pass_cls.validate_config(run_pass_config.config, accelerator_spec):
logger.warning("Invalid config, pruned.")
logger.debug(pass_config)
logger.debug(run_pass_config)
# no need to record in footprint since there was no run and thus no valid/failed model
# invalid configs are also not cached since the same config can be valid for other accelerator specs
# a pass can be accelerator agnostic but still have accelerator specific invalid configs
# this helps reusing cached models for different accelerator specs
return INVALID_CONFIG, None

p: Pass = pass_cls(accelerator_spec, pass_config.config, self.get_host_device())
pass_config = p.config.to_json()
pass_config = run_pass_config.config.to_json()
output_model_config = None

# load run from cache if it exists
run_accel = None if p.is_accelerator_agnostic(accelerator_spec) else accelerator_spec
run_accel = None if pass_cls.is_accelerator_agnostic(accelerator_spec) else accelerator_spec
output_model_id = self.cache.get_output_model_id(pass_type_name, pass_config, input_model_id, run_accel)
run_cache = self.cache.load_run_from_model_id(output_model_id)
if run_cache:
Expand Down Expand Up @@ -734,15 +734,18 @@ def _run_pass(
input_model_config = self.cache.prepare_resources_for_local(input_model_config)

try:
if p.run_on_target:
if pass_cls.run_on_target:
if self.target.system_type == SystemType.IsolatedORT:
logger.warning(
"Cannot run pass %s on IsolatedORT target, will use the host to run the pass.", pass_name
)
else:
host = self.target

output_model_config = host.run_pass(p, input_model_config, output_model_path)
full_pass_config = FullPassConfig.from_run_pass_config(
run_pass_config, accelerator_spec, self.get_host_device()
)
output_model_config = host.run_pass(full_pass_config, input_model_config, output_model_path)
except OlivePassError:
logger.exception("Pass run_pass failed")
output_model_config = FAILED_CONFIG
Expand Down
57 changes: 36 additions & 21 deletions olive/passes/olive_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,22 @@
import logging
import shutil
from abc import ABC, abstractmethod
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, get_args
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, get_args

from olive.common.config_utils import ParamCategory, validate_config
from olive.common.pydantic_v1 import BaseModel, ValidationError, create_model
from olive.common.pydantic_v1 import BaseModel, ValidationError, create_model, validator
from olive.common.user_module_loader import UserModuleLoader
from olive.data.config import DataConfig
from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec
from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec, Device
from olive.model import CompositeModelHandler, DistributedOnnxModelHandler, OliveModelHandler, ONNXModelHandler
from olive.passes.pass_config import (
AbstractPassConfig,

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'AbstractPassConfig' may not be defined if module
olive.passes.pass_config
is imported before module
olive.passes.olive_pass
, as the
definition
of AbstractPassConfig occurs after the cyclic
import
of olive.passes.olive_pass.
BasePassConfig,

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'BasePassConfig' may not be defined if module
olive.passes.pass_config
is imported before module
olive.passes.olive_pass
, as the
definition
of BasePassConfig occurs after the cyclic
import
of olive.passes.olive_pass.
PassConfigParam,

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'PassConfigParam' may not be defined if module
olive.passes.pass_config
is imported before module
olive.passes.olive_pass
, as the
definition
of PassConfigParam occurs after the cyclic
import
of olive.passes.olive_pass.
PassParamDefault,

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'PassParamDefault' may not be defined if module
olive.passes.pass_config
is imported before module
olive.passes.olive_pass
, as the
definition
of PassParamDefault occurs after the cyclic
import
of olive.passes.olive_pass.
create_config_class,

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'create_config_class' may not be defined if module
olive.passes.pass_config
is imported before module
olive.passes.olive_pass
, as the
definition
of create_config_class occurs after the cyclic
import
of olive.passes.olive_pass.
)
from olive.resource_path import ResourcePath
from olive.search.search_parameter import (
Expand All @@ -32,6 +33,9 @@
)
from olive.search.utils import cyclic_search_space, order_search_parameters

if TYPE_CHECKING:
from olive.engine.config import RunPassConfig

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'RunPassConfig' may not be defined if module
olive.engine.config
is imported before module
olive.passes.olive_pass
, as the
definition
of RunPassConfig occurs after the cyclic
import
of olive.passes.olive_pass.

logger = logging.getLogger(__name__)

# ruff: noqa: B027
Expand Down Expand Up @@ -80,17 +84,10 @@
if hasattr(self.config, "user_script") and hasattr(self.config, "script_dir"):
self._user_module_loader = UserModuleLoader(self.config.user_script, self.config.script_dir)

# Params that are paths [(param_name, required)]
self.path_params = [
(param, param_config.required, param_config.category)
for param, param_config in self.default_config(accelerator_spec).items()
if param_config.category in (ParamCategory.PATH, ParamCategory.DATA)
]

self._initialized = False

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Whether the pass is accelerator agnostic. If True, the pass will be reused for all accelerators.

The default value is True. The subclass could choose to override this method to return False by using the
Expand Down Expand Up @@ -482,17 +479,35 @@
reconstruct the pass from the JSON file.
"""

accelerator: Dict[str, str] = None
host_device: Optional[str] = None
accelerator: Optional[AcceleratorSpec] = None
host_device: Optional[Device] = None

def create_pass(self):
if not isinstance(self.accelerator, dict):
raise ValueError(f"accelerator must be a dict, got {self.accelerator}")
@validator("accelerator", pre=True)
def validate_accelerator(cls, v):
if isinstance(v, AcceleratorSpec):
return v
elif isinstance(v, dict):
return AcceleratorSpec(**v)
raise ValueError("Invalid accelerator input.")

pass_cls = Pass.registry[self.type.lower()]
accelerator_spec = AcceleratorSpec(**self.accelerator) # pylint: disable=not-a-mapping
self.config = pass_cls.generate_config(accelerator_spec, self.config)
return pass_cls(accelerator_spec, self.config, self.host_device)
def create_pass(self) -> Pass:
"""Create a Pass."""
return super().create_pass_with_args(self.accelerator, self.host_device)

@staticmethod
def from_run_pass_config(
run_pass_config: Union[Dict[str, Any], "RunPassConfig"],
accelerator: "AcceleratorSpec",
host_device: Device = None,
) -> "FullPassConfig":
config = deepcopy(run_pass_config) if isinstance(run_pass_config, dict) else run_pass_config.dict()
config.update(
{
"accelerator": accelerator,
"host_device": host_device,
}
)
return validate_config(config, FullPassConfig)


# TODO(myguo): deprecate or remove this function by explicitly specify the accelerator_spec in the arguments
Expand Down
4 changes: 2 additions & 2 deletions olive/passes/onnx/inc_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@
class IncQuantization(Pass):
"""Quantize ONNX model with Intel® Neural Compressor."""

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
return False

Expand Down
4 changes: 2 additions & 2 deletions olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def validate_config(
return False
return True

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
return False

def _run_for_config(
Expand Down
4 changes: 2 additions & 2 deletions olive/passes/onnx/optimum_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class OptimumMerging(Pass):

_accepts_composite_model = True

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
return False

Expand Down
4 changes: 2 additions & 2 deletions olive/passes/onnx/session_params_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def get_thread_affinity_nums(affinity_str):
class OrtSessionParamsTuning(Pass):
"""Optimize ONNX Runtime inference settings."""

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
return False

Expand Down
4 changes: 2 additions & 2 deletions olive/passes/onnx/transformer_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class OrtTransformersOptimization(Pass):
# using a Linux machine which doesn't support onnxruntime-directml package.
# It is enough for the pass to fail if `opt_level` > 0 and the host doesn't have the required packages.

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
from onnxruntime import __version__ as OrtVersion
from packaging import version
Expand Down
4 changes: 2 additions & 2 deletions olive/passes/onnx/vitis_ai_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def _initialize(self):
super()._initialize()
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_vaiq_tmp")

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
return False

Expand Down
21 changes: 20 additions & 1 deletion olive/passes/pass_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Type, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Set, Type, Union

from olive.common.config_utils import (
ConfigBase,
Expand All @@ -20,6 +20,10 @@
from olive.resource_path import validate_resource_path
from olive.search.search_parameter import SearchParameter, SpecialParamValue, json_to_search_parameter

if TYPE_CHECKING:
from olive.hardware.accelerator import AcceleratorSpec
from olive.passes.olive_pass import Pass

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'Pass' may not be defined if module
olive.passes.olive_pass
is imported before module
olive.passes.pass_config
, as the
definition
of Pass occurs after the cyclic
import
of olive.passes.pass_config.
'Pass' may not be defined if module
olive.passes.olive_pass
is imported before module
olive.passes.pass_config
, as the
definition
of Pass occurs after the cyclic
import
of olive.passes.pass_config.


class PassParamDefault(StrEnumBase):
"""Default values for passes."""
Expand Down Expand Up @@ -128,6 +132,21 @@
def validate_type(cls, v):
return validate_lowercase(v)

@validator("config", pre=True, always=True)
def validate_config(cls, v):
return v or {}

def create_pass_with_args(self, accelerator: "AcceleratorSpec", host_device: Device) -> "Pass":
"""Create a Pass."""
from olive.package_config import OlivePackageConfig

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
olive.package_config
begins an import cycle.

olive_config = OlivePackageConfig.load_default_config()
pass_cls: Type[Pass] = olive_config.import_pass_module(self.type)
self.config = pass_cls.generate_config(
accelerator, self.config if isinstance(self.config, dict) else self.config.dict()
)
return pass_cls(accelerator, self.config, host_device)


def create_config_class(
pass_type: str,
Expand Down
4 changes: 2 additions & 2 deletions olive/passes/pytorch/capture_split_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def validate_config(

return True

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
return False

def _run_for_config(
Expand Down
19 changes: 13 additions & 6 deletions olive/systems/azureml/aml_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

if TYPE_CHECKING:
from olive.hardware.accelerator import AcceleratorSpec
from olive.passes.olive_pass import Pass
from olive.passes.olive_pass import FullPassConfig


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -243,23 +243,30 @@ def _assert_not_none(self, obj):
if obj is None:
raise ValueError(f"{obj.__class__.__name__} is missing in the inputs!")

def run_pass(self, the_pass: "Pass", model_config: ModelConfig, output_model_path: str) -> ModelConfig:
def run_pass(
self,
full_pass_config: "FullPassConfig",
model_config: "ModelConfig",
output_model_path: str,
) -> ModelConfig:
"""Run the pass on the model."""
ml_client = self.azureml_client_config.create_client()

# serialize pass
pass_config = the_pass.to_json(check_object=True)
# serialize config
serialized_pass_config = full_pass_config.to_json(check_object=True)

with tempfile.TemporaryDirectory() as tempdir:
pipeline_job = self._create_pipeline_for_pass(tempdir, model_config.to_json(check_object=True), pass_config)
pipeline_job = self._create_pipeline_for_pass(
tempdir, model_config.to_json(check_object=True), serialized_pass_config
)

# submit job
named_outputs_dir = self._run_job(
ml_client,
pipeline_job,
"olive-pass",
tempdir,
tags={"Pass": pass_config["type"]},
tags={"Pass": serialized_pass_config["type"]},
output_name="pipeline_output",
)
pipeline_output_path = named_outputs_dir / "pipeline_output"
Expand Down
Loading
Loading