Skip to content

Commit

Permalink
Use concrete config classes instead of generic dictionary (#1567)
Browse files Browse the repository at this point in the history
## Use concrete config classes instead of generic dictionary

* Pass now accepts a concrete BasePassConfig object instead of generic
dict for config
* Pass._run_for_config similarly accepts a concrete BasePassConfig
object
* Pass.generate_config returns a concrete BasePassConfig object
* Removed "disable_search" argument from Pass.validate_config
* Removed Pass.searialize_config. Use Pass.config.to_json or
Pass.config.dict instead
* Removed Pass._config_class, Use Pass.config.__class__ instead
* Updated all passes and tests to use input config as an object of type
BasePassConfig instead of generic dict

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [x] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [x] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
shaahji authored Feb 14, 2025
1 parent 481807e commit 00415b6
Show file tree
Hide file tree
Showing 59 changed files with 547 additions and 576 deletions.
4 changes: 3 additions & 1 deletion olive/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def __init__(self, cache_config: Union[CacheConfig, Dict]):
self.update_shared_cache = cache_config.update_shared_cache

@staticmethod
def get_run_json(pass_name: int, pass_config: dict, input_model_id: str, accelerator_spec: "AcceleratorSpec"):
def get_run_json(
pass_name: str, pass_config: Dict[str, Any], input_model_id: str, accelerator_spec: "AcceleratorSpec"
) -> Dict[str, Any]:
accelerator_spec = str(accelerator_spec) if accelerator_spec else None
return {
"input_model_id": input_model_id,
Expand Down
7 changes: 3 additions & 4 deletions olive/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def _get_search_space_config(self, accelerator_spec: "AcceleratorSpec"):
space_config[pass_name] = pass_params_config = []
for pass_config in passes_configs:
pass_cls = self.olive_config.import_pass_module(pass_config.type)
_, search_params = pass_cls.get_config_params(accelerator_spec, pass_config.config, False)
_, _, search_params = pass_cls.get_config_params(accelerator_spec, pass_config.config, False)
pass_params_config.append(search_params)
return space_config

Expand Down Expand Up @@ -689,7 +689,7 @@ def _run_pass(

# check whether the config is valid
pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_config.type)
if not pass_cls.validate_config(pass_config.config, accelerator_spec, self.search_strategy is None):
if not pass_cls.validate_config(pass_config.config, accelerator_spec):
logger.warning("Invalid config, pruned.")
logger.debug(pass_config)
# no need to record in footprint since there was no run and thus no valid/failed model
Expand All @@ -699,12 +699,11 @@ def _run_pass(
return INVALID_CONFIG, None

p: Pass = pass_cls(accelerator_spec, pass_config.config, self.get_host_device())
pass_config = p.serialize_config(pass_config.config, check_object=True)
pass_config = p.config.to_json()
output_model_config = None

# load run from cache if it exists
run_accel = None if p.is_accelerator_agnostic(accelerator_spec) else accelerator_spec

output_model_id = self.cache.get_output_model_id(pass_type_name, pass_config, input_model_id, run_accel)
run_cache = self.cache.load_run_from_model_id(output_model_id)
if run_cache:
Expand Down
44 changes: 17 additions & 27 deletions olive/passes/olive_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,38 +58,32 @@ def __init_subclass__(cls, **kwargs) -> None:
def __init__(
self,
accelerator_spec: AcceleratorSpec,
config: Dict[str, Any],
config: Type[BasePassConfig],
host_device=None,
):
"""Initialize the pass.
:param accelerator_spec: the accelerator spec for the pass.
:type accelerator_spec: AcceleratorSpec
:param config: the configuration representing search space.
:type config: Dict[str, Any]
:type config: Type[BasePassConfig]
:param host_device: the host device for the pass.
:type host_device: Optional[str]
"""
assert accelerator_spec is not None, "Please specify the accelerator spec for the pass."
assert config is not None, "Please specify the configuration for the pass."

# NOTE: The :disable_search argument isn't impactful here since the search isn't
# dependent on it. The same parameter in :generate_config is what decides
# how search points are handled. HEre, Using default values for each config
# parameter in the config class keeps it simple.
config_class, default_config = self.get_config_class(accelerator_spec, True)

self.config = config
self.accelerator_spec = accelerator_spec
self.host_device = host_device

self._config_class = config_class
self.config = config
self._user_module_loader = UserModuleLoader(self.config.get("user_script"), self.config.get("script_dir"))
if hasattr(self.config, "user_script") and hasattr(self.config, "script_dir"):
self._user_module_loader = UserModuleLoader(self.config.user_script, self.config.script_dir)

# Params that are paths [(param_name, required)]
self.path_params = [
(param, param_config.required, param_config.category)
for param, param_config in default_config.items()
for param, param_config in self.default_config(accelerator_spec).items()
if param_config.category in (ParamCategory.PATH, ParamCategory.DATA)
]

Expand All @@ -110,7 +104,7 @@ def get_config_params(
accelerator_spec: AcceleratorSpec,
config: Optional[Dict[str, Any]] = None,
disable_search: Optional[bool] = False,
) -> Tuple[Dict[str, Any], Dict[str, SearchParameter]]:
) -> Tuple[Type[BasePassConfig], Dict[str, Any], Dict[str, SearchParameter]]:
"""Generate search space for the pass."""
assert accelerator_spec is not None, "Please specify the accelerator spec for the pass"
config = config or {}
Expand All @@ -125,7 +119,7 @@ def get_config_params(
# Generate the search space by using both default value and default search value and user provided config
config = validate_config(config, config_class)
config = cls._resolve_config(config, default_config)
return cls._init_fixed_and_search_params(config, default_config)
return config_class, *cls._init_fixed_and_search_params(config, default_config)

@classmethod
def generate_config(
Expand All @@ -134,16 +128,16 @@ def generate_config(
config: Optional[Dict[str, Any]] = None,
point: Optional[Dict[str, Any]] = None,
disable_search: Optional[bool] = False,
) -> Dict[str, Any]:
) -> Type[BasePassConfig]:
"""Get the configuration for the pass at a specific point in the search space."""
assert accelerator_spec is not None, "Please specify the accelerator spec for the pass"

point = point or {}
fixed_values, search_params = cls.get_config_params(accelerator_spec, config, disable_search)
config_class, fixed_values, search_params = cls.get_config_params(accelerator_spec, config, disable_search)
assert (
set(point.keys()).intersection(set(search_params.keys())) == point.keys()
), "Search point is not in the search space."
return {**fixed_values, **search_params, **point}
return config_class.parse_obj({**fixed_values, **search_params, **point})

@classmethod
def _identify_search_values(
Expand Down Expand Up @@ -202,9 +196,8 @@ def default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConf
@classmethod
def validate_config(
cls,
config: Dict[str, Any],
config: Type[BasePassConfig],
accelerator_spec: AcceleratorSpec,
disable_search: Optional[bool] = False,
) -> bool:
"""Validate the input config for the pass."""
return True
Expand Down Expand Up @@ -310,17 +303,13 @@ def _carry_forward_additional_files(input_model: OliveModelHandler, output_model
output_model_attributes["additional_files"] = sorted(output_model_additional_files)
output_model.model_attributes = output_model_attributes

def serialize_config(self, config: Dict[str, Any], check_object: bool = False) -> str:
"""Serialize the configuration."""
return self._config_class(**config).to_json(check_object)

def to_json(self, check_object: bool = False) -> Dict[str, Any]:
"""Convert the pass to json."""
return {
"type": self.__class__.__name__,
"accelerator": self.accelerator_spec.to_json(),
"host_device": self.host_device,
"config": self.serialize_config(self.config, check_object),
"config": self.config.to_json(check_object),
}

@classmethod
Expand Down Expand Up @@ -464,7 +453,7 @@ def _resolve_search_parameter(cls, param: SearchParameter, fixed_params: Dict[st
@classmethod
def _resolve_config(
cls,
input_config: Union[Dict[str, Any], BasePassConfig],
input_config: Union[Dict[str, Any], Type[BasePassConfig]],
default_config: Dict[str, PassConfigParam],
) -> Dict[str, Any]:
"""Resolve config to BasePassConfig."""
Expand All @@ -480,7 +469,7 @@ def _initialize(self):

@abstractmethod
def _run_for_config(
self, model: OliveModelHandler, config: Dict[str, Any], output_model_path: str
self, model: OliveModelHandler, config: Type[BasePassConfig], output_model_path: str
) -> OliveModelHandler:
"""Run the pass on the model with the given configuration."""
raise NotImplementedError
Expand All @@ -502,6 +491,7 @@ def create_pass(self):

pass_cls = Pass.registry[self.type.lower()]
accelerator_spec = AcceleratorSpec(**self.accelerator) # pylint: disable=not-a-mapping
self.config = pass_cls.generate_config(accelerator_spec, self.config)
return pass_cls(accelerator_spec, self.config, self.host_device)


Expand All @@ -518,5 +508,5 @@ def create_pass_from_dict(
if accelerator_spec is None:
accelerator_spec = DEFAULT_CPU_ACCELERATOR

config = pass_cls.generate_config(accelerator_spec, config, disable_search=disable_search)
config: Type[BasePassConfig] = pass_cls.generate_config(accelerator_spec, config, disable_search=disable_search)
return pass_cls(accelerator_spec, config, host_device)
27 changes: 13 additions & 14 deletions olive/passes/onnx/append_pre_post_processing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# --------------------------------------------------------------------------
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Type, Union

import onnx
from packaging import version
Expand All @@ -17,7 +17,7 @@
from olive.passes import Pass
from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model
from olive.passes.onnx.pipeline import TENSOR_TYPE_MAP
from olive.passes.pass_config import PassConfigParam
from olive.passes.pass_config import BasePassConfig, PassConfigParam


class PrePostProcessorInput(ConfigBase):
Expand Down Expand Up @@ -81,13 +81,13 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, Dict[st
return config

def _run_for_config(
self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str
self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str
) -> ONNXModelHandler:
from onnxruntime import __version__ as OrtVersion

output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name)

tool_command = config.get("tool_command")
tool_command = config.tool_command
if tool_command:
if tool_command == "whisper":
from onnxruntime_extensions import __version__ as ortext_version
Expand All @@ -98,17 +98,17 @@ def _run_for_config(

from olive.passes.utils.whisper_prepost import add_pre_post_processing_to_model

tool_command_args = config.get("tool_command_args") or {}
tool_command_args = config.tool_command_args or {}
onnx_model = add_pre_post_processing_to_model(
model.load_model(), config["target_opset"], **tool_command_args
model.load_model(), config.target_opset, **tool_command_args
)
else:
# Use the pre-defined helper to add pre/post processing to model.
from onnxruntime_extensions.tools import add_pre_post_processing_to_model as add_ppp

# ORT 1.14 and later support ONNX opset 18, which added antialiasing to the Resize operator.
# Results are much better when that can be used. Minimum opset is 16.
onnx_opset = config.get("target_opset")
onnx_opset = config.target_opset

if version.parse(OrtVersion) >= version.parse("1.14.0"):
onnx_opset = 18
Expand All @@ -123,7 +123,7 @@ def _run_for_config(
"tool_command must be a callable or a string defined in onnxruntime_extensions.tools"
) from None

kwargs = config.get("tool_command_args") or {}
kwargs = config.tool_command_args or {}
kwargs["onnx_opset"] = onnx_opset

# add the processing commands to the model
Expand All @@ -143,32 +143,31 @@ def _run_for_config(
olive_model.use_ort_extensions = True
return olive_model

def _run_prepost_pipeline(self, model: ONNXModelHandler, config: Dict[str, Any]):
def _run_prepost_pipeline(self, model: ONNXModelHandler, config: Type[BasePassConfig]):
from onnxruntime_extensions.tools.pre_post_processing import PrePostProcessor

from olive.passes.onnx.pipeline.step_utils import create_named_value, parse_steps

# Initialize pre/post step instance list
pre_steps = []
pre = config.get("pre")
pre = config.pre
model_proto = model.load_model()
if pre:
steps = parse_steps(model_proto, pre)
pre_steps = [self.create_step_from_config(step_name, step_param) for step_name, step_param in steps]

post_steps = []
post = config.get("post")
post = config.post
if post:
steps = parse_steps(model_proto, post)
post_steps = [self.create_step_from_config(step_name, step_param) for step_name, step_param in steps]

# Initialize PrePostProcessor instance
config_obj = self._config_class(**config)
input_param = config_obj.tool_command_args
input_param = config.tool_command_args
assert isinstance(input_param, list)
assert all(isinstance(i, PrePostProcessorInput) for i in input_param)
inputs = [create_named_value(i.name, TENSOR_TYPE_MAP[i.data_type], i.shape) for i in input_param]
pipeline = PrePostProcessor(inputs, config_obj.target_opset)
pipeline = PrePostProcessor(inputs, config.target_opset)

if pre_steps:
pipeline.add_pre_processing(pre_steps)
Expand Down
12 changes: 6 additions & 6 deletions olive/passes/onnx/bnb_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import re
from pathlib import Path
from typing import Any, Dict, List
from typing import Dict, List, Type

import onnx
from packaging import version
Expand All @@ -15,7 +15,7 @@
from olive.model.utils import resolve_onnx_path
from olive.passes import Pass
from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model
from olive.passes.pass_config import PassConfigParam
from olive.passes.pass_config import BasePassConfig, PassConfigParam

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,7 +48,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
return config

def _run_for_config(
self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str
self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str
) -> ONNXModelHandler:
from onnxruntime import __version__ as OrtVersion

Expand All @@ -60,8 +60,8 @@ def _run_for_config(

output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name)

quant_type = config["quant_type"]
quantized_modules = config["quantized_modules"]
quant_type = config.quant_type
quantized_modules = config.quantized_modules
if model.model_attributes:
quantized_modules = quantized_modules or model.model_attributes.get("quantized_modules")

Expand All @@ -87,7 +87,7 @@ def _run_for_config(
onnx_model = model.load_model()

# get nodes to exclude from quantization
nodes_to_exclude = config["nodes_to_exclude"] or []
nodes_to_exclude = config.nodes_to_exclude or []

# find all MatMul nodes in the graph
matmul_nodes = self._find_matmul_nodes(onnx_model.graph)
Expand Down
8 changes: 5 additions & 3 deletions olive/passes/onnx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import logging
import re
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Type, Union

import onnx

from olive.model import ONNXModelHandler
from olive.passes.onnx.onnx_dag import OnnxDAG
from olive.passes.pass_config import PassConfigParam
from olive.passes.pass_config import BasePassConfig, PassConfigParam
from olive.resource_path import LocalFile, LocalFolder

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -139,7 +139,7 @@ def model_proto_to_file(
def model_proto_to_olive_model(
model_proto: onnx.ModelProto,
output_model_path: Union[str, Path],
external_data_config: dict,
external_data_config: Union[Dict[str, Any], Type[BasePassConfig]],
check_model: bool = False,
external_initializers_file_name: Optional[str] = None,
constant_inputs_file_name: Optional[str] = None,
Expand All @@ -163,6 +163,8 @@ def model_proto_to_olive_model(
"size_threshold",
"convert_attribute",
]
if not isinstance(external_data_config, dict):
external_data_config = external_data_config.dict()
has_external_data = model_proto_to_file(
model_proto, output_model_path, **{k: external_data_config[k] for k in config_keys if k in external_data_config}
)
Expand Down
Loading

0 comments on commit 00415b6

Please sign in to comment.