From 00415b621118283ae5dd71a86fe258cb12fc96d9 Mon Sep 17 00:00:00 2001 From: shaahji <96227573+shaahji@users.noreply.github.com> Date: Thu, 13 Feb 2025 16:22:22 -0800 Subject: [PATCH] Use concrete config classes instead of generic dictionary (#1567) ## Use concrete config classes instead of generic dictionary * Pass now accepts a concrete BasePassConfig object instead of generic dict for config * Pass._run_for_config similarly accepts a concrete BasePassConfig object * Pass.generate_config returns a concrete BasePassConfig object * Removed "disable_search" argument from Pass.validate_config * Removed Pass.searialize_config. Use Pass.config.to_json or Pass.config.dict instead * Removed Pass._config_class, Use Pass.config.__class__ instead * Updated all passes and tests to use input config as an object of type BasePassConfig instead of generic dict ## Checklist before requesting a review - [ ] Add unit tests for this change. - [x] Make sure all tests can pass. - [ ] Update documents if necessary. - [x] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --- olive/cache.py | 4 +- olive/engine/engine.py | 7 +- olive/passes/olive_pass.py | 44 +++++------ .../onnx/append_pre_post_processing_ops.py | 27 ++++--- olive/passes/onnx/bnb_quantization.py | 12 +-- olive/passes/onnx/common.py | 8 +- olive/passes/onnx/conversion.py | 52 ++++++------- olive/passes/onnx/dynamic_to_fixed_shape.py | 14 ++-- olive/passes/onnx/extract_adapters.py | 28 +++---- olive/passes/onnx/float16_conversion.py | 9 ++- olive/passes/onnx/graph_surgeries.py | 6 +- olive/passes/onnx/inc_quantization.py | 15 ++-- olive/passes/onnx/insert_beam_search.py | 50 +++++++------ olive/passes/onnx/io_datatype_converter.py | 14 ++-- olive/passes/onnx/mixed_precision.py | 13 ++-- .../passes/onnx/mixed_precision_overrides.py | 26 +++---- olive/passes/onnx/mnb_to_qdq.py | 30 ++++---- olive/passes/onnx/model_builder.py | 25 +++---- olive/passes/onnx/moe_experts_distributor.py | 16 ++-- olive/passes/onnx/nvmo_quantization.py | 28 +++---- olive/passes/onnx/optimum_conversion.py | 28 +++---- olive/passes/onnx/optimum_merging.py | 8 +- olive/passes/onnx/peephole_optimizer.py | 6 +- olive/passes/onnx/qnn/qnn_preprocess.py | 22 +++--- olive/passes/onnx/quantization.py | 74 +++++++++---------- olive/passes/onnx/session_params_tuning.py | 20 ++--- olive/passes/onnx/split.py | 6 +- olive/passes/onnx/transformer_optimization.py | 35 ++++----- olive/passes/onnx/vitis_ai_quantization.py | 14 ++-- olive/passes/openvino/conversion.py | 22 +++--- olive/passes/openvino/quantization.py | 32 ++++---- olive/passes/pass_config.py | 2 +- olive/passes/pytorch/autoawq.py | 30 ++++---- olive/passes/pytorch/capture_split_info.py | 34 ++++----- olive/passes/pytorch/gptq.py | 33 +++++---- olive/passes/pytorch/lora.py | 70 +++++++++--------- olive/passes/pytorch/merge_adapter_weights.py | 8 +- .../pytorch/quantization_aware_training.py | 33 ++++----- olive/passes/pytorch/rotate.py | 18 ++--- olive/passes/pytorch/slicegpt.py | 9 +-- olive/passes/pytorch/sparsegpt.py | 29 ++++---- olive/passes/pytorch/tensor_parallel.py | 9 ++- olive/passes/pytorch/torch_trt_conversion.py | 18 ++--- olive/passes/qnn/context_binary_generator.py | 14 ++-- olive/passes/qnn/conversion.py | 18 ++--- olive/passes/qnn/model_lib_generator.py | 6 +- olive/passes/snpe/conversion.py | 8 +- olive/passes/snpe/quantization.py | 8 +- olive/passes/snpe/snpe_to_onnx.py | 10 +-- olive/systems/docker/docker_system.py | 3 +- test/requirements-test.txt | 3 +- test/unit_test/engine/test_engine.py | 34 +++++---- .../evaluator/test_olive_evaluator.py | 4 +- .../passes/common/test_user_script.py | 2 +- .../passes/onnx/test_optimum_conversion.py | 4 +- .../passes/onnx/test_session_params_tuning.py | 6 +- .../onnx/test_transformer_optimization.py | 5 +- .../test_python_environment_system.py | 1 - test/unit_test/utils.py | 9 +-- 59 files changed, 547 insertions(+), 576 deletions(-) diff --git a/olive/cache.py b/olive/cache.py index 0029c52b8..1739e7256 100644 --- a/olive/cache.py +++ b/olive/cache.py @@ -154,7 +154,9 @@ def __init__(self, cache_config: Union[CacheConfig, Dict]): self.update_shared_cache = cache_config.update_shared_cache @staticmethod - def get_run_json(pass_name: int, pass_config: dict, input_model_id: str, accelerator_spec: "AcceleratorSpec"): + def get_run_json( + pass_name: str, pass_config: Dict[str, Any], input_model_id: str, accelerator_spec: "AcceleratorSpec" + ) -> Dict[str, Any]: accelerator_spec = str(accelerator_spec) if accelerator_spec else None return { "input_model_id": input_model_id, diff --git a/olive/engine/engine.py b/olive/engine/engine.py index 79f42777b..cc88f0821 100644 --- a/olive/engine/engine.py +++ b/olive/engine/engine.py @@ -382,7 +382,7 @@ def _get_search_space_config(self, accelerator_spec: "AcceleratorSpec"): space_config[pass_name] = pass_params_config = [] for pass_config in passes_configs: pass_cls = self.olive_config.import_pass_module(pass_config.type) - _, search_params = pass_cls.get_config_params(accelerator_spec, pass_config.config, False) + _, _, search_params = pass_cls.get_config_params(accelerator_spec, pass_config.config, False) pass_params_config.append(search_params) return space_config @@ -689,7 +689,7 @@ def _run_pass( # check whether the config is valid pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_config.type) - if not pass_cls.validate_config(pass_config.config, accelerator_spec, self.search_strategy is None): + if not pass_cls.validate_config(pass_config.config, accelerator_spec): logger.warning("Invalid config, pruned.") logger.debug(pass_config) # no need to record in footprint since there was no run and thus no valid/failed model @@ -699,12 +699,11 @@ def _run_pass( return INVALID_CONFIG, None p: Pass = pass_cls(accelerator_spec, pass_config.config, self.get_host_device()) - pass_config = p.serialize_config(pass_config.config, check_object=True) + pass_config = p.config.to_json() output_model_config = None # load run from cache if it exists run_accel = None if p.is_accelerator_agnostic(accelerator_spec) else accelerator_spec - output_model_id = self.cache.get_output_model_id(pass_type_name, pass_config, input_model_id, run_accel) run_cache = self.cache.load_run_from_model_id(output_model_id) if run_cache: diff --git a/olive/passes/olive_pass.py b/olive/passes/olive_pass.py index 24135cfb0..b0b51af0a 100644 --- a/olive/passes/olive_pass.py +++ b/olive/passes/olive_pass.py @@ -58,7 +58,7 @@ def __init_subclass__(cls, **kwargs) -> None: def __init__( self, accelerator_spec: AcceleratorSpec, - config: Dict[str, Any], + config: Type[BasePassConfig], host_device=None, ): """Initialize the pass. @@ -66,30 +66,24 @@ def __init__( :param accelerator_spec: the accelerator spec for the pass. :type accelerator_spec: AcceleratorSpec :param config: the configuration representing search space. - :type config: Dict[str, Any] + :type config: Type[BasePassConfig] :param host_device: the host device for the pass. :type host_device: Optional[str] """ assert accelerator_spec is not None, "Please specify the accelerator spec for the pass." assert config is not None, "Please specify the configuration for the pass." - # NOTE: The :disable_search argument isn't impactful here since the search isn't - # dependent on it. The same parameter in :generate_config is what decides - # how search points are handled. HEre, Using default values for each config - # parameter in the config class keeps it simple. - config_class, default_config = self.get_config_class(accelerator_spec, True) - + self.config = config self.accelerator_spec = accelerator_spec self.host_device = host_device - self._config_class = config_class - self.config = config - self._user_module_loader = UserModuleLoader(self.config.get("user_script"), self.config.get("script_dir")) + if hasattr(self.config, "user_script") and hasattr(self.config, "script_dir"): + self._user_module_loader = UserModuleLoader(self.config.user_script, self.config.script_dir) # Params that are paths [(param_name, required)] self.path_params = [ (param, param_config.required, param_config.category) - for param, param_config in default_config.items() + for param, param_config in self.default_config(accelerator_spec).items() if param_config.category in (ParamCategory.PATH, ParamCategory.DATA) ] @@ -110,7 +104,7 @@ def get_config_params( accelerator_spec: AcceleratorSpec, config: Optional[Dict[str, Any]] = None, disable_search: Optional[bool] = False, - ) -> Tuple[Dict[str, Any], Dict[str, SearchParameter]]: + ) -> Tuple[Type[BasePassConfig], Dict[str, Any], Dict[str, SearchParameter]]: """Generate search space for the pass.""" assert accelerator_spec is not None, "Please specify the accelerator spec for the pass" config = config or {} @@ -125,7 +119,7 @@ def get_config_params( # Generate the search space by using both default value and default search value and user provided config config = validate_config(config, config_class) config = cls._resolve_config(config, default_config) - return cls._init_fixed_and_search_params(config, default_config) + return config_class, *cls._init_fixed_and_search_params(config, default_config) @classmethod def generate_config( @@ -134,16 +128,16 @@ def generate_config( config: Optional[Dict[str, Any]] = None, point: Optional[Dict[str, Any]] = None, disable_search: Optional[bool] = False, - ) -> Dict[str, Any]: + ) -> Type[BasePassConfig]: """Get the configuration for the pass at a specific point in the search space.""" assert accelerator_spec is not None, "Please specify the accelerator spec for the pass" point = point or {} - fixed_values, search_params = cls.get_config_params(accelerator_spec, config, disable_search) + config_class, fixed_values, search_params = cls.get_config_params(accelerator_spec, config, disable_search) assert ( set(point.keys()).intersection(set(search_params.keys())) == point.keys() ), "Search point is not in the search space." - return {**fixed_values, **search_params, **point} + return config_class.parse_obj({**fixed_values, **search_params, **point}) @classmethod def _identify_search_values( @@ -202,9 +196,8 @@ def default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConf @classmethod def validate_config( cls, - config: Dict[str, Any], + config: Type[BasePassConfig], accelerator_spec: AcceleratorSpec, - disable_search: Optional[bool] = False, ) -> bool: """Validate the input config for the pass.""" return True @@ -310,17 +303,13 @@ def _carry_forward_additional_files(input_model: OliveModelHandler, output_model output_model_attributes["additional_files"] = sorted(output_model_additional_files) output_model.model_attributes = output_model_attributes - def serialize_config(self, config: Dict[str, Any], check_object: bool = False) -> str: - """Serialize the configuration.""" - return self._config_class(**config).to_json(check_object) - def to_json(self, check_object: bool = False) -> Dict[str, Any]: """Convert the pass to json.""" return { "type": self.__class__.__name__, "accelerator": self.accelerator_spec.to_json(), "host_device": self.host_device, - "config": self.serialize_config(self.config, check_object), + "config": self.config.to_json(check_object), } @classmethod @@ -464,7 +453,7 @@ def _resolve_search_parameter(cls, param: SearchParameter, fixed_params: Dict[st @classmethod def _resolve_config( cls, - input_config: Union[Dict[str, Any], BasePassConfig], + input_config: Union[Dict[str, Any], Type[BasePassConfig]], default_config: Dict[str, PassConfigParam], ) -> Dict[str, Any]: """Resolve config to BasePassConfig.""" @@ -480,7 +469,7 @@ def _initialize(self): @abstractmethod def _run_for_config( - self, model: OliveModelHandler, config: Dict[str, Any], output_model_path: str + self, model: OliveModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> OliveModelHandler: """Run the pass on the model with the given configuration.""" raise NotImplementedError @@ -502,6 +491,7 @@ def create_pass(self): pass_cls = Pass.registry[self.type.lower()] accelerator_spec = AcceleratorSpec(**self.accelerator) # pylint: disable=not-a-mapping + self.config = pass_cls.generate_config(accelerator_spec, self.config) return pass_cls(accelerator_spec, self.config, self.host_device) @@ -518,5 +508,5 @@ def create_pass_from_dict( if accelerator_spec is None: accelerator_spec = DEFAULT_CPU_ACCELERATOR - config = pass_cls.generate_config(accelerator_spec, config, disable_search=disable_search) + config: Type[BasePassConfig] = pass_cls.generate_config(accelerator_spec, config, disable_search=disable_search) return pass_cls(accelerator_spec, config, host_device) diff --git a/olive/passes/onnx/append_pre_post_processing_ops.py b/olive/passes/onnx/append_pre_post_processing_ops.py index ab6d08010..4e03b6586 100644 --- a/olive/passes/onnx/append_pre_post_processing_ops.py +++ b/olive/passes/onnx/append_pre_post_processing_ops.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import tempfile from pathlib import Path -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Type, Union import onnx from packaging import version @@ -17,7 +17,7 @@ from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model from olive.passes.onnx.pipeline import TENSOR_TYPE_MAP -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam class PrePostProcessorInput(ConfigBase): @@ -81,13 +81,13 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, Dict[st return config def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: from onnxruntime import __version__ as OrtVersion output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) - tool_command = config.get("tool_command") + tool_command = config.tool_command if tool_command: if tool_command == "whisper": from onnxruntime_extensions import __version__ as ortext_version @@ -98,9 +98,9 @@ def _run_for_config( from olive.passes.utils.whisper_prepost import add_pre_post_processing_to_model - tool_command_args = config.get("tool_command_args") or {} + tool_command_args = config.tool_command_args or {} onnx_model = add_pre_post_processing_to_model( - model.load_model(), config["target_opset"], **tool_command_args + model.load_model(), config.target_opset, **tool_command_args ) else: # Use the pre-defined helper to add pre/post processing to model. @@ -108,7 +108,7 @@ def _run_for_config( # ORT 1.14 and later support ONNX opset 18, which added antialiasing to the Resize operator. # Results are much better when that can be used. Minimum opset is 16. - onnx_opset = config.get("target_opset") + onnx_opset = config.target_opset if version.parse(OrtVersion) >= version.parse("1.14.0"): onnx_opset = 18 @@ -123,7 +123,7 @@ def _run_for_config( "tool_command must be a callable or a string defined in onnxruntime_extensions.tools" ) from None - kwargs = config.get("tool_command_args") or {} + kwargs = config.tool_command_args or {} kwargs["onnx_opset"] = onnx_opset # add the processing commands to the model @@ -143,32 +143,31 @@ def _run_for_config( olive_model.use_ort_extensions = True return olive_model - def _run_prepost_pipeline(self, model: ONNXModelHandler, config: Dict[str, Any]): + def _run_prepost_pipeline(self, model: ONNXModelHandler, config: Type[BasePassConfig]): from onnxruntime_extensions.tools.pre_post_processing import PrePostProcessor from olive.passes.onnx.pipeline.step_utils import create_named_value, parse_steps # Initialize pre/post step instance list pre_steps = [] - pre = config.get("pre") + pre = config.pre model_proto = model.load_model() if pre: steps = parse_steps(model_proto, pre) pre_steps = [self.create_step_from_config(step_name, step_param) for step_name, step_param in steps] post_steps = [] - post = config.get("post") + post = config.post if post: steps = parse_steps(model_proto, post) post_steps = [self.create_step_from_config(step_name, step_param) for step_name, step_param in steps] # Initialize PrePostProcessor instance - config_obj = self._config_class(**config) - input_param = config_obj.tool_command_args + input_param = config.tool_command_args assert isinstance(input_param, list) assert all(isinstance(i, PrePostProcessorInput) for i in input_param) inputs = [create_named_value(i.name, TENSOR_TYPE_MAP[i.data_type], i.shape) for i in input_param] - pipeline = PrePostProcessor(inputs, config_obj.target_opset) + pipeline = PrePostProcessor(inputs, config.target_opset) if pre_steps: pipeline.add_pre_processing(pre_steps) diff --git a/olive/passes/onnx/bnb_quantization.py b/olive/passes/onnx/bnb_quantization.py index ea818f180..f28c09feb 100644 --- a/olive/passes/onnx/bnb_quantization.py +++ b/olive/passes/onnx/bnb_quantization.py @@ -5,7 +5,7 @@ import logging import re from pathlib import Path -from typing import Any, Dict, List +from typing import Dict, List, Type import onnx from packaging import version @@ -15,7 +15,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam logger = logging.getLogger(__name__) @@ -48,7 +48,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon return config def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: from onnxruntime import __version__ as OrtVersion @@ -60,8 +60,8 @@ def _run_for_config( output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) - quant_type = config["quant_type"] - quantized_modules = config["quantized_modules"] + quant_type = config.quant_type + quantized_modules = config.quantized_modules if model.model_attributes: quantized_modules = quantized_modules or model.model_attributes.get("quantized_modules") @@ -87,7 +87,7 @@ def _run_for_config( onnx_model = model.load_model() # get nodes to exclude from quantization - nodes_to_exclude = config["nodes_to_exclude"] or [] + nodes_to_exclude = config.nodes_to_exclude or [] # find all MatMul nodes in the graph matmul_nodes = self._find_matmul_nodes(onnx_model.graph) diff --git a/olive/passes/onnx/common.py b/olive/passes/onnx/common.py index 936dc5f67..7459e9b20 100644 --- a/olive/passes/onnx/common.py +++ b/olive/passes/onnx/common.py @@ -5,13 +5,13 @@ import logging import re from pathlib import Path -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Type, Union import onnx from olive.model import ONNXModelHandler from olive.passes.onnx.onnx_dag import OnnxDAG -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam from olive.resource_path import LocalFile, LocalFolder logger = logging.getLogger(__name__) @@ -139,7 +139,7 @@ def model_proto_to_file( def model_proto_to_olive_model( model_proto: onnx.ModelProto, output_model_path: Union[str, Path], - external_data_config: dict, + external_data_config: Union[Dict[str, Any], Type[BasePassConfig]], check_model: bool = False, external_initializers_file_name: Optional[str] = None, constant_inputs_file_name: Optional[str] = None, @@ -163,6 +163,8 @@ def model_proto_to_olive_model( "size_threshold", "convert_attribute", ] + if not isinstance(external_data_config, dict): + external_data_config = external_data_config.dict() has_external_data = model_proto_to_file( model_proto, output_model_path, **{k: external_data_config[k] for k in config_keys if k in external_data_config} ) diff --git a/olive/passes/onnx/conversion.py b/olive/passes/onnx/conversion.py index a32dc88cf..368b66b2f 100644 --- a/olive/passes/onnx/conversion.py +++ b/olive/passes/onnx/conversion.py @@ -7,7 +7,7 @@ import tempfile from copy import deepcopy from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import onnx import torch @@ -27,7 +27,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model -from olive.passes.pass_config import PassConfigParam, get_user_script_data_config +from olive.passes.pass_config import BasePassConfig, PassConfigParam, get_user_script_data_config logger = logging.getLogger(__name__) @@ -111,12 +111,12 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon def _run_for_config( self, model: Union[DistributedHfModelHandler, HfModelHandler, PyTorchModelHandler], - config: Dict[str, Any], + config: Type[BasePassConfig], output_model_path: str, ) -> Union[DistributedOnnxModelHandler, ONNXModelHandler]: output_model = self._run_for_config_internal(model, config, output_model_path) - if isinstance(model, HfModelHandler) and config["save_metadata_for_token_generation"]: + if isinstance(model, HfModelHandler) and config.save_metadata_for_token_generation: # output_model can only be an ONNXModelHandler output_dir = output_model.change_model_path_to_dir() @@ -131,14 +131,14 @@ def _run_for_config( def _run_for_config_internal( self, model: Union[DistributedHfModelHandler, HfModelHandler, PyTorchModelHandler], - config: Dict[str, Any], + config: Type[BasePassConfig], output_model_path: str, ) -> Union[DistributedOnnxModelHandler, ONNXModelHandler]: # get the device to use for conversion # default to "cpu" for PyTorchModelHandler and "cuda" for DistributedHfModel - device = config["device"] or "cpu" + device = config.device or "cpu" # get the dtype to use for conversion - torch_dtype = resolve_torch_dtype(config["torch_dtype"]) if config["torch_dtype"] else None + torch_dtype = resolve_torch_dtype(config.torch_dtype) if config.torch_dtype else None if torch_dtype == torch.float16 and device == "cpu": logger.debug( "Converting model to float16 on CPU. This might fail for some models. If the conversion fails or model" @@ -147,7 +147,7 @@ def _run_for_config_internal( ) if isinstance(model, DistributedHfModelHandler): - if not config["device"]: + if not config.device: device = "cuda" return self._convert_distributed_model_on_device(model, config, output_model_path, device, torch_dtype) @@ -159,7 +159,7 @@ def _export_pytorch_model( pytorch_model: torch.nn.Module, dummy_inputs, io_config, - config: Dict[str, Any], + config: Type[BasePassConfig], device: Union[str, torch.device], torch_dtype: Optional[torch.dtype] = None, tempdir: Optional[Union[Path, str]] = None, @@ -191,7 +191,7 @@ def _export_pytorch_model( if isinstance(pytorch_model, torch.jit.RecursiveScriptModule): pytorch_model = TraceModelWrapper(pytorch_model) - pytorch_model = make_export_compatible_peft(pytorch_model, merge_weights=config["merge_adapter_weights"]) + pytorch_model = make_export_compatible_peft(pytorch_model, merge_weights=config.merge_adapter_weights) pytorch_model = make_export_compatible_quant(pytorch_model) # cast to dtype, want all modules including lora layers and quant linears in the same dtype if torch_dtype: @@ -201,12 +201,12 @@ def _export_pytorch_model( assert io_config is not None, "Cannot get io_config for the model." io_config = validate_config(io_config, IoConfig) # If dynamic is False, set dynamic_axes and dynamic_shapes to None - if not config["dynamic"]: + if not config.dynamic: io_config.dynamic_axes = None io_config.dynamic_shapes = None onnx_model = None - if config["use_dynamo_exporter"]: + if config.use_dynamo_exporter: # Take the "release" version so that dev builds like 2.5.0dev1234 are treated as 2.5.0 torch_version = version.parse(torch.__version__).release # The "legacy dynamo" is the torch.onnx_dynamo_export API @@ -256,7 +256,7 @@ def _export_pytorch_model( dummy_inputs, tmp_model_path, # needed for fallback=True kwargs=dummy_kwargs, - opset_version=config["target_opset"], + opset_version=config.target_opset, input_names=io_config.input_names, output_names=io_config.output_names, dynamic_axes=io_config.dynamic_axes, @@ -280,7 +280,7 @@ def _export_pytorch_model( dummy_inputs, tmp_model_path, export_params=True, - opset_version=config["target_opset"], + opset_version=config.target_opset, input_names=io_config.input_names, output_names=io_config.output_names, dynamic_axes=io_config.dynamic_axes, @@ -347,7 +347,7 @@ def _prepare_hf_model( new_load_kwargs["torch_dtype"] = torch_dtype model_attributes["torch_dtype"] = str(torch_dtype).replace("torch.", "") - if load_kwargs.quantization_method == "bitsandbytes" and load_kwargs.quantization_config["load_in_4bit"]: + if load_kwargs.quantization_method == "bitsandbytes" and load_kwargs.quantization_config.load_in_4bit: logger.debug( "Bitsandbytes 4bit quantization is not supported for conversion. The quantization config is removed" " from the load kwargs. Use OnnxBnb4Quantization pass after conversion to quantize the" @@ -378,7 +378,7 @@ def _prepare_hf_model( def _convert_model_on_device( self, model: Union[HfModelHandler, PyTorchModelHandler], - config: Dict[str, Any], + config: Type[BasePassConfig], output_model_path: str, device: str, torch_dtype: Optional[torch.dtype] = None, @@ -417,15 +417,15 @@ def _convert_model_on_device( @staticmethod def _get_dummy_inputs( - model: Union[HfModelHandler, PyTorchModelHandler], config: Dict[str, Any] + model: Union[HfModelHandler, PyTorchModelHandler], config: Type[BasePassConfig] ) -> Union[Dict, Tuple]: """Get dummy inputs for the model.""" return model.get_dummy_inputs( filter_hook=( - model.merge_kv_cache_hook if config["use_dynamo_exporter"] else model.merge_kv_cache_to_tuple_hook + model.merge_kv_cache_hook if config.use_dynamo_exporter else model.merge_kv_cache_to_tuple_hook ), filter_hook_kwargs={ - "past_kv_names": config["past_key_value_name"], + "past_kv_names": config.past_key_value_name, }, ) @@ -466,7 +466,7 @@ def _export_ranked_model(params): olive_pytorch_model = input_model.load_model(local_rank) dummy_inputs = OnnxConversion._get_dummy_inputs(olive_pytorch_model, pass_config) - io_config = None if pass_config["use_dynamo_exporter"] else olive_pytorch_model.io_config + io_config = None if pass_config.use_dynamo_exporter else olive_pytorch_model.io_config pytorch_model = olive_pytorch_model.prepare_session(rank=local_rank) ranked_onnx_modelproto = OnnxConversion._export_pytorch_model( @@ -489,7 +489,7 @@ def _export_ranked_model(params): def _convert_distributed_model_on_device( self, model: DistributedHfModelHandler, - config: Dict[str, Any], + config: Type[BasePassConfig], output_model_path: str, device: str, torch_dtype: Optional[torch.dtype] = None, @@ -514,7 +514,7 @@ def _convert_distributed_model_on_device( for rank in range(world_size) ] - max_parallel_jobs = min(world_size, config["parallel_jobs"] or multiprocessing.cpu_count()) + max_parallel_jobs = min(world_size, config.parallel_jobs or multiprocessing.cpu_count()) if max_parallel_jobs <= 1: results = [OnnxConversion._export_ranked_model(_) for _ in params] else: @@ -549,18 +549,18 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon return config def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: output_model_path = resolve_onnx_path(output_model_path) # since external data is saved in a separate file, we need to load the model to get the opset version model_proto = onnx.load(model.model_path, load_external_data=False) model_opset_version = model_proto.opset_import[0].version - if model_opset_version == config["target_opset"]: - logger.info("Model is already in target opset version %s.", config["target_opset"]) + if model_opset_version == config.target_opset: + logger.info("Model is already in target opset version %s.", config.target_opset) return model - converted_model_proto = onnx.version_converter.convert_version(model_proto, config["target_opset"]) + converted_model_proto = onnx.version_converter.convert_version(model_proto, config.target_opset) # copy the external data of original model to the new model dst_init_map = {init.name: init for init in converted_model_proto.graph.initializer} for src_init in model_proto.graph.initializer: diff --git a/olive/passes/onnx/dynamic_to_fixed_shape.py b/olive/passes/onnx/dynamic_to_fixed_shape.py index 2c343c3f9..9b74ee7fa 100644 --- a/olive/passes/onnx/dynamic_to_fixed_shape.py +++ b/olive/passes/onnx/dynamic_to_fixed_shape.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Type from olive.common.pydantic_v1 import root_validator from olive.hardware import AcceleratorSpec @@ -12,7 +12,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes.olive_pass import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam if TYPE_CHECKING: from onnx import ModelProto @@ -66,7 +66,7 @@ def _validators(cls) -> Dict[str, Callable[..., Any]]: def _run_for_config( self, model: ONNXModelHandler, - config: Dict[str, Any], + config: Type[BasePassConfig], output_model_path: str, ) -> ONNXModelHandler: from onnxruntime.tools.onnx_model_utils import make_dim_param_fixed, make_input_shape_fixed @@ -74,11 +74,11 @@ def _run_for_config( onnx_model = model.load_model() output_model_path = resolve_onnx_path(output_model_path) - if config["dim_param"]: - for param, value in zip(config["dim_param"], config["dim_value"]): + if config.dim_param: + for param, value in zip(config.dim_param, config.dim_value): make_dim_param_fixed(onnx_model.graph, param, value) - elif config["input_name"]: - for name, shape in zip(config["input_name"], config["input_shape"]): + elif config.input_name: + for name, shape in zip(config.input_name, config.input_shape): make_input_shape_fixed(onnx_model.graph, name, shape) # update the output shapes to make them fixed # onnxruntime.tools.onnx_model_utils.fix_output_shapes cannot handle models > 2GB diff --git a/olive/passes/onnx/extract_adapters.py b/olive/passes/onnx/extract_adapters.py index f290a8371..58fd2f7f4 100644 --- a/olive/passes/onnx/extract_adapters.py +++ b/olive/passes/onnx/extract_adapters.py @@ -6,7 +6,7 @@ import re from copy import deepcopy from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Dict, Type import numpy as np import onnx @@ -18,7 +18,7 @@ from olive.passes import Pass from olive.passes.onnx.common import LORA_NAME_PATTERNS, get_external_data_config, model_proto_to_olive_model from olive.passes.onnx.onnx_dag import OnnxDAG -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam if TYPE_CHECKING: from numpy.typing import NDArray @@ -74,7 +74,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon return config def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) @@ -181,8 +181,8 @@ def _run_for_config( for node_name in nodes_to_remove: dag.remove_node(node_name) - if config["make_inputs"]: - if quant_modules and config["dynamic_lora_r"]: + if config.make_inputs: + if quant_modules and config.dynamic_lora_r: # MatMulNBits has static K,N dimensions which are set as attributes # No use case for DequantizeLinear with dynamic lora_r logger.info("Quantized modules do not support dynamic_lora_r. Ignoring.") @@ -196,15 +196,15 @@ def _run_for_config( dag.update() # save the weights - weights_path = save_weights(weights, Path(output_model_path).parent / "adapter_weights", config["save_format"]) + weights_path = save_weights(weights, Path(output_model_path).parent / "adapter_weights", config.save_format) # save the model output_model = model_proto_to_olive_model( dag.model, output_model_path, config, - external_initializers_file_name=weights_path.name if not config["make_inputs"] else None, - constant_inputs_file_name=weights_path.name if config["make_inputs"] else None, + external_initializers_file_name=weights_path.name if not config.make_inputs else None, + constant_inputs_file_name=weights_path.name if config.make_inputs else None, ) output_model.model_attributes = deepcopy(model.model_attributes) or {} # add adapter weights to the model attributes @@ -214,7 +214,7 @@ def _run_for_config( additional_files.append(str(weights_path)) # save information about the weights in the model attributes weights_info = {name: [list(value.shape), str(value.dtype)] for name, value in weights.items()} - if not config["make_inputs"]: + if not config.make_inputs: output_model.model_attributes["external_initializers"] = weights_info else: output_model.model_attributes["constant_inputs"] = weights_info @@ -282,14 +282,16 @@ def _externalize_initializer(cls, dag: OnnxDAG, weights: Dict[str, "NDArray"], o dag.add_initializer(new_initializer, dag.get_io(old_name).graph_idx) @classmethod - def _make_dynamic_optional(cls, dag: OnnxDAG, weights: Dict[str, "NDArray"], name: str, config: Dict[str, Any]): + def _make_dynamic_optional( + cls, dag: OnnxDAG, weights: Dict[str, "NDArray"], name: str, config: Type[BasePassConfig] + ): """Make the input dynamic and optional.""" if "quant" in name: # dynamic shape not supported for quantized modules # cannot have empty tensor as default values, so create default initializers of the same shape # scales must be zero to make the dequantized weights zero # quant weight and zeros points also made zero to be clean and consistent - if config["optional_inputs"]: + if config.optional_inputs: initializer_proto = onnx.numpy_helper.from_array(np.zeros_like(weights[name]), name) dag.add_initializer(initializer_proto, 0, keep_input=True) @@ -299,11 +301,11 @@ def _make_dynamic_optional(cls, dag: OnnxDAG, weights: Dict[str, "NDArray"], nam dim_idx = 1 if "lora_A" in name else 0 # make the input dynamic - if config["dynamic_lora_r"]: + if config.dynamic_lora_r: dag.make_input_dim_dynamic(name, dim_idx, "lora_r") # create default initializer with the lora_r dimension set to 0 - if config["optional_inputs"]: + if config.optional_inputs: shape = list(weights[name].shape) shape[dim_idx] = 0 initializer_proto = onnx.numpy_helper.from_array(np.zeros(shape, dtype=weights[name].dtype), name) diff --git a/olive/passes/onnx/float16_conversion.py b/olive/passes/onnx/float16_conversion.py index ff0134fed..4953d2297 100644 --- a/olive/passes/onnx/float16_conversion.py +++ b/olive/passes/onnx/float16_conversion.py @@ -3,14 +3,14 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from pathlib import Path -from typing import Any, Dict, List +from typing import Dict, List, Type from olive.hardware.accelerator import AcceleratorSpec from olive.model import ONNXModelHandler from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam class OnnxFloatToFloat16(Pass): @@ -47,7 +47,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon return config def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: from onnxruntime.transformers.onnx_model import OnnxModel @@ -56,9 +56,10 @@ def _run_for_config( # using the float16 converter from onnxruntime since it is regularly updated # and can handle large models (>2GB) as well as ort contrib ops ort_onnx_model = OnnxModel(model.load_model()) + config_dict = config.dict() ort_onnx_model.convert_float_to_float16( **{ - key: config[key] + key: config_dict[key] for key in [ "min_positive_val", "max_finite_val", diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index 2ad31ee70..2f93fc6fd 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -21,7 +21,7 @@ from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model from olive.passes.onnx.onnx_dag import OnnxDAG -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam logger = logging.getLogger(__name__) @@ -693,11 +693,11 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon } def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) - surgeries = config["surgeries"] + surgeries = config.surgeries onnx_model = model.load_model() for surgery in surgeries: logger.info("Applying surgery: %s", surgery) diff --git a/olive/passes/onnx/inc_quantization.py b/olive/passes/onnx/inc_quantization.py index d697ad2ab..97b251420 100644 --- a/olive/passes/onnx/inc_quantization.py +++ b/olive/passes/onnx/inc_quantization.py @@ -7,7 +7,7 @@ import tempfile from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Type, Union from packaging import version @@ -23,7 +23,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_has_adapters, model_proto_to_olive_model -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam from olive.search.search_parameter import Boolean, Categorical, Conditional logger = logging.getLogger(__name__) @@ -445,7 +445,7 @@ def _set_woq_config(self, run_config): return {"bits": bits, "group_size": group_size, "scheme": scheme, "algorithm": algo} def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: if model_has_adapters(model.model_path): logger.info("Model has adapters which should not be quantized. Returning the model without quantization.") @@ -467,18 +467,17 @@ def _run_for_config( import neural_compressor assert not ( - config["approach"] == "weight_only" - and version.parse(neural_compressor.__version__) < version.parse("2.3.0") + config.approach == "weight_only" and version.parse(neural_compressor.__version__) < version.parse("2.3.0") ), "Require neural-compressor >= 2.3.0 to support weight only quantization." # start with a copy of the config - run_config = deepcopy(config) + run_config = config.dict() require_dataloader = run_config["approach"] == "static" or ( run_config["approach"] == "weight_only" and run_config["weight_only_config"]["algorithm"].upper() in {"GPTQ", "AWQ"} ) if require_dataloader: - assert config["data_config"], "data_config is required for {} quantization.".format(run_config["approach"]) + assert config.data_config, "data_config is required for {} quantization.".format(run_config["approach"]) output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) @@ -515,7 +514,7 @@ def _run_for_config( inc_calib_dataloader = None if require_dataloader: - data_config = validate_config(config["data_config"], DataConfig) + data_config = validate_config(config.data_config, DataConfig) # inc quantization's calibration dataloader requires: # 1. input: (input, label) # 2. the dataloader should have the attributes of "__iter__" and "batch_size" diff --git a/olive/passes/onnx/insert_beam_search.py b/olive/passes/onnx/insert_beam_search.py index 4466e030d..5ec34580f 100644 --- a/olive/passes/onnx/insert_beam_search.py +++ b/olive/passes/onnx/insert_beam_search.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from typing import Any, Dict +from typing import Dict, Type from onnx import ModelProto, TensorProto, helper from packaging import version @@ -13,7 +13,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam logger = logging.getLogger(__name__) @@ -90,7 +90,13 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon return config def chain_model( - self, model_A: ModelProto, model_A_name: str, model_B: ModelProto, model_B_name: str, model_config, options + self, + model_A: ModelProto, + model_A_name: str, + model_B: ModelProto, + model_B_name: str, + model_config, + options: Type[BasePassConfig], ): from onnxruntime import __version__ as OrtVersion from onnxruntime.transformers.convert_generation import get_shared_initializers @@ -105,24 +111,24 @@ def chain_model( model_B.graph.name = f"{model_B_name} subgraph" beam_inputs = [ - "input_features_fp16" if options["fp16"] else "input_features", + "input_features_fp16" if options.fp16 else "input_features", "max_length", "min_length", "num_beams", "num_return_sequences", - "length_penalty_fp16" if options["fp16"] else "length_penalty", - "repetition_penalty_fp16" if options["fp16"] else "repetition_penalty", - "vocab_mask" if (version_1_16 and options["use_vocab_mask"]) else "", - "prefix_vocab_mask" if (version_1_16 and options["use_prefix_vocab_mask"]) else "", + "length_penalty_fp16" if options.fp16 else "length_penalty", + "repetition_penalty_fp16" if options.fp16 else "repetition_penalty", + "vocab_mask" if (version_1_16 and options.use_vocab_mask) else "", + "prefix_vocab_mask" if (version_1_16 and options.use_prefix_vocab_mask) else "", "" if version_1_16 else "attention_mask", ] if version_1_16: - beam_inputs.extend(["decoder_input_ids" if options["use_forced_decoder_ids"] else ""]) - beam_inputs.extend(["logits_processor" if options["use_logits_processor"] else ""]) + beam_inputs.extend(["decoder_input_ids" if options.use_forced_decoder_ids else ""]) + beam_inputs.extend(["logits_processor" if options.use_logits_processor else ""]) if version_1_17_1: beam_inputs.extend(["", ""]) beam_inputs.extend( - [("temperature_fp16" if options["fp16"] else "temperature") if options["use_temperature"] else ""] + [("temperature_fp16" if options.fp16 else "temperature") if options.use_temperature else ""] ) # remove empty string from the end of beam_inputs # otherwise, the model gives error when the last input is empty @@ -131,7 +137,7 @@ def chain_model( # Cast input features to fp16 if required graph_nodes = [] - if options["fp16"]: + if options.fp16: input_features_cast_node = helper.make_node( "Cast", inputs=["input_features"], @@ -155,7 +161,7 @@ def chain_model( ) graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node]) - if version_1_17_1 and options["use_temperature"]: + if version_1_17_1 and options.use_temperature: temperature_cast_node = helper.make_node( "Cast", inputs=["temperature"], @@ -172,7 +178,7 @@ def chain_model( helper.make_attribute("eos_token_id", model_config["eos_token_id"]), helper.make_attribute("pad_token_id", model_config["pad_token_id"]), helper.make_attribute("decoder_start_token_id", model_config["decoder_start_token_id"]), - helper.make_attribute("no_repeat_ngram_size", options["no_repeat_ngram_size"]), + helper.make_attribute("no_repeat_ngram_size", options.no_repeat_ngram_size), helper.make_attribute("early_stopping", True), helper.make_attribute("model_type", 2), ] @@ -236,29 +242,29 @@ def chain_model( ) graph_inputs.append(attention_mask) else: - if options["use_vocab_mask"]: + if options.use_vocab_mask: vocab_mask = helper.make_tensor_value_info( "vocab_mask", TensorProto.INT32, [model_config["vocab_size"]] ) graph_inputs.append(vocab_mask) - if options["use_prefix_vocab_mask"]: + if options.use_prefix_vocab_mask: prefix_vocab_mask = helper.make_tensor_value_info( "prefix_vocab_mask", TensorProto.INT32, ["batch_size", model_config["vocab_size"]] ) graph_inputs.append(prefix_vocab_mask) - if options["use_forced_decoder_ids"]: + if options.use_forced_decoder_ids: decoder_input_ids = helper.make_tensor_value_info( "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] ) graph_inputs.append(decoder_input_ids) - if options["use_logits_processor"]: + if options.use_logits_processor: logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) graph_inputs.append(logits_processor) - if version_1_17_1 and options["use_temperature"]: + if version_1_17_1 and options.use_temperature: temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1]) graph_inputs.append(temperature) @@ -269,7 +275,7 @@ def chain_model( graph_outputs = [sequences] # Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference - if options["use_gpu"] and version_1_16: + if options.use_gpu and version_1_16: from onnxruntime.transformers.convert_generation import ( update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha as update_decoder_with_ort, ) @@ -316,7 +322,7 @@ def add_attention_mask(self, model: ModelProto): model.graph.input.insert(1, mask) def _run_for_config( - self, model: OliveModelHandler, config: Dict[str, Any], output_model_path: str + self, model: OliveModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: from onnxruntime import __version__ as OrtVersion from onnxruntime.transformers import onnx_model as ort_onnx_model @@ -334,7 +340,7 @@ def _run_for_config( # version check version_1_16 = version.parse(OrtVersion) >= version.parse("1.16.0") - if not version_1_16 and config["use_forced_decoder_ids"]: + if not version_1_16 and config.use_forced_decoder_ids: logger.warning( "use_forced_decoder_ids is not supported in ONNX Runtime versions < 1.16.0. Will be ignored." ) diff --git a/olive/passes/onnx/io_datatype_converter.py b/olive/passes/onnx/io_datatype_converter.py index 1c022dbf9..50f2f2ec7 100644 --- a/olive/passes/onnx/io_datatype_converter.py +++ b/olive/passes/onnx/io_datatype_converter.py @@ -6,7 +6,7 @@ import re from collections import defaultdict from pathlib import Path -from typing import Any, Dict, Optional +from typing import Dict, Optional, Type import onnx @@ -15,7 +15,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam logger = logging.getLogger(__name__) @@ -139,7 +139,7 @@ def _verify_elem_type(self, elem_type): ) def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: from onnxruntime.transformers.onnx_model import OnnxModel @@ -153,11 +153,11 @@ def _run_for_config( self.create_io_mapping(ort_onnx_model.model.graph, i_map, o_map) pat = None - if config["name_pattern"]: - pat = re.compile(config["name_pattern"]) + if config.name_pattern: + pat = re.compile(config.name_pattern) - source_dtype = config["source_dtype"] - target_dtype = config["target_dtype"] + source_dtype = config.source_dtype + target_dtype = config.target_dtype self._verify_elem_type(source_dtype) self._verify_elem_type(target_dtype) diff --git a/olive/passes/onnx/mixed_precision.py b/olive/passes/onnx/mixed_precision.py index 96210bdc8..82d3711df 100644 --- a/olive/passes/onnx/mixed_precision.py +++ b/olive/passes/onnx/mixed_precision.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import logging from pathlib import Path -from typing import Any, Dict, List +from typing import Dict, List, Type from onnx import ValueInfoProto @@ -13,7 +13,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam logger = logging.getLogger(__name__) @@ -37,7 +37,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon return config def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: """Convert model to mixed precision. @@ -46,7 +46,7 @@ def _run_for_config( """ from onnxruntime.transformers.float16 import float_to_float16_max_diff - op_block_list = config["op_block_list"] + op_block_list = config.op_block_list op_full_set = {node.op_type for node in model.nodes()} fp32_op_set = set(op_block_list) fp16_op_set = op_full_set.difference(fp32_op_set) @@ -75,7 +75,7 @@ def _run_for_config( # we can deduce that the weights are stored in float16 precision. max_diff = float_to_float16_max_diff(initializer) logger.debug("max diff of converting weights in last MatMul node %s: %s", node.name, max_diff) - is_weight_fp16_precision = max_diff < config["atol"] + is_weight_fp16_precision = max_diff < config.atol else: logger.warning("Failed to find MatMul node for logits. Found %s of node %s", node.op_type, node.name) @@ -99,8 +99,7 @@ def _run_for_config( model=model.load_model(), use_symbolic_shape_infer=True, **parameters ) output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) - config = self._config_class(**config) - return model_proto_to_olive_model(fp16_model, output_model_path, config.dict()) + return model_proto_to_olive_model(fp16_model, output_model_path, config) def _convert_float_to_float16(self, model, use_symbolic_shape_infer=True, **kwargs): """Convert a model to half (default) or mixed precision. diff --git a/olive/passes/onnx/mixed_precision_overrides.py b/olive/passes/onnx/mixed_precision_overrides.py index fb458b95f..3976cc26e 100644 --- a/olive/passes/onnx/mixed_precision_overrides.py +++ b/olive/passes/onnx/mixed_precision_overrides.py @@ -6,14 +6,14 @@ import json from logging import getLogger from pathlib import Path -from typing import Any, Dict, Union +from typing import Dict, Type, Union from olive.hardware import AcceleratorSpec from olive.model import ONNXModelHandler from olive.model.utils import resolve_onnx_path from olive.passes.olive_pass import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_file -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam logger = getLogger(__name__) @@ -54,7 +54,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon def _run_for_config( self, model: ONNXModelHandler, - config: Dict[str, Any], + config: Type[BasePassConfig], output_model_path: str, ) -> ONNXModelHandler: """Run for config. @@ -84,10 +84,10 @@ def _run_for_config( from onnxruntime.quantization.onnx_model import ONNXModel overrides_content = {} - if isinstance(config["overrides_config"], dict): - overrides_content = config["overrides_config"] - elif isinstance(config["overrides_config"], str): - overrides_config_path = Path(config["overrides_config"]) + if isinstance(config.overrides_config, dict): + overrides_content = config.overrides_config + elif isinstance(config.overrides_config, str): + overrides_config_path = Path(config.overrides_config) with overrides_config_path.open() as f: overrides_content = json.load(f) else: @@ -137,7 +137,7 @@ def handle_conflict(tensor_name, node) -> bool: # If certain initializer tensor makes conflict, we do not convert it, but rather add it to conflict_data # which we analyze later - element_wise_binary_ops = config["element_wise_binary_ops"] or ["Add", "Sub", "Mul", "Div"] + element_wise_binary_ops = config.element_wise_binary_ops or ["Add", "Sub", "Mul", "Div"] for node in onnx_model.graph.node: if node.op_type in element_wise_binary_ops: @@ -195,11 +195,11 @@ def handle_conflict(tensor_name, node) -> bool: model_proto_to_file( onnx_model.model, output_model_path, - save_as_external_data=config["save_as_external_data"], - all_tensors_to_one_file=config["all_tensors_to_one_file"], - external_data_name=config["external_data_name"], - size_threshold=config["size_threshold"], - convert_attribute=config["convert_attribute"], + save_as_external_data=config.save_as_external_data, + all_tensors_to_one_file=config.all_tensors_to_one_file, + external_data_name=config.external_data_name, + size_threshold=config.size_threshold, + convert_attribute=config.convert_attribute, ) overrides_jsonable = { diff --git a/olive/passes/onnx/mnb_to_qdq.py b/olive/passes/onnx/mnb_to_qdq.py index bafafbc8d..4d3f9cf86 100644 --- a/olive/passes/onnx/mnb_to_qdq.py +++ b/olive/passes/onnx/mnb_to_qdq.py @@ -5,7 +5,7 @@ import logging import math from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Dict, Type import numpy as np import onnx @@ -16,7 +16,7 @@ from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model from olive.passes.onnx.onnx_dag import OnnxDAG -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam if TYPE_CHECKING: from numpy.typing import NDArray @@ -68,7 +68,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon } def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) @@ -78,9 +78,9 @@ def _run_for_config( dag.remove_identity_nodes() # if matmulnbits zero point is the following, then the zero point is not needed in the DQ node - default_mnb_zp = 8 if config["use_int4"] else 0 - int_np_dtype = np.int8 if config["use_int4"] else np.uint8 - int_elem_type = onnx.TensorProto.INT4 if config["use_int4"] else onnx.TensorProto.UINT4 + default_mnb_zp = 8 if config.use_int4 else 0 + int_np_dtype = np.int8 if config.use_int4 else np.uint8 + int_elem_type = onnx.TensorProto.INT4 if config.use_int4 else onnx.TensorProto.UINT4 # set of nodes to exclude from the conversion nodes_to_exclude = set(config["nodes_to_exclude"] or []) @@ -157,15 +157,15 @@ def _run_for_config( qi = qi.flatten() # skip if is a no-op zero point - if not config["add_zero_point"] and new_qi_name.endswith(".qzeros") and np.all(qi == default_mnb_zp): + if not config.add_zero_point and new_qi_name.endswith(".qzeros") and np.all(qi == default_mnb_zp): continue - if not config["use_transpose_op"]: + if not config.use_transpose_op: # becomes K X N qi = qi.T if qi.dtype == np.uint8: - if config["use_int4"]: + if config.use_int4: # no worries about making signed since the values only use 4 bits qi = qi.astype(np.int8) # subtract 8 to make it signed @@ -189,11 +189,9 @@ def _run_for_config( dq_inputs.append(new_qi_name) # DQ default zp is 0 but MatMulNBits is 8, so we need to add a zero tensor with all 8s # no need to add for int4 if add_zero_point is False - if len(dq_inputs) == 2 and (config["add_zero_point"] or not config["use_int4"]): + if len(dq_inputs) == 2 and (config.add_zero_point or not config.use_int4): zp_name = f"{dq_name}.qzeros" - zp_shape = ( - [N] if is_per_axis else ([N, num_k_blocks] if config["use_transpose_op"] else [num_k_blocks, N]) - ) + zp_shape = [N] if is_per_axis else ([N, num_k_blocks] if config.use_transpose_op else [num_k_blocks, N]) zp_tensor = onnx.helper.make_tensor( zp_name, int_elem_type, @@ -227,16 +225,16 @@ def _run_for_config( block_size=None if is_per_axis else block_size, # for some reason block_wise and per-axis appear to use swapped axis # flip the axis if it is per-axis - axis=(1 if config["use_transpose_op"] else 0) ^ (1 if is_per_axis else 0), + axis=(1 if config.use_transpose_op else 0) ^ (1 if is_per_axis else 0), ) ) new_value_infos.append( onnx.helper.make_tensor_value_info( - dq_output, float_elem_type, shape=[N, K] if config["use_transpose_op"] else [K, N] + dq_output, float_elem_type, shape=[N, K] if config.use_transpose_op else [K, N] ) ) - if config["use_transpose_op"]: + if config.use_transpose_op: # Transpose transpose_name = self._get_new_node_name(dag, node_name, "Transpose") transpose_output = f"{transpose_name}/output_0" diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index 7e7b9cf04..5345cd469 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -8,7 +8,7 @@ import json import logging from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Type, Union import onnx import transformers @@ -19,6 +19,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.olive_pass import PassConfigParam +from olive.passes.pass_config import BasePassConfig logger = logging.getLogger(__name__) @@ -95,16 +96,12 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon @classmethod def validate_config( cls, - config: Dict[str, Any], + config: Type[BasePassConfig], accelerator_spec: AcceleratorSpec, - disable_search: Optional[bool] = False, ) -> bool: - if not super().validate_config(config, accelerator_spec, disable_search): + if not super().validate_config(config, accelerator_spec): return False - config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) - config = config_cls(**config) - # if device is GPU, but user choose CPU EP, the is_cpu should be True if (config.precision == ModelBuilder.Precision.FP16) and not ( accelerator_spec.accelerator_type == Device.GPU @@ -124,13 +121,13 @@ def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool: def _run_for_config( self, model: Union[HfModelHandler, ONNXModelHandler], - config: Dict[str, Any], + config: Type[BasePassConfig], output_model_path: str, ) -> ONNXModelHandler: from onnxruntime_genai.models.builder import create_model - precision = config["precision"] - metadata_only = config["metadata_only"] + precision = config.precision + metadata_only = config.metadata_only if metadata_only: if not isinstance(model, ONNXModelHandler): @@ -167,12 +164,12 @@ def _run_for_config( extra_args["adapter_path"] = model.adapter_path if config.get("int4_block_size"): - if int(config["int4_block_size"]) not in [16, 32, 64, 128, 256]: + if int(config.int4_block_size) not in [16, 32, 64, 128, 256]: raise ValueError("Invalid int4_block_size. Accepted values: 16/32/64/128/256.") - extra_args["int4_block_size"] = config["int4_block_size"] + extra_args["int4_block_size"] = config.int4_block_size if config.get("int4_accuracy_level"): - extra_args["int4_accuracy_level"] = config["int4_accuracy_level"].value + extra_args["int4_accuracy_level"] = config.int4_accuracy_level.value # args that are only checked for presence, not value for arg in ["exclude_embeds", "exclude_lm_head"]: @@ -225,7 +222,7 @@ def _run_for_config( with open(genai_config_filepath) as istrm: genai_config = json.load(istrm) - genai_config["search"] = {**(genai_config.get("search") or {}), **(config.get("search") or {})} + genai_config["search"] = {**(genai_config.get("search") or {}), **(config.search or {})} with open(genai_config_filepath, "w") as ostrm: json.dump(genai_config, ostrm, indent=4) diff --git a/olive/passes/onnx/moe_experts_distributor.py b/olive/passes/onnx/moe_experts_distributor.py index 2a996a21e..483cbcdac 100644 --- a/olive/passes/onnx/moe_experts_distributor.py +++ b/olive/passes/onnx/moe_experts_distributor.py @@ -10,7 +10,7 @@ import pprint from abc import abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Tuple, Type, Union import numpy as np import onnx @@ -22,7 +22,7 @@ from olive.model import DistributedOnnxModelHandler, ONNXModelHandler from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam if TYPE_CHECKING: from onnxruntime.transformers.onnx_model import OnnxModel @@ -394,7 +394,7 @@ def _validators(cls) -> Dict[str, Callable]: return {"validate_distributor_config": validator("world_size", allow_reuse=True)(cls._validate_world_size)} def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> DistributedOnnxModelHandler: # huggingface/tokenizers: The current process just got forked, after parallelism has already been used. # Disabling parallelism to avoid deadlocks... @@ -403,18 +403,18 @@ def _run_for_config( # - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) os.environ["TOKENIZERS_PARALLELISM"] = "false" - matcher = MoEExpertDistributionPatternMatcherA(config["world_size"], model.model_path) + matcher = MoEExpertDistributionPatternMatcherA(config.world_size, model.model_path) experts, num_experts = matcher.identify_experts(output_model_path) matcher.distribute( experts, num_experts, output_model_path, - use_external_data_format=config["save_as_external_data"], - all_tensors_to_one_file=config["all_tensors_to_one_file"], - parallel_jobs=config["parallel_jobs"] or multiprocessing.cpu_count(), + use_external_data_format=config.save_as_external_data, + all_tensors_to_one_file=config.all_tensors_to_one_file, + parallel_jobs=config.parallel_jobs or multiprocessing.cpu_count(), ) return DistributedOnnxModelHandler( model_path=str(Path(output_model_path).with_suffix("")), model_name_pattern=DistributedOnnxModelHandler.DEFAULT_RANKED_MODEL_NAME_FORMAT, - num_ranks=config["world_size"], + num_ranks=config.world_size, ) diff --git a/olive/passes/onnx/nvmo_quantization.py b/olive/passes/onnx/nvmo_quantization.py index f3d88b09c..a98f0655b 100644 --- a/olive/passes/onnx/nvmo_quantization.py +++ b/olive/passes/onnx/nvmo_quantization.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import logging from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Type, Union import onnx import torch @@ -19,7 +19,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.onnx.common import model_proto_to_olive_model -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam from olive.search.search_parameter import Categorical logger = logging.getLogger(__name__) @@ -76,16 +76,12 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon @classmethod def validate_config( cls, - config: Dict[str, Any], + config: Type[BasePassConfig], accelerator_spec: AcceleratorSpec, - disable_search: Optional[bool] = False, ) -> bool: - if not super().validate_config(config, accelerator_spec, disable_search): + if not super().validate_config(config, accelerator_spec): return False - config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) - config = config_cls(**config) - # Validate Precision if config.precision != NVModelOptQuantization.Precision.INT4: logger.error("Only INT4 quantization is supported.") @@ -120,14 +116,14 @@ def validate_config( return True - def initialize_quant_config(self, config: Dict[str, Any]) -> Dict[str, Any]: + def initialize_quant_config(self, config: Type[BasePassConfig]) -> Dict[str, Any]: # Check if 'tokenizer_dir' is provided and not empty - random_calib = config.get("random_calib_data", False) + random_calib = config.random_calib_data or False if not random_calib: # Prepare calibration inputs only if tokenizer_dir is specified calib_inputs = self.get_calib_inputs( dataset_name="cnn", - model_name=config["tokenizer_dir"], + model_name=config.tokenizer_dir, cache_dir="./cache", calib_size=32, batch_size=1, @@ -146,10 +142,10 @@ def initialize_quant_config(self, config: Dict[str, Any]) -> Dict[str, Any]: # Return a dictionary containing necessary configuration for quantization return { - "algorithm": config.get("algorithm", self.Algorithm.AWQ.value), - "precision": config.get("precision", self.Precision.INT4.value), - "calibration_method": config.get("calibration", self.Calibration.AWQ_CLIP.value), - "tokenizer_dir": config.get("tokenizer_dir", ""), + "algorithm": config.algorithm or self.Algorithm.AWQ.value, + "precision": config.precision or self.Precision.INT4.value, + "calibration_method": config.calibration or self.Calibration.AWQ_CLIP.value, + "tokenizer_dir": config.tokenizer_dir or "", "calibration_data_reader": calib_inputs, } @@ -384,7 +380,7 @@ def convert_opset_to_21_proto(self, model_proto: ModelProto) -> ModelProto: return model_proto def _run_for_config( - self, model: OliveModelHandler, config: Dict[str, Any], output_model_path: str + self, model: OliveModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> OliveModelHandler: try: diff --git a/olive/passes/onnx/optimum_conversion.py b/olive/passes/onnx/optimum_conversion.py index 153bf6825..8ccad93e4 100644 --- a/olive/passes/onnx/optimum_conversion.py +++ b/olive/passes/onnx/optimum_conversion.py @@ -5,12 +5,12 @@ import logging from copy import deepcopy from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Type, Union from olive.hardware.accelerator import AcceleratorSpec from olive.model import CompositeModelHandler, HfModelHandler, ONNXModelHandler from olive.passes import Pass -from olive.passes.pass_config import PassConfigParam, get_user_script_data_config +from olive.passes.pass_config import BasePassConfig, PassConfigParam, get_user_script_data_config logger = logging.getLogger(__name__) @@ -51,16 +51,12 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon @classmethod def validate_config( cls, - config: Dict[str, Any], + config: Type[BasePassConfig], accelerator_spec: AcceleratorSpec, - disable_search: Optional[bool] = False, ) -> bool: - if not super().validate_config(config, accelerator_spec, disable_search): + if not super().validate_config(config, accelerator_spec): return False - config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) - config = config_cls(**config) - if config.fp16 and config.device != "cuda": logger.info("OptimumConversion: fp16 is set to True, but device is not set to cuda.") return False @@ -68,18 +64,18 @@ def validate_config( return True def _run_for_config( - self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> Union[ONNXModelHandler, CompositeModelHandler]: from optimum import version as optimum_version from optimum.exporters.onnx import main_export as export_optimum_model from packaging import version - extra_args = deepcopy(config["extra_args"]) or {} + extra_args = deepcopy(config.extra_args) or {} extra_args.update( { - "opset": config["target_opset"], - "fp16": config["fp16"], - "device": config["device"], + "opset": config.target_opset, + "fp16": config.fp16, + "device": config.device, } ) if model.load_kwargs and "trust_remote_code" not in extra_args: @@ -104,11 +100,11 @@ def _run_for_config( # check the exported components exported_models = [name.stem for name in Path(output_model_path).iterdir() if name.suffix == ".onnx"] - if config["components"]: + if config.components: assert all( - component in exported_models for component in config["components"] + component in exported_models for component in config.components ), f"Components {config['components']} are not exported. Only {exported_models} are exported." - components = config["components"] or exported_models + components = config.components or exported_models logger.debug("Exported models are: %s. Returning components: %s.", exported_models, components) # if there is only one component, return it directly diff --git a/olive/passes/onnx/optimum_merging.py b/olive/passes/onnx/optimum_merging.py index 4d601a594..a9834a15a 100644 --- a/olive/passes/onnx/optimum_merging.py +++ b/olive/passes/onnx/optimum_merging.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from typing import Any, Dict, Union +from typing import Dict, Type, Union from onnx import ModelProto @@ -11,7 +11,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam class OptimumMerging(Pass): @@ -43,7 +43,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon return config def _run_for_config( - self, model: CompositeModelHandler, config: Dict[str, Any], output_model_path: str + self, model: CompositeModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> Union[ONNXModelHandler, CompositeModelHandler]: import onnxruntime @@ -65,7 +65,7 @@ def new_byte_size_func(_): merged_model = merge_decoders( model.model_components[0].model_path, model.model_components[1].model_path, - strict=config["strict"], + strict=config.strict, ) finally: ModelProto.ByteSize = prev_byte_size_func diff --git a/olive/passes/onnx/peephole_optimizer.py b/olive/passes/onnx/peephole_optimizer.py index 4a6af4241..e50d89560 100644 --- a/olive/passes/onnx/peephole_optimizer.py +++ b/olive/passes/onnx/peephole_optimizer.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import logging from pathlib import Path -from typing import Any, Dict +from typing import Dict, Type import numpy as np import onnx @@ -15,7 +15,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam logger = logging.getLogger(__name__) @@ -259,7 +259,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon return get_external_data_config() def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) diff --git a/olive/passes/onnx/qnn/qnn_preprocess.py b/olive/passes/onnx/qnn/qnn_preprocess.py index a8c7c30bf..654a98c11 100644 --- a/olive/passes/onnx/qnn/qnn_preprocess.py +++ b/olive/passes/onnx/qnn/qnn_preprocess.py @@ -5,14 +5,14 @@ import logging from pathlib import Path -from typing import Any, Dict +from typing import Dict, Type from olive.hardware import AcceleratorSpec from olive.model import ONNXModelHandler from olive.model.utils import resolve_onnx_path from olive.passes.olive_pass import Pass from olive.passes.onnx.common import get_external_data_config -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam logger = logging.getLogger(__name__) @@ -75,7 +75,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon def _run_for_config( self, model: ONNXModelHandler, - config: Dict[str, Any], + config: Type[BasePassConfig], output_model_path: str, ) -> ONNXModelHandler: from onnxruntime import __version__ as OrtVersion @@ -85,17 +85,17 @@ def _run_for_config( raise RuntimeError("QNNPreprocess only supports ONNXRuntime version 1.17.0 or later") output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) - external_data_location = config["external_data_name"] or f"{Path(output_model_path).name}.data" + external_data_location = config.external_data_name or f"{Path(output_model_path).name}.data" # only 1.18.0 or later adds the following parameters extra_kwargs = { - "save_as_external_data": config["save_as_external_data"], - "all_tensors_to_one_file": config["all_tensors_to_one_file"], - "external_data_size_threshold": config["size_threshold"], + "save_as_external_data": config.save_as_external_data, + "all_tensors_to_one_file": config.all_tensors_to_one_file, + "external_data_size_threshold": config.size_threshold, "external_data_location": external_data_location, - "external_data_convert_attribute": config["convert_attribute"], - "inputs_to_make_channel_last": config["inputs_to_make_channel_last"], - "outputs_to_make_channel_last": config["outputs_to_make_channel_last"], + "external_data_convert_attribute": config.convert_attribute, + "inputs_to_make_channel_last": config.inputs_to_make_channel_last, + "outputs_to_make_channel_last": config.outputs_to_make_channel_last, } if version.parse(OrtVersion) < version.parse("1.18.0"): removed_config = [ @@ -114,7 +114,7 @@ def _run_for_config( modified = qnn_preprocess_model( model.model_path, output_model_path, - fuse_layernorm=config["fuse_layernorm"], + fuse_layernorm=config.fuse_layernorm, **extra_kwargs, ) if not modified: diff --git a/olive/passes/onnx/quantization.py b/olive/passes/onnx/quantization.py index da39ab65f..09b9c0dc0 100644 --- a/olive/passes/onnx/quantization.py +++ b/olive/passes/onnx/quantization.py @@ -7,7 +7,7 @@ from copy import deepcopy from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Type, Union import onnx from packaging import version @@ -27,7 +27,7 @@ model_proto_to_file, model_proto_to_olive_model, ) -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam from olive.resource_path import LocalFile from olive.search.search_parameter import Boolean, Categorical, Conditional, ConditionalDefault @@ -256,7 +256,7 @@ def get_calibration_dataloader(config): - data_config = validate_config(config["data_config"], DataConfig) + data_config = validate_config(config.data_config, DataConfig) return data_config.to_data_container().create_calibration_dataloader() @@ -327,16 +327,12 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon @classmethod def validate_config( cls, - config: Dict[str, Any], + config: Type[BasePassConfig], accelerator_spec: AcceleratorSpec, - disable_search: Optional[bool] = False, ) -> bool: - if not super().validate_config(config, accelerator_spec, disable_search): + if not super().validate_config(config, accelerator_spec): return False - config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) - config = config_cls(**config) - if config.quant_mode == "static": if ( config.weight_type == "QInt8" @@ -355,7 +351,7 @@ def validate_config( return True def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: if model_has_adapters(model.model_path): logger.info("Model has adapters which should not be quantized. Returning the model without quantization.") @@ -366,21 +362,21 @@ def _run_for_config( from onnxruntime.quantization.calibrate import CalibrationMethod # start with a copy of the config - run_config = deepcopy(config) + run_config = config.dict() is_static = run_config["quant_mode"] == "static" if is_static: - assert config["data_config"], "data_config is required for static quantization." + assert config.data_config, "data_config is required for static quantization." # whether to prepare qnn config # we do the version check here and not in `validate_config` since search point validation # is done by the engine. Unless the host is local system, the ort version of the host is # not known by the engine when the search point is validated. - if config["prepare_qnn_config"] and version.parse(OrtVersion) < version.parse("1.17.0"): + if config.prepare_qnn_config and version.parse(OrtVersion) < version.parse("1.17.0"): raise OlivePassError("prepare_qnn_config is only supported by onnxruntime>=1.17.0") output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) # extra config - extra_options = deepcopy(config["extra_options"]) if config["extra_options"] else {} + extra_options = deepcopy(config.extra_options) or {} # keys in extra_options that are already exposed intersection = set(extra_options.keys()).intersection(set(_exposed_extra_options_config.keys())) if intersection: @@ -465,7 +461,7 @@ def _run_for_config( if is_static: # get the dataloader dataloader = get_calibration_dataloader(config) - if config["prepare_qnn_config"]: + if config.prepare_qnn_config: import inspect from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config @@ -474,10 +470,10 @@ def _run_for_config( if version.parse(OrtVersion) >= version.parse("1.18.0"): symmetric_options = { - "activation_symmetric": config["ActivationSymmetric"], - "weight_symmetric": config["WeightSymmetric"], + "activation_symmetric": config.ActivationSymmetric, + "weight_symmetric": config.WeightSymmetric, } - qnn_extra_options = config["qnn_extra_options"] or {} + qnn_extra_options = config.qnn_extra_options or {} if init_overrides := _get_qnn_init_overrides(model, config): qnn_extra_options["init_overrides"] = init_overrides qnn_config = get_qnn_qdq_config( @@ -720,13 +716,13 @@ def _validators(cls) -> Dict[str, Callable]: } def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: - if model_has_adapters(model.model_path) and config["algorithm"] not in {None, "DEFAULT"}: + if model_has_adapters(model.model_path) and config.algorithm not in {None, "DEFAULT"}: logger.info( "Model has adapters which should only be quantized with algorithm=None or DEFAULT. Got %s. Returning" " the model without quantization.", - config["algorithm"], + config.algorithm, ) return model @@ -741,16 +737,16 @@ def _run_for_config( weight_only_quant_config_class = None weight_only_quant_config = None - algo_config = deepcopy(config["weight_only_quant_configs"] or {}) + algo_config = deepcopy(config.weight_only_quant_configs) or {} if version.parse(OrtVersion) >= version.parse("1.17.0"): from onnxruntime.quantization.matmul_4bits_quantizer import ( GPTQWeightOnlyQuantConfig, RTNWeightOnlyQuantConfig, ) - if config["algorithm"] == "RTN": + if config.algorithm == "RTN": weight_only_quant_config_class = RTNWeightOnlyQuantConfig - elif config["algorithm"] == "GPTQ": + elif config.algorithm == "GPTQ": if "block_size" in algo_config and version.parse(OrtVersion) < version.parse("1.18.0"): # ort 1.17.0+ uses blocksize instead of block_size :( algo_config["blocksize"] = algo_config["block_size"] @@ -764,30 +760,30 @@ def _run_for_config( HQQWeightOnlyQuantConfig, ) - if config["algorithm"] == "DEFAULT": + if config.algorithm == "DEFAULT": weight_only_quant_config_class = DefaultWeightOnlyQuantConfig - elif config["algorithm"] == "HQQ": + elif config.algorithm == "HQQ": weight_only_quant_config_class = HQQWeightOnlyQuantConfig - elif config["algorithm"] in ("HQQ", "DEFAULT"): + elif config.algorithm in ("HQQ", "DEFAULT"): raise ValueError("HQQ and DEFAULT algorithm are only supported in onnxruntime >= 1.18.0") if weight_only_quant_config_class: weight_only_quant_config = weight_only_quant_config_class(**algo_config) quant = MatMul4BitsQuantizer( model.load_model(), - block_size=config["block_size"], - is_symmetric=config["is_symmetric"], - nodes_to_exclude=config["nodes_to_exclude"], - accuracy_level=config["accuracy_level"], + block_size=config.block_size, + is_symmetric=config.is_symmetric, + nodes_to_exclude=config.nodes_to_exclude, + accuracy_level=config.accuracy_level, algo_config=weight_only_quant_config, ) else: # TODO(trajep): remove this block once we migrate customer to onnxruntime>=1.17.0 all quant = MatMul4BitsQuantizer( model.load_model(), - block_size=config["block_size"], - is_symmetric=config["is_symmetric"], - nodes_to_exclude=config["nodes_to_exclude"], + block_size=config.block_size, + is_symmetric=config.is_symmetric, + nodes_to_exclude=config.nodes_to_exclude, ) quant.process() # topologically sort the graph at the end since previous optimizations may have broken it @@ -864,13 +860,13 @@ def _append_first_op_types_to_quantize_list( return op_types_to_quantize -def _get_qnn_init_overrides(model_handler: ONNXModelHandler, config: Dict[str, Any]): +def _get_qnn_init_overrides(model_handler: ONNXModelHandler, config: Type[BasePassConfig]): # get qnn overrides from the input model model_attributes = model_handler.model_attributes or {} mp_init_overrides = model_attributes.get("mixed_precision_overrides") or {} init_overrides = {} - config["qnn_extra_options"] = config["qnn_extra_options"] or {} - if mp_init_overrides and "init_overrides" not in config["qnn_extra_options"]: + config.qnn_extra_options = config.qnn_extra_options or {} + if mp_init_overrides and "init_overrides" not in config.qnn_extra_options: from onnxruntime.quantization import QuantType # use QuantType to get the quantization type @@ -879,11 +875,11 @@ def _get_qnn_init_overrides(model_handler: ONNXModelHandler, config: Dict[str, A for tensor, quant_types in mp_init_overrides.items() } # add `convert_outputs` to the TensorQuantOverridesHelper - convert_outputs = config.get("convert_outputs") or {} + convert_outputs = config.convert_outputs or {} for output_name, output_convert_type in convert_outputs.items(): init_overrides[output_name] = init_overrides.get(output_name, [{}]) init_overrides[output_name][0]["quant_type"] = init_overrides[output_name][0].get( "quant_type" - ) or QuantType.from_string(config.get("activation_type", "QUInt8")) + ) or QuantType.from_string(config.activation_type or "QUInt8") init_overrides[output_name][0]["convert"] = {"quant_type": QuantType.from_string(output_convert_type)} return init_overrides diff --git a/olive/passes/onnx/session_params_tuning.py b/olive/passes/onnx/session_params_tuning.py index 701c35080..baa392111 100644 --- a/olive/passes/onnx/session_params_tuning.py +++ b/olive/passes/onnx/session_params_tuning.py @@ -7,7 +7,7 @@ import logging import tempfile from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Type, Union import onnxruntime as ort @@ -22,7 +22,7 @@ from olive.hardware.accelerator import AcceleratorLookup, AcceleratorSpec from olive.model import ONNXModelHandler from olive.passes import Pass -from olive.passes.pass_config import PassConfigParam, get_user_script_data_config +from olive.passes.pass_config import BasePassConfig, PassConfigParam, get_user_script_data_config from olive.search.search_parameter import Categorical logger = logging.getLogger(__name__) @@ -173,19 +173,15 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon @classmethod def validate_config( cls, - config: Dict[str, Any], + config: Type[BasePassConfig], accelerator_spec: AcceleratorSpec, - disable_search: Optional[bool] = False, ) -> bool: """Validate the search point for the pass.""" - if not super().validate_config(config, accelerator_spec, disable_search): + if not super().validate_config(config, accelerator_spec): return False - config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) - config_cls.__config__.extra = Extra.allow - config = config_cls(**config) - # Rename the search parameters with atomic/singular names for clarity + config.__class__.__config__.extra = Extra.allow config.execution_provider = config.providers_list config.provider_options = config.provider_options_list config.execution_mode = config.execution_mode_list @@ -264,11 +260,11 @@ def validate_config( return True def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: # Rename the search parameters with atomic/singular names for clarity - self._config_class.__config__.extra = Extra.allow - config = self._config_class(**config) + config.__class__.__config__.extra = Extra.allow + config = config.copy() config.execution_provider = config.providers_list config.provider_options = config.provider_options_list config.execution_mode = config.execution_mode_list diff --git a/olive/passes/onnx/split.py b/olive/passes/onnx/split.py index cdaf61305..0f8d7501c 100644 --- a/olive/passes/onnx/split.py +++ b/olive/passes/onnx/split.py @@ -6,7 +6,7 @@ from collections import defaultdict from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional +from typing import Dict, Optional, Type import numpy as np import onnx @@ -17,7 +17,7 @@ from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model from olive.passes.onnx.onnx_dag import OnnxDAG -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam logger = logging.getLogger(__name__) @@ -32,7 +32,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon } def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> CompositeModelHandler: model_proto = model.load_model() diff --git a/olive/passes/onnx/transformer_optimization.py b/olive/passes/onnx/transformer_optimization.py index 11188c553..931c2d689 100644 --- a/olive/passes/onnx/transformer_optimization.py +++ b/olive/passes/onnx/transformer_optimization.py @@ -3,9 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from copy import deepcopy from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Type, Union import onnx @@ -17,7 +16,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam if TYPE_CHECKING: from onnxruntime.transformers.onnx_model import OnnxModel @@ -139,19 +138,15 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon @classmethod def validate_config( cls, - config: Dict[str, Any], + config: Type[BasePassConfig], accelerator_spec: AcceleratorSpec, - disable_search: Optional[bool] = False, ) -> bool: - if not super().validate_config(config, accelerator_spec, disable_search): + if not super().validate_config(config, accelerator_spec): return False from onnxruntime import __version__ as OrtVersion from packaging import version - config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) - config = config_cls(**config) - if config.float16: if accelerator_spec.execution_provider == "TensorrtExecutionProvider": logger.info( @@ -210,14 +205,14 @@ def _set_fusion_options(run_config: Dict[str, Any]): run_config["optimization_options"] = fusion_options def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: from onnxruntime.transformers import optimizer as transformers_optimizer - num_kv_heads = config["num_key_value_heads"] + num_kv_heads = config.num_key_value_heads # start with a copy of the config - run_config = deepcopy(config) + run_config = config.dict() keys_to_remove = [ "float16", "keep_io_types", @@ -268,7 +263,7 @@ def _run_for_config( output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) - optimization_options = config["optimization_options"] + optimization_options = config.optimization_options if optimization_options: self._set_fusion_options(run_config) @@ -296,22 +291,22 @@ def _run_for_config( optimizer = transformers_optimizer.optimize_model(input=model.model_path, **run_config) - if config["float16"]: + if config.float16: optimizer.convert_float_to_float16( - keep_io_types=config["keep_io_types"], - op_block_list=config["force_fp32_ops"], - node_block_list=config["force_fp32_nodes"], - force_fp16_inputs=config["force_fp16_inputs"], + keep_io_types=config.keep_io_types, + op_block_list=config.force_fp32_ops, + node_block_list=config.force_fp32_nodes, + force_fp16_inputs=config.force_fp16_inputs, ) - if config["use_gqa"]: + if config.use_gqa: world_size = model.model_attributes.get("world_size", 1) if model.model_attributes is not None else 1 optimizer = self._replace_mha_with_gqa(optimizer, kv_num_heads=num_kv_heads, world_size=world_size) optimizer.prune_graph() # add allow_remove_graph_inputs to pass config optimizer.update_graph(allow_remove_graph_inputs=True) - if config["input_int32"]: + if config.input_int32: optimizer.change_graph_inputs_to_int32() # Topologically sort the graph at the end since previous optimizations may have broken it diff --git a/olive/passes/onnx/vitis_ai_quantization.py b/olive/passes/onnx/vitis_ai_quantization.py index bd0088cac..dc3bd87f8 100644 --- a/olive/passes/onnx/vitis_ai_quantization.py +++ b/olive/passes/onnx/vitis_ai_quantization.py @@ -6,7 +6,7 @@ import tempfile from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Union +from typing import Dict, Type, Union import onnx @@ -23,7 +23,7 @@ model_proto_to_file, model_proto_to_olive_model, ) -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam from olive.resource_path import LocalFile from olive.search.search_parameter import Boolean, Categorical, Conditional @@ -236,7 +236,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon } def _run_for_config( - self, model: ONNXModelHandler, config: Dict[str, Any], output_model_path: str + self, model: ONNXModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: if model_has_adapters(model.model_path): logger.info("Model has adapters which should not be quantized. Returning the model without quantization.") @@ -248,12 +248,12 @@ def _run_for_config( from olive.passes.onnx.vitis_ai.quant_utils import PowerOfTwoMethod # start with a copy of the config - run_config = deepcopy(config) + run_config = config.dict() output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) # extra config - extra_options = deepcopy(config["extra_options"]) if config["extra_options"] else {} + extra_options = deepcopy(config.extra_options) if config.extra_options else {} # keys in extra_options that are already exposed intersection = set(extra_options.keys()).intersection(set(_exposed_extra_options_config.keys())) if intersection: @@ -315,8 +315,8 @@ def _run_for_config( # get the dataloader dataloader = None - if config["data_config"]: - data_config = validate_config(config["data_config"], DataConfig) + if config.data_config: + data_config = validate_config(config.data_config, DataConfig) dataloader = data_config.to_data_container().create_calibration_dataloader() execution_provider = self.accelerator_spec.execution_provider diff --git a/olive/passes/openvino/conversion.py b/olive/passes/openvino/conversion.py index d1b4c3216..3c35fee18 100644 --- a/olive/passes/openvino/conversion.py +++ b/olive/passes/openvino/conversion.py @@ -3,13 +3,13 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from pathlib import Path -from typing import Any, Callable, Dict, List, Union +from typing import Callable, Dict, List, Type, Union from olive.constants import Framework from olive.hardware.accelerator import AcceleratorSpec from olive.model import HfModelHandler, ONNXModelHandler, OpenVINOModelHandler, PyTorchModelHandler from olive.passes import Pass -from olive.passes.pass_config import PassConfigParam, get_user_script_data_config +from olive.passes.pass_config import BasePassConfig, PassConfigParam, get_user_script_data_config class OpenVINOConversion(Pass): @@ -65,7 +65,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon def _run_for_config( self, model: Union[HfModelHandler, PyTorchModelHandler, ONNXModelHandler], - config: Dict[str, Any], + config: Type[BasePassConfig], output_model_path: str, ) -> OpenVINOModelHandler: try: @@ -81,18 +81,18 @@ def _run_for_config( input_model = model.load_model() example_input = None - if config.get("example_input_func"): - example_input = self._user_module_loader.call_object(config["example_input_func"]) + if config.example_input_func: + example_input = self._user_module_loader.call_object(config.example_input_func) input_shape = None - if config.get("input"): - config_input = config["input"] + if config.input: + config_input = config.input if isinstance(config_input, List): input_shape = config_input else: input_shape = self._user_module_loader.call_object(config_input) - extra_configs = config.get("extra_configs") or {} + extra_configs = config.extra_configs or {} args = { "input_model": input_model, "input": input_shape, @@ -103,10 +103,8 @@ def _run_for_config( ov_model = ov.convert_model(**args) model_name = "ov_model" - output_dir = Path(output_model_path) / config.get("output_model", model_name) + output_dir = Path(output_model_path) / (config.output_model or model_name) # Save as ov model - ov.save_model( - ov_model, output_model=output_dir.with_suffix(".xml"), compress_to_fp16=config["compress_to_fp16"] - ) + ov.save_model(ov_model, output_model=output_dir.with_suffix(".xml"), compress_to_fp16=config.compress_to_fp16) return OpenVINOModelHandler(model_path=output_model_path) diff --git a/olive/passes/openvino/quantization.py b/olive/passes/openvino/quantization.py index 8afd53d5c..f489051a7 100644 --- a/olive/passes/openvino/quantization.py +++ b/olive/passes/openvino/quantization.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import logging from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Type, Union from olive.common.config_utils import validate_config from olive.common.utils import StrEnumBase @@ -13,7 +13,7 @@ from olive.model import OliveModelHandler from olive.model.handler import OpenVINOModelHandler from olive.passes import Pass -from olive.passes.pass_config import ParamCategory, PassConfigParam, get_user_script_data_config +from olive.passes.pass_config import BasePassConfig, ParamCategory, PassConfigParam, get_user_script_data_config if TYPE_CHECKING: from openvino import CompiledModel @@ -149,8 +149,8 @@ def _get_nncf_dataset(self, config): raise ImportError("Please install olive-ai[openvino] to use OpenVINO pass") from None data_loader = None - if config["data_config"]: - data_config = validate_config(config["data_config"], DataConfig) + if config.data_config: + data_config = validate_config(config.data_config, DataConfig) data_loader = data_config.to_data_container().create_dataloader() def transform_fn(data_item): @@ -172,16 +172,14 @@ def _get_extra_params(config): } extra_params = {} - extra_params["model_type"] = nncf.ModelType.Transformer if config.get("model_type") == "TRANSFORMER" else None + extra_params["model_type"] = nncf.ModelType.Transformer if config.model_type == "TRANSFORMER" else None extra_params["preset"] = ( - nncf.QuantizationPreset.PERFORMANCE - if config.get("preset") == "PERFORMANCE" - else nncf.QuantizationPreset.MIXED + nncf.QuantizationPreset.PERFORMANCE if config.preset == "PERFORMANCE" else nncf.QuantizationPreset.MIXED ) - extra_params["target_device"] = device_map.get(config.get("target_device"), nncf.TargetDevice.ANY) + extra_params["target_device"] = device_map.get(config.target_device, nncf.TargetDevice.ANY) - if config.get("ignored_scope"): - kwargs = {config.get("ignored_scope_type"): config.get("ignored_scope")} + if config.ignored_scope: + kwargs = {config.ignored_scope_type: config.ignored_scope} extra_params["ignored_scopes"] = nncf.IgnoredScope(**kwargs) return extra_params @@ -190,7 +188,7 @@ def _get_extra_params(config): class OpenVINOQuantization(OpenVINOQuantizationBase): def _run_for_config( - self, model: OpenVINOModelHandler, config: Dict[str, Any], output_model_path: str + self, model: OpenVINOModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> OpenVINOModelHandler: try: import nncf @@ -248,7 +246,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon return config def _run_for_config( - self, model: OliveModelHandler, config: Dict[str, Any], output_model_path: str + self, model: OliveModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> OliveModelHandler: try: import nncf @@ -263,19 +261,19 @@ def _run_for_config( extra_params = self._get_extra_params(config) validate_func = ( - self._user_module_loader.load_object(config["validation_func"]) - if config.get("validation_func") + self._user_module_loader.load_object(config.validation_func) + if config.validation_func else _default_validate_func ) - drop_type = nncf.DropType.ABSOLUTE if config["drop_type"] == "ABSOLUTE" else nncf.DropType.RELATIVE + drop_type = nncf.DropType.ABSOLUTE if config.drop_type == "ABSOLUTE" else nncf.DropType.RELATIVE quantized_model = nncf.quantize_with_accuracy_control( model, calibration_dataset=calibration_dataset, validation_dataset=validation_dataset, validation_fn=validate_func, - max_drop=config["max_drop"], + max_drop=config.max_drop, drop_type=drop_type, **extra_params ) diff --git a/olive/passes/pass_config.py b/olive/passes/pass_config.py index 27144e420..624af33c7 100644 --- a/olive/passes/pass_config.py +++ b/olive/passes/pass_config.py @@ -115,7 +115,7 @@ class AbstractPassConfig(NestedConfig): """Base class for pass configuration.""" type: str = Field(description="The type of the pass.") - config: Dict[str, Any] = Field( + config: Union[Dict[str, Any], Type[BasePassConfig]] = Field( None, description=( "The configuration of the pass. Values for required parameters must be provided. For optional parameters," diff --git a/olive/passes/pytorch/autoawq.py b/olive/passes/pytorch/autoawq.py index 3e83b2858..8176ef77d 100644 --- a/olive/passes/pytorch/autoawq.py +++ b/olive/passes/pytorch/autoawq.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import logging from copy import deepcopy -from typing import Any, Dict, Union +from typing import Any, Dict, Type, Union import torch from packaging import version @@ -14,7 +14,7 @@ from olive.hardware.accelerator import AcceleratorSpec from olive.model import HfModelHandler from olive.passes import Pass -from olive.passes.pass_config import PassConfigParam, get_user_script_data_config +from olive.passes.pass_config import BasePassConfig, PassConfigParam, get_user_script_data_config from olive.passes.pytorch.common import inherit_hf_from_hf logger = logging.getLogger(__name__) @@ -108,20 +108,22 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon } @torch.no_grad() - def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler: + def _run_for_config( + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str + ) -> HfModelHandler: from awq import AutoAWQForCausalLM if not torch.cuda.is_available(): raise ValueError("Please use GPU to run AWQ quantization.") data_kwargs = {} - if config["data_config"]: + if config.data_config: # set default values for data config data_kwargs.update( { - "calib_data": config["data_config"].load_dataset_params.get("data_name"), - "split": config["data_config"].load_dataset_params.get("split"), - "text_column": config["data_config"].pre_process_params.get("input_cols"), + "calib_data": config.data_config.load_dataset_params.get("data_name"), + "split": config.data_config.load_dataset_params.get("split"), + "text_column": config.data_config.pre_process_params.get("input_cols"), } ) @@ -145,14 +147,14 @@ def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_ awq_model.quantize( tokenizer, quant_config={ - "zero_point": config["zero_point"], - "q_group_size": config["q_group_size"], - "w_bit": config["w_bit"], - "version": config["version"], - "modules_to_not_convert": config["modules_to_not_convert"], + "zero_point": config.zero_point, + "q_group_size": config.q_group_size, + "w_bit": config.w_bit, + "version": config.version, + "modules_to_not_convert": config.modules_to_not_convert, }, - duo_scaling=config["duo_scaling"], - export_compatible=config["export_compatible"], + duo_scaling=config.duo_scaling, + export_compatible=config.export_compatible, **data_kwargs, ) diff --git a/olive/passes/pytorch/capture_split_info.py b/olive/passes/pytorch/capture_split_info.py index 71e666ee7..aaffa224e 100644 --- a/olive/passes/pytorch/capture_split_info.py +++ b/olive/passes/pytorch/capture_split_info.py @@ -6,7 +6,7 @@ import logging from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Dict, Type, Union import numpy as np @@ -15,7 +15,7 @@ from olive.hardware.accelerator import AcceleratorSpec from olive.model import HfModelHandler, PyTorchModelHandler from olive.passes import Pass -from olive.passes.pass_config import ParamCategory, PassConfigParam +from olive.passes.pass_config import BasePassConfig, ParamCategory, PassConfigParam logger = logging.getLogger(__name__) @@ -65,16 +65,12 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon @classmethod def validate_config( cls, - config: Dict[str, Any], + config: Type[BasePassConfig], accelerator_spec: AcceleratorSpec, - disable_search: Optional[bool] = False, ) -> bool: - if not super().validate_config(config, accelerator_spec, disable_search): + if not super().validate_config(config, accelerator_spec): return False - config_cls, _ = cls.get_config_class(accelerator_spec, disable_search) - config = config_cls(**config) - if config.num_splits is None and config.cost_model is None: logger.info("One of num_splits or cost_model is required.") return False @@ -90,12 +86,12 @@ def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool: return False def _run_for_config( - self, model: Union[HfModelHandler, PyTorchModelHandler], config: Dict[str, Any], output_model_path: str + self, model: Union[HfModelHandler, PyTorchModelHandler], config: Type[BasePassConfig], output_model_path: str ) -> Union[HfModelHandler, PyTorchModelHandler]: split_assignments = None - if config["num_splits"]: + if config.num_splits: split_assignments = self.split_using_num_splits(model, config) - elif config["cost_model"]: + elif config.cost_model: split_assignments = self.split_using_cost_model(model, config) else: raise ValueError("One of num_splits or cost_model is required.") @@ -109,12 +105,12 @@ def _run_for_config( return output_model def split_using_num_splits( - self, model: Union[HfModelHandler, PyTorchModelHandler], config: Dict[str, Any] + self, model: Union[HfModelHandler, PyTorchModelHandler], config: Type[BasePassConfig] ) -> Dict[str, int]: # consider loading with meta device to avoid loading the weights loaded_model = model.load_model(cache_model=False) - block_to_split = config["block_to_split"] + block_to_split = config.block_to_split # check for None specifically since "" is a valid value if block_to_split is None and isinstance(model, HfModelHandler): model_wrapper = ModelWrapper.from_model(loaded_model) @@ -127,21 +123,21 @@ def split_using_num_splits( block_members = [child_name for child_name, _ in block.named_children()] split_assignments = {} - for split_idx, split_members in enumerate(np.array_split(block_members, config["num_splits"])): + for split_idx, split_members in enumerate(np.array_split(block_members, config.num_splits)): for child_name in split_members: split_assignments[f"{block_to_split}.{child_name}".lstrip(".")] = split_idx return split_assignments def split_using_cost_model( - self, model: Union[HfModelHandler, PyTorchModelHandler], config: Dict[str, Any] + self, model: Union[HfModelHandler, PyTorchModelHandler], config: Type[BasePassConfig] ) -> Dict[str, int]: if self.accelerator_spec.memory is None: raise ValueError("Accelerator memory is required to split using cost model.") # will only care about the number of bytes for now module_to_cost = {} - with open(config["cost_model"]) as f: + with open(config.cost_model) as f: reader = csv.DictReader(f) for row in reader: module_to_cost[row["module"]] = (int(row["num_params"]), int(row["num_bytes"]), int(row["num_flops"])) @@ -149,12 +145,12 @@ def split_using_cost_model( loaded_model = model.load_model(cache_model=False) modules_to_exclude = set() - if config["exclude_embeds"] and isinstance(model, HfModelHandler): + if config.exclude_embeds and isinstance(model, HfModelHandler): model_wrapper = ModelWrapper.from_model(loaded_model) modules_to_exclude.update(model_wrapper.get_embeds()[1]) - elif config["exclude_embeds"]: + elif config.exclude_embeds: modules_to_exclude.add("model.embed_tokens") - if config["exclude_lm_head"]: + if config.exclude_lm_head: modules_to_exclude.add("lm_head") split_assignments = {} diff --git a/olive/passes/pytorch/gptq.py b/olive/passes/pytorch/gptq.py index bb42cbf52..3630af65b 100644 --- a/olive/passes/pytorch/gptq.py +++ b/olive/passes/pytorch/gptq.py @@ -7,7 +7,7 @@ from argparse import Namespace from copy import deepcopy from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, Union import torch from packaging import version @@ -21,7 +21,7 @@ from olive.model import HfModelHandler, PyTorchModelHandler from olive.model.utils.path_utils import normalize_path_suffix from olive.passes import Pass -from olive.passes.pass_config import PassConfigParam, get_user_script_data_config +from olive.passes.pass_config import BasePassConfig, PassConfigParam, get_user_script_data_config from olive.passes.pytorch.common import inherit_hf_from_hf, inherit_pytorch_from_pytorch logger = logging.getLogger(__name__) @@ -103,7 +103,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon @torch.no_grad() def _run_for_config( - self, model: Union[HfModelHandler, PyTorchModelHandler], config: Dict[str, Any], output_model_path: str + self, model: Union[HfModelHandler, PyTorchModelHandler], config: Type[BasePassConfig], output_model_path: str ) -> PyTorchModelHandler: from auto_gptq import BaseQuantizeConfig, __version__ from auto_gptq.modeling import BaseGPTQForCausalLM @@ -139,13 +139,13 @@ def _run_for_config( model_wrapper = ModelWrapper.from_model(pytorch_model) quantize_config = BaseQuantizeConfig( - bits=config["bits"], - group_size=config["group_size"], - damp_percent=config["damp_percent"], - static_groups=config["static_groups"], - true_sequential=config["true_sequential"], - desc_act=config["desc_act"], - sym=config["sym"], + bits=config.bits, + group_size=config.group_size, + damp_percent=config.damp_percent, + static_groups=config.static_groups, + true_sequential=config.true_sequential, + desc_act=config.desc_act, + sym=config.sym, # this is so that the weight gets saved as "model.safetensors" model_file_base_name="model", ) @@ -154,9 +154,10 @@ def _run_for_config( quantized_model: BaseGPTQForCausalLM = model_class(pytorch_model, False, quantize_config) for key in ["outside_layer_modules", "inside_layer_modules", "layers_block_name"]: - if config[key]: + v = getattr(config, key, None) + if v: # user provided value - setattr(quantized_model, key, config[key]) + setattr(quantized_model, key, v) elif model_type in GPTQ_CAUSAL_LM_MODEL_MAP: # gptq supports the model type pass @@ -175,9 +176,9 @@ def _run_for_config( qlinear_class = dynamically_import_QuantLinear( use_triton=False, - desc_act=config["desc_act"], - group_size=config["group_size"], - bits=config["bits"], + desc_act=config.desc_act, + group_size=config.group_size, + bits=config.bits, disable_exllama=False, disable_exllamav2=True, ) @@ -231,7 +232,7 @@ def get_dataset( self, model: Union[HfModelHandler, PyTorchModelHandler], config: Dict[str, Any] ) -> List[Dict[str, Any]]: """Get the dataset for quantization.""" - data_config = config["data_config"] + data_config = config.data_config if not data_config and isinstance(model, HfModelHandler): data_config = self.get_calibration_data_config( model.model_name_or_path, trust_remote_code=model.get_load_kwargs().get("trust_remote_code", None) diff --git a/olive/passes/pytorch/lora.py b/olive/passes/pytorch/lora.py index eb7740b83..26ab16a81 100644 --- a/olive/passes/pytorch/lora.py +++ b/olive/passes/pytorch/lora.py @@ -15,12 +15,11 @@ from copy import deepcopy from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union import transformers from packaging import version -from olive.common.config_utils import ConfigBase from olive.common.hf.mappings import MODELS_TO_LORA_TARGET_MODULES_MAPPING from olive.common.hf.utils import get_peft_task_type_from_task from olive.common.pydantic_v1 import Field, validator @@ -32,6 +31,7 @@ from olive.model.config.hf_config import HfLoadKwargs from olive.passes import Pass from olive.passes.olive_pass import PassConfigParam +from olive.passes.pass_config import BasePassConfig from olive.passes.pytorch.train_utils import ( BaseHFTrainingArguments, count_trainable_parameters, @@ -168,7 +168,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon } @classmethod - def check_dependencies(cls, config: ConfigBase, is_qlora: bool = False): + def check_dependencies(cls, config: Type[BasePassConfig], is_qlora: bool = False): """Check dependencies for the pass.""" # bitsandbytes quantization only supported after transformers 4.30.0 if is_qlora and version.parse(transformers.__version__) < version.parse("4.30.0"): @@ -223,9 +223,7 @@ def collate_batch(batch: List[Dict], tokenizer: "PreTrainedTokenizer") -> Dict[s return new_batch @staticmethod - def get_datasets( - config: ConfigBase, - ) -> Tuple["Dataset", Optional["Dataset"]]: + def get_datasets(config: Type[BasePassConfig]) -> Tuple["Dataset", Optional["Dataset"]]: """Load training and evaluation datasets.""" # we return dataset.Dataset object since the trainer works better with it # load training dataset @@ -238,16 +236,14 @@ def get_datasets( return train_dataset, eval_dataset - def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler: + def _run_for_config( + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str + ) -> HfModelHandler: return self._run_lora_training(model, config, output_model_path) def _run_lora_training( - self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str, use_dora: bool = False + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str, use_dora: bool = False ) -> HfModelHandler: - # convert config to pass config class - # this will validate the config and convert to the correct types - config = self._config_class(**config) - # check dependencies self.check_dependencies(config) @@ -275,7 +271,9 @@ def _run_lora_training( pytorch_model, model.get_hf_tokenizer(), config, deepcopy(model), output_model_path ) - def load_base_pytorch_model(self, model_handler: HfModelHandler, config: ConfigBase, **kwargs) -> "PreTrainedModel": + def load_base_pytorch_model( + self, model_handler: HfModelHandler, config: Type[BasePassConfig], **kwargs + ) -> "PreTrainedModel": """Load a base PyTorch model for fine-tuning. :param model_handler: The input model handler. @@ -299,7 +297,7 @@ def load_base_pytorch_model(self, model_handler: HfModelHandler, config: ConfigB def init_adapters( self, model: "PreTrainedModel", - config: ConfigBase, + config: Type[BasePassConfig], *, task: Optional[str] = None, use_loftq: Optional[bool] = False, @@ -334,7 +332,7 @@ def init_adapters( def enable_lora( self, model: "PreTrainedModel", - config: ConfigBase, + config: Type[BasePassConfig], task: Optional[str] = None, use_dora: bool = False, adapter_path: Optional[str] = None, @@ -384,7 +382,7 @@ def train_and_save_new_model( self, model: "PeftModel", tokenizer: "PreTrainedTokenizer", - config: ConfigBase, + config: Type[BasePassConfig], output_model: HfModelHandler, output_model_path: str, ) -> HfModelHandler: @@ -506,7 +504,9 @@ def get_target_modules(model: HfModelHandler) -> Optional[List[str]]: return None @staticmethod - def get_peft_model(model: "PreTrainedModel", config: ConfigBase, config_kwargs: Dict = None) -> "PeftModel": + def get_peft_model( + model: "PreTrainedModel", config: Type[BasePassConfig], config_kwargs: Dict = None + ) -> "PeftModel": """Get the PEFT model for LoRA fine-tuning.""" from peft import LoraConfig, LoraRuntimeConfig, get_peft_model @@ -530,7 +530,9 @@ def get_peft_model(model: "PreTrainedModel", config: ConfigBase, config_kwargs: class DoRA(LoRA): """Run DoRA fine-tuning on a Hugging Face PyTorch model.""" - def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler: + def _run_for_config( + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str + ) -> HfModelHandler: return self._run_lora_training(model, config, output_model_path, use_dora=True) @@ -590,7 +592,9 @@ class LoHa(LoRAVariant): """Run LoHa fine-tuning on a Hugging Face PyTorch model.""" @staticmethod - def get_peft_model(model: "PreTrainedModel", config: ConfigBase, config_kwargs: Dict = None) -> "PeftModel": + def get_peft_model( + model: "PreTrainedModel", config: Type[BasePassConfig], config_kwargs: Dict = None + ) -> "PeftModel": """Get the PEFT model for LoHa fine-tuning.""" from peft import LoHaConfig, get_peft_model @@ -614,9 +618,9 @@ def get_peft_model(model: "PreTrainedModel", config: ConfigBase, config_kwargs: return get_peft_model(model, config) @classmethod - def check_dependencies(cls, config: ConfigBase, is_qlora: bool = False): + def check_dependencies(cls, config: Type[BasePassConfig], is_qlora: bool = False): """Check dependencies for the pass.""" - super().check_dependencies(config) + super().check_dependencies(config, is_qlora=is_qlora) from peft import __version__ as peft_version @@ -651,7 +655,9 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon return config @staticmethod - def get_peft_model(model: "PreTrainedModel", config: ConfigBase, config_kwargs: Dict = None) -> "PeftModel": + def get_peft_model( + model: "PreTrainedModel", config: Type[BasePassConfig], config_kwargs: Dict = None + ) -> "PeftModel": """Get the PEFT model for LoKr fine-tuning.""" from peft import LoKrConfig, get_peft_model @@ -678,9 +684,9 @@ def get_peft_model(model: "PreTrainedModel", config: ConfigBase, config_kwargs: return get_peft_model(model, config) @classmethod - def check_dependencies(cls, config: ConfigBase, is_qlora: bool = False): + def check_dependencies(cls, config: Type[BasePassConfig], is_qlora: bool = False): """Check dependencies for the pass.""" - super().check_dependencies(config) + super().check_dependencies(config, is_qlora=is_qlora) from peft import __version__ as peft_version @@ -715,11 +721,9 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon config.update(super()._default_config(accelerator_spec)) return config - def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler: - # convert config to pass config class - # this will validate the config and convert to the correct types - config = self._config_class(**config) - + def _run_for_config( + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str + ) -> HfModelHandler: # check dependencies self.check_dependencies(config, is_qlora=True) @@ -748,7 +752,7 @@ def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_ @abstractmethod def get_quant_model( - self, model: HfModelHandler, config: ConfigBase, output_model_path: str + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> Tuple[HfModelHandler, "PreTrainedModel", Dict, List[str]]: """Get the model handler, LoRA model for QLoRA fine-tuning. @@ -785,7 +789,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon return config def get_quant_model( - self, model: HfModelHandler, config: ConfigBase, output_model_path: str + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> Tuple[HfModelHandler, "PreTrainedModel", Dict, List[str]]: """Get the model handler, LoRA model for QLoRA fine-tuning. @@ -836,7 +840,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon return config @classmethod - def check_dependencies(cls, config: ConfigBase, is_qlora: bool = False): + def check_dependencies(cls, config: Type[BasePassConfig], is_qlora: bool = False): """Check dependencies for the pass.""" super().check_dependencies(config, is_qlora=is_qlora) @@ -847,7 +851,7 @@ def check_dependencies(cls, config: ConfigBase, is_qlora: bool = False): raise ImportError(f"Please install peft >= 0.7.0 to use {cls.__name__} pass.") def get_quant_model( - self, model: HfModelHandler, config: ConfigBase, output_model_path: str + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> Tuple[HfModelHandler, "PreTrainedModel", Dict, List[str]]: """Get the model handler, LoRA model for QLoRA fine-tuning. diff --git a/olive/passes/pytorch/merge_adapter_weights.py b/olive/passes/pytorch/merge_adapter_weights.py index c53dd6d7d..c674c40b8 100644 --- a/olive/passes/pytorch/merge_adapter_weights.py +++ b/olive/passes/pytorch/merge_adapter_weights.py @@ -4,14 +4,14 @@ # -------------------------------------------------------------------------- import logging from copy import deepcopy -from typing import Any, Dict +from typing import Dict, Type import torch from olive.hardware.accelerator import AcceleratorSpec from olive.model import HfModelHandler from olive.passes import Pass -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam from olive.passes.pytorch.common import inherit_hf_from_hf logger = logging.getLogger(__name__) @@ -25,7 +25,9 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon return {} @torch.no_grad() - def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler: + def _run_for_config( + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str + ) -> HfModelHandler: if not model.adapter_path: raise RuntimeError( "No adapter path found in the model. Please check your input " diff --git a/olive/passes/pytorch/quantization_aware_training.py b/olive/passes/pytorch/quantization_aware_training.py index 6fdce4679..ae6c484ec 100755 --- a/olive/passes/pytorch/quantization_aware_training.py +++ b/olive/passes/pytorch/quantization_aware_training.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Union +from typing import Callable, Dict, Iterable, List, Type, Union from olive.common.config_utils import validate_config from olive.data.config import DataConfig @@ -11,7 +11,7 @@ from olive.model import PyTorchModelHandler from olive.passes import Pass from olive.passes.olive_pass import ParamCategory, PassConfigParam -from olive.passes.pass_config import get_user_script_data_config +from olive.passes.pass_config import BasePassConfig, get_user_script_data_config class QuantizationAwareTraining(Pass): @@ -93,26 +93,25 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon } def _run_for_config( - self, model: PyTorchModelHandler, config: Dict[str, Any], output_model_path: str + self, model: PyTorchModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> PyTorchModelHandler: from olive.passes.pytorch.qat_utils import QatTrainer - qat_trainer_config = self._config_class(**config) if Path(output_model_path).suffix != ".pt": output_model_path += ".pt" - if config["train_data_config"]: - qat_trainer_config.train_data_config = validate_config(config["train_data_config"], DataConfig) - if config["val_data_config"]: - qat_trainer_config.val_data_config = validate_config(config["val_data_config"], DataConfig) - if config["training_loop_func"]: - qat_trainer_config.training_loop_func = self._user_module_loader.load_object(config["training_loop_func"]) - if config["ptl_module"]: - qat_trainer_config.ptl_module = self._user_module_loader.load_object(config["ptl_module"]) - if config["ptl_data_module"]: - qat_trainer_config.ptl_data_module = self._user_module_loader.load_object(config["ptl_data_module"]) - if config["qconfig_func"]: - qat_trainer_config.qconfig_func = self._user_module_loader.load_object(config["qconfig_func"]) + if config.train_data_config: + config.train_data_config = validate_config(config.train_data_config, DataConfig) + if config.val_data_config: + config.val_data_config = validate_config(config.val_data_config, DataConfig) + if config.training_loop_func: + config.training_loop_func = self._user_module_loader.load_object(config.training_loop_func) + if config.ptl_module: + config.ptl_module = self._user_module_loader.load_object(config.ptl_module) + if config.ptl_data_module: + config.ptl_data_module = self._user_module_loader.load_object(config.ptl_data_module) + if config.qconfig_func: + config.qconfig_func = self._user_module_loader.load_object(config.qconfig_func) - qat_trainer = QatTrainer(model, qat_trainer_config, output_model_path) + qat_trainer = QatTrainer(model, config, output_model_path) return qat_trainer.execute_local() diff --git a/olive/passes/pytorch/rotate.py b/olive/passes/pytorch/rotate.py index 40a04dcb3..d402d6ede 100644 --- a/olive/passes/pytorch/rotate.py +++ b/olive/passes/pytorch/rotate.py @@ -6,7 +6,7 @@ import tempfile from copy import deepcopy from functools import partial -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Tuple, Type, Union import torch from torch import nn @@ -18,7 +18,7 @@ from olive.hardware.accelerator import AcceleratorSpec from olive.model import HfModelHandler from olive.passes import Pass -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam from olive.passes.pytorch.common import inherit_hf_from_hf from olive.passes.pytorch.train_utils import ( BaseHFTrainingArguments, @@ -241,8 +241,10 @@ class QuaRot(RotateBase): """ @torch.no_grad() - def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler: - model_wrapper, _, save_replacements = self.rotate_model(model, config["rotate_mode"], config["seed"]) + def _run_for_config( + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str + ) -> HfModelHandler: + model_wrapper, _, save_replacements = self.rotate_model(model, config.rotate_mode, config.seed) # save the model model_wrapper.save_model(output_model_path, replacements=save_replacements) @@ -328,20 +330,18 @@ def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_ from olive.passes.pytorch.sgdg import SGDG - training_args = HFTrainingArguments.parse_obj(config["training_args"] or {}) + training_args = HFTrainingArguments.parse_obj(config.training_args or {}) # rotate the model model_wrapper, rotation_params, save_replacements = self.rotate_model( - model, config["rotate_mode"], config["seed"], training_args + model, config.rotate_mode, config.seed, training_args ) # add activation quantization to the layer linear modules replace_submodules( model_wrapper.get_layers(False), RotateLinear, - partial( - ActQuantLinear, bits=config["a_bits"], symmetric=config["a_symmetric"], per_token=config["a_per_token"] - ), + partial(ActQuantLinear, bits=config.a_bits, symmetric=config.a_symmetric, per_token=config.a_per_token), ) save_replacements = [(ActQuantLinear, lambda x: x.linear), *save_replacements] diff --git a/olive/passes/pytorch/slicegpt.py b/olive/passes/pytorch/slicegpt.py index 4a4a2968a..05de9f51b 100644 --- a/olive/passes/pytorch/slicegpt.py +++ b/olive/passes/pytorch/slicegpt.py @@ -6,7 +6,7 @@ import json import logging import sys -from typing import Any, Dict, Union +from typing import Dict, Type, Union import torch from torch.utils.data import DataLoader, SubsetRandomSampler @@ -19,6 +19,7 @@ from olive.model.utils.path_utils import normalize_path_suffix from olive.passes import Pass from olive.passes.olive_pass import PassConfigParam +from olive.passes.pass_config import BasePassConfig from olive.passes.pytorch.common import inherit_pytorch_from_hf logger = logging.getLogger(__name__) @@ -77,7 +78,7 @@ def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigPa @torch.no_grad() def _run_for_config( - self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> PyTorchModelHandler: if sys.version_info < (3, 10): raise ValueError("SliceGPT requires python3.10 or higher") @@ -90,10 +91,6 @@ def _run_for_config( model_handler = model model = None - # convert config to pass config class - # this will validate the config and convert to the correct types - config = self._config_class(**config) - model_adapter, _ = get_model_and_tokenizer(model_handler.model_name_or_path) model_handler.model = model_adapter.model model = model_handler.load_model() diff --git a/olive/passes/pytorch/sparsegpt.py b/olive/passes/pytorch/sparsegpt.py index 9ebfe46eb..067a84702 100644 --- a/olive/passes/pytorch/sparsegpt.py +++ b/olive/passes/pytorch/sparsegpt.py @@ -7,7 +7,7 @@ # https://arxiv.org/abs/2301.00774 # ------------------------------------------------------------------------- import logging -from typing import Any, Dict, List, Union +from typing import Dict, List, Type, Union import torch @@ -18,6 +18,7 @@ from olive.model import HfModelHandler from olive.passes import Pass from olive.passes.olive_pass import PassConfigParam +from olive.passes.pass_config import BasePassConfig from olive.passes.pytorch.common import inherit_hf_from_hf from olive.passes.pytorch.sparsegpt_utils import ( SparseGPTModule, @@ -92,22 +93,24 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon } @torch.no_grad() - def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler: + def _run_for_config( + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str + ) -> HfModelHandler: model_type = model.model_attributes["model_type"] if model_type not in supported_models: raise ValueError(f"Unsupported model type: {model_type}. Supported types: {supported_models}") # get sparsity mode and parameters - if isinstance(config["sparsity"], float): - assert 0 <= config["sparsity"] <= 1, "Sparsity must be in [0,1]." - elif isinstance(config["sparsity"], list): - assert len(config["sparsity"]) == 2, "Sparsity must be a float or a list of two integers." - mode = "unstructured" if isinstance(config["sparsity"], float) else "structured" - sparsity = config["sparsity"] + if isinstance(config.sparsity, float): + assert 0 <= config.sparsity <= 1, "Sparsity must be in [0,1]." + elif isinstance(config.sparsity, list): + assert len(config.sparsity) == 2, "Sparsity must be a float or a list of two integers." + mode = "unstructured" if isinstance(config.sparsity, float) else "structured" + sparsity = config.sparsity n, m = sparsity if mode == "structured" else [0, 0] # get device to use for computations - device = config["device"] + device = config.device if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" logger.debug( @@ -115,7 +118,7 @@ def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_ ) # load_data - data_config = validate_config(config["data_config"], DataConfig) + data_config = validate_config(config.data_config, DataConfig) dataloader = data_config.to_data_container().create_dataloader() logger.debug("Data loaded. Number of batches: %d", len(dataloader)) @@ -138,8 +141,8 @@ def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_ outputs = torch.zeros_like(inputs) # prune layers - min_layer, max_layer = validate_min_max_layers(config["min_layer"], config["max_layer"], len(layers)) - layer_name_filter = config["layer_name_filter"] or [] + min_layer, max_layer = validate_min_max_layers(config.min_layer, config.max_layer, len(layers)) + layer_name_filter = config.layer_name_filter or [] if isinstance(layer_name_filter, str): layer_name_filter = [layer_name_filter] # loop over layers @@ -181,7 +184,7 @@ def handler(_, inputs, output): losses = {} for name, sparse_gpt_module in sparge_gpt_modules.items(): loss = sparse_gpt_module.prune( - mode, sparsity, n, m, blocksize=config["blocksize"], percdamp=config["percdamp"] + mode, sparsity, n, m, blocksize=config.blocksize, percdamp=config.percdamp ) losses[name] = loss sparse_gpt_module.free() diff --git a/olive/passes/pytorch/tensor_parallel.py b/olive/passes/pytorch/tensor_parallel.py index 1a28e4dbd..bd587f656 100644 --- a/olive/passes/pytorch/tensor_parallel.py +++ b/olive/passes/pytorch/tensor_parallel.py @@ -9,7 +9,7 @@ import multiprocessing from abc import abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict +from typing import TYPE_CHECKING, Callable, Dict, Type from olive.common.config_utils import ParamCategory from olive.common.pydantic_v1 import validator @@ -17,6 +17,7 @@ from olive.model import DistributedHfModelHandler, HfModelHandler from olive.passes import Pass from olive.passes.olive_pass import PassConfigParam +from olive.passes.pass_config import BasePassConfig from olive.passes.pytorch.common import inherit_distributed_hf_from_hf if TYPE_CHECKING: @@ -144,11 +145,11 @@ def _generate_one(params): return 1 # Return 1 for success. def _run_for_config( - self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> DistributedHfModelHandler: import torch - world_size = int(config["world_size"]) + world_size = int(config.world_size) output_model_path = Path(output_model_path) output_model_path.mkdir(parents=True, exist_ok=True) @@ -163,7 +164,7 @@ def _run_for_config( for rank in range(world_size) ] - max_parallel_jobs = min(world_size, config["parallel_jobs"] or multiprocessing.cpu_count()) + max_parallel_jobs = min(world_size, config.parallel_jobs or multiprocessing.cpu_count()) if max_parallel_jobs <= 1: results = [PyTorchTensorParallel._generate_one(_) for _ in params] else: diff --git a/olive/passes/pytorch/torch_trt_conversion.py b/olive/passes/pytorch/torch_trt_conversion.py index efb955fe1..0d0aa5ca0 100644 --- a/olive/passes/pytorch/torch_trt_conversion.py +++ b/olive/passes/pytorch/torch_trt_conversion.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import logging from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Type, Union import torch @@ -16,6 +16,7 @@ from olive.model import HfModelHandler, PyTorchModelHandler from olive.passes import Pass from olive.passes.olive_pass import PassConfigParam +from olive.passes.pass_config import BasePassConfig from olive.passes.pytorch.common import inherit_pytorch_from_hf from olive.passes.pytorch.sparsegpt_utils import get_layer_submodules, supported_models, validate_min_max_layers @@ -65,11 +66,10 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon @classmethod def validate_config( cls, - config: Dict[str, Any], + config: Type[BasePassConfig], accelerator_spec: AcceleratorSpec, - disable_search: Optional[bool] = False, ) -> bool: - if not super().validate_config(config, accelerator_spec, disable_search): + if not super().validate_config(config, accelerator_spec): return False # since the run will leverage the host device to move the model to device, @@ -81,7 +81,7 @@ def validate_config( @torch.no_grad() def _run_for_config( - self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str + self, model: HfModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> PyTorchModelHandler: from olive.passes.pytorch.trt_utils import compile_trt_model @@ -94,7 +94,7 @@ def _run_for_config( device = "cuda" # load_data - data_config = validate_config(config["data_config"], DataConfig) + data_config = validate_config(config.data_config, DataConfig) first_batch = data_config.to_data_container().get_first_batch()[0] first_batch = tensor_data_to_device(first_batch, device=device) batch_size = first_batch["input_ids"].shape[0] @@ -105,7 +105,7 @@ def _run_for_config( # move model to device pytorch_model.to(device=device) # convert model to fp16 if needed - if config["float16"]: + if config.float16: pytorch_model = pytorch_model.to(dtype=torch.float16) # disable cache use_cache = pytorch_model.config.use_cache @@ -121,8 +121,8 @@ def _run_for_config( layers = model_wrapper.get_layers(False) # get layer information - min_layer, max_layer = validate_min_max_layers(config["min_layer"], config["max_layer"], len(layers)) - layer_name_filter = config["layer_name_filter"] or [] + min_layer, max_layer = validate_min_max_layers(config.min_layer, config.max_layer, len(layers)) + layer_name_filter = config.layer_name_filter or [] if isinstance(layer_name_filter, str): layer_name_filter = [layer_name_filter] # layer information storage diff --git a/olive/passes/qnn/context_binary_generator.py b/olive/passes/qnn/context_binary_generator.py index db8e286dd..7c1a4b5b5 100644 --- a/olive/passes/qnn/context_binary_generator.py +++ b/olive/passes/qnn/context_binary_generator.py @@ -6,14 +6,14 @@ import logging import platform from pathlib import Path -from typing import Any, Dict, Union +from typing import Dict, Type, Union from olive.common.constants import OS from olive.constants import ModelFileFormat from olive.hardware import AcceleratorSpec from olive.model import QNNModelHandler, SNPEModelHandler from olive.passes.olive_pass import Pass -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam from olive.platform_sdk.qualcomm.runner import QNNSDKRunner logger = logging.getLogger(__name__) @@ -51,7 +51,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon def _run_for_config( self, model: Union[QNNModelHandler, SNPEModelHandler], - config: Dict[str, Any], + config: Type[BasePassConfig], output_model_path: str, ) -> QNNModelHandler: if platform.system() == OS.WINDOWS: @@ -60,7 +60,7 @@ def _run_for_config( main_cmd = "qnn-context-binary-generator" runner = QNNSDKRunner(use_dev_tools=True) - extra_args = config["extra_args"] or "" + extra_args = config.extra_args or "" model_arg = f"--model {model.model_path}" if model.model_file_format == ModelFileFormat.SNPE_DLC and "--dlc_path" not in extra_args: @@ -77,7 +77,7 @@ def _run_for_config( # TODO(trajep): find .so file in the same directory as the model output_model_path = Path(output_model_path).resolve() - binary_file = config["binary_file"] + binary_file = config.binary_file if not binary_file: binary_file = output_model_path.with_suffix(".serialized").name @@ -86,7 +86,7 @@ def _run_for_config( cmd_list = [ main_cmd, model_arg, - f"--backend {config['backend']}", + f"--backend {config.backend}", f"--output_dir {output_model_path}", f"--binary_file {binary_file}" if binary_file else "", extra_args, @@ -96,5 +96,5 @@ def _run_for_config( return QNNModelHandler( output_model_full_path, model_file_format=ModelFileFormat.QNN_SERIALIZED_BIN, - model_attributes={"backend": config["backend"]}, + model_attributes={"backend": config.backend}, ) diff --git a/olive/passes/qnn/conversion.py b/olive/passes/qnn/conversion.py index 66828900b..d988e73b1 100644 --- a/olive/passes/qnn/conversion.py +++ b/olive/passes/qnn/conversion.py @@ -5,7 +5,7 @@ import platform from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Dict, List, Type, Union from olive.common.constants import OS from olive.constants import ModelFileFormat @@ -13,7 +13,7 @@ from olive.model import ONNXModelHandler, PyTorchModelHandler, QNNModelHandler, TensorFlowModelHandler from olive.model.utils import normalize_path_suffix from olive.passes.olive_pass import Pass -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam from olive.platform_sdk.qualcomm.runner import QNNSDKRunner @@ -67,7 +67,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon def _run_for_config( self, model: Union[TensorFlowModelHandler, PyTorchModelHandler, ONNXModelHandler], - config: Dict[str, Any], + config: Type[BasePassConfig], output_model_path: str, ) -> QNNModelHandler: if isinstance(model, TensorFlowModelHandler): @@ -90,15 +90,15 @@ def _run_for_config( # get input dim from io_config input_dims = None - if config.get("input_dim"): - input_dims = map(str.split, config["input_dim"]) + if config.input_dim: + input_dims = map(str.split, config.input_dim) elif model.io_config: input_dims_tuple = zip(model.io_config["input_names"], model.io_config["input_shapes"]) input_dims = [[name, ",".join(map(str, shape))] for name, shape in input_dims_tuple] out_nodes = None - if config.get("out_node"): - out_nodes = config["out_node"] + if config.out_node: + out_nodes = config.out_node elif model.io_config: out_nodes = model.io_config["output_names"] @@ -117,8 +117,8 @@ def _run_for_config( if out_nodes: for o in out_nodes: cmd_list.extend(["--out_node", o]) - if config["extra_args"]: - cmd_list.extend(config["extra_args"].split()) + if config.extra_args: + cmd_list.extend(config.extra_args.split()) runner.run(cmd_list) return QNNModelHandler(output_model_path, model_file_format=ModelFileFormat.QNN_CPP) diff --git a/olive/passes/qnn/model_lib_generator.py b/olive/passes/qnn/model_lib_generator.py index 993541e12..85098223e 100644 --- a/olive/passes/qnn/model_lib_generator.py +++ b/olive/passes/qnn/model_lib_generator.py @@ -6,14 +6,14 @@ import logging import platform from pathlib import Path -from typing import Any, Dict +from typing import Dict, Type from olive.common.constants import OS from olive.constants import ModelFileFormat from olive.hardware import AcceleratorSpec from olive.model import QNNModelHandler from olive.passes.olive_pass import Pass -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam from olive.platform_sdk.qualcomm.runner import QNNSDKRunner logger = logging.getLogger(__name__) @@ -48,7 +48,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon def _run_for_config( self, model: QNNModelHandler, - config: Dict[str, Any], + config: Type[BasePassConfig], output_model_path: str, ) -> QNNModelHandler: main_cmd = "qnn-model-lib-generator" diff --git a/olive/passes/snpe/conversion.py b/olive/passes/snpe/conversion.py index c398ac7cf..5f3f010ba 100644 --- a/olive/passes/snpe/conversion.py +++ b/olive/passes/snpe/conversion.py @@ -3,13 +3,13 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from pathlib import Path -from typing import Any, Callable, Dict, List, Union +from typing import Callable, Dict, List, Type, Union from olive.common.pydantic_v1 import validator from olive.hardware.accelerator import AcceleratorSpec from olive.model import ONNXModelHandler, SNPEModelHandler, TensorFlowModelHandler from olive.passes.olive_pass import Pass -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam from olive.platform_sdk.qualcomm.constants import InputLayout, InputType from olive.platform_sdk.qualcomm.snpe.tools.dev import get_dlc_io_config, to_dlc from olive.resource_path import LocalFile @@ -95,12 +95,12 @@ def _validators(cls) -> Dict[str, Callable]: def _run_for_config( self, model: Union[ONNXModelHandler, TensorFlowModelHandler], - config: Dict[str, Any], + config: Type[BasePassConfig], output_model_path: str, ) -> SNPEModelHandler: if Path(output_model_path).suffix != ".dlc": output_model_path += ".dlc" to_dlc(model.model_path, model.framework, config, output_model_path) - io_config = get_dlc_io_config(output_model_path, config["input_names"], config["output_names"]) + io_config = get_dlc_io_config(output_model_path, config.input_names, config.output_names) return SNPEModelHandler(model_path=LocalFile({"path": output_model_path}), **io_config) diff --git a/olive/passes/snpe/quantization.py b/olive/passes/snpe/quantization.py index 398cfed21..27158cf77 100644 --- a/olive/passes/snpe/quantization.py +++ b/olive/passes/snpe/quantization.py @@ -3,14 +3,14 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Dict, List, Type, Union from olive.common.config_utils import validate_config from olive.data.config import DataConfig from olive.hardware.accelerator import AcceleratorSpec from olive.model import SNPEModelHandler from olive.passes.olive_pass import Pass -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam from olive.platform_sdk.qualcomm.snpe.tools.dev import quantize_dlc from olive.platform_sdk.qualcomm.utils.data_loader import FileListCommonDataLoader, FileListDataLoader from olive.resource_path import LocalFile @@ -63,12 +63,12 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon } def _run_for_config( - self, model: SNPEModelHandler, config: Dict[str, Any], output_model_path: str + self, model: SNPEModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> SNPEModelHandler: if Path(output_model_path).suffix != ".dlc": output_model_path += ".dlc" - data_config = validate_config(config["data_config"], DataConfig) + data_config = validate_config(config.data_config, DataConfig) dataloader = data_config.to_data_container().create_dataloader() # convert dataloader to FileListDataLoader if it is not already diff --git a/olive/passes/snpe/snpe_to_onnx.py b/olive/passes/snpe/snpe_to_onnx.py index bf64c957c..39dcab599 100644 --- a/olive/passes/snpe/snpe_to_onnx.py +++ b/olive/passes/snpe/snpe_to_onnx.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from typing import Any, Callable, Dict +from typing import Callable, Dict, Type from olive.common.pydantic_v1 import validator from olive.hardware.accelerator import AcceleratorSpec @@ -10,7 +10,7 @@ from olive.model.utils import resolve_onnx_path from olive.passes.olive_pass import Pass from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model -from olive.passes.pass_config import PassConfigParam +from olive.passes.pass_config import BasePassConfig, PassConfigParam from olive.platform_sdk.qualcomm.constants import SNPEDevice from olive.platform_sdk.qualcomm.snpe.tools.dev import dlc_to_onnx @@ -51,14 +51,12 @@ def _validators(cls) -> Dict[str, Callable]: } def _run_for_config( - self, model: SNPEModelHandler, config: Dict[str, Any], output_model_path: str + self, model: SNPEModelHandler, config: Type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: - config = self._config_class(**config) - output_model_path = resolve_onnx_path(output_model_path) # create a onnx model that wraps the dlc binary in a node onnx_model = dlc_to_onnx(model.model_path, config.dict(), **model.io_config) # save the model to the output path and return the model - return model_proto_to_olive_model(onnx_model, output_model_path, config.dict()) + return model_proto_to_olive_model(onnx_model, output_model_path, config) diff --git a/olive/systems/docker/docker_system.py b/olive/systems/docker/docker_system.py index 47d84d0cf..69e7e8468 100644 --- a/olive/systems/docker/docker_system.py +++ b/olive/systems/docker/docker_system.py @@ -361,8 +361,9 @@ def _run_container( def _create_data_mounts_for_pass(self, container_root_path: Path, the_pass: "Pass"): mounts = {} mount_strs = [] + config_dict = the_pass.config.dict() for param, _, category in the_pass.path_params: - param_val = the_pass.config.get(param) + param_val = config_dict.get(param) if category == ParamCategory.DATA and param_val: mount = str(container_root_path / param) mounts[param] = mount diff --git a/test/requirements-test.txt b/test/requirements-test.txt index ea8515ac0..a0d24429f 100644 --- a/test/requirements-test.txt +++ b/test/requirements-test.txt @@ -2,7 +2,8 @@ accelerate azure-ai-ml azure-identity azure-storage-blob -azureml-evaluate-mlflow>=0.0.60 +# azureml.evaluate.mlflow.hftransformers is deprecated in 0.0.66 and above +azureml-evaluate-mlflow>=0.0.60, <0.0.66 azureml-fsspec # Pin azureml-metrics[all] greater than 0.0.26 to avoid breaking change in azureml-evaluate-mlflow azureml-metrics[all]>=0.0.26 diff --git a/test/unit_test/engine/test_engine.py b/test/unit_test/engine/test_engine.py index 4e7293721..85ad2ea61 100644 --- a/test/unit_test/engine/test_engine.py +++ b/test/unit_test/engine/test_engine.py @@ -41,8 +41,7 @@ class TestEngine: def test_register(self, tmpdir): # setup - p = get_onnxconversion_pass() - name = p.__class__.__name__ + name = OnnxConversion.__name__ host = SystemConfig(type=SystemType.Local) evaluator_config = OliveEvaluatorConfig(metrics=[get_accuracy_metric(AccuracySubType.ACCURACY_SCORE)]) @@ -145,17 +144,19 @@ def test_run(self, mock_local_system, tmp_path): engine = Engine(**options) p_name = "converter" - p1, p1_run_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=13) - p2, p2_run_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=14) - p1_run_config = RunPassConfig(**p1_run_config) - p2_run_config = RunPassConfig(**p2_run_config) - engine.set_input_passes_configs({p_name: [p1_run_config, p2_run_config]}) - - p1_pass_config = p1.serialize_config(p1_run_config.config, check_object=True) - p2_pass_config = p2.serialize_config(p2_run_config.config, check_object=True) + p1: OnnxConversion = get_onnxconversion_pass(target_opset=13) + p2: OnnxConversion = get_onnxconversion_pass(target_opset=14) + engine.set_input_passes_configs( + { + p_name: [ + RunPassConfig.from_json(p1.to_json(check_object=True)), + RunPassConfig.from_json(p2.to_json(check_object=True)), + ] + } + ) model_ids = [ - engine.cache.get_output_model_id(p1.__class__.__name__, p1_pass_config, input_model_id), - engine.cache.get_output_model_id(p2.__class__.__name__, p2_pass_config, input_model_id), + engine.cache.get_output_model_id(p1.__class__.__name__, p1.config.to_json(), input_model_id), + engine.cache.get_output_model_id(p2.__class__.__name__, p2.config.to_json(), input_model_id), ] expected_res = { model_id: { @@ -257,9 +258,9 @@ def test_run_no_search(self, mock_local_system_init, tmp_path): } engine = Engine(**options) - _, p_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=13) - engine.register(OnnxConversion, config=p_config) accelerator_spec = DEFAULT_CPU_ACCELERATOR + p_config = OnnxConversion.generate_config(accelerator_spec, {"target_opset": 13}).dict() + engine.register(OnnxConversion, config=p_config) output_model_id = engine.cache.get_output_model_id( "OnnxConversion", p_config, model_config.get_model_id(), accelerator_spec @@ -331,7 +332,8 @@ def test_run_output_model(self, search_strategy, tmp_path): "evaluator": evaluator_config, } engine = Engine(**options) - _, p_config = get_onnxconversion_pass(ignore_pass_config=False, target_opset=13) + accelerator_spec = DEFAULT_CPU_ACCELERATOR + p_config = OnnxConversion.generate_config(accelerator_spec, {"target_opset": 13}).dict() engine.register(OnnxConversion, config=p_config) # output model to output_dir output_dir = tmp_path / "output_dir" @@ -339,7 +341,7 @@ def test_run_output_model(self, search_strategy, tmp_path): # execute engine.run( model_config, - [DEFAULT_CPU_ACCELERATOR], + [accelerator_spec], output_dir=output_dir, ) diff --git a/test/unit_test/evaluator/test_olive_evaluator.py b/test/unit_test/evaluator/test_olive_evaluator.py index 1a29b96e0..9618f6056 100644 --- a/test/unit_test/evaluator/test_olive_evaluator.py +++ b/test/unit_test/evaluator/test_olive_evaluator.py @@ -464,7 +464,7 @@ def test_is_accuracy_drop_tolerance(self, metric_args, is_accuracy_drop_toleranc def test_valid_custom_type_validation(self, registry_get_mock, import_user_module_mock): registry_get_mock.return_value = MagicMock() OliveEvaluatorConfig.from_json({"type": "test_evaluator"}) - registry_get_mock.called_once_with("test_evaluator") + registry_get_mock.assert_called_once_with("test_evaluator") @patch("olive.common.import_lib.import_user_module") @patch("olive.evaluator.registry.Registry.get") @@ -474,4 +474,4 @@ def test_invalid_custom_type_validation(self, registry_get_mock, import_user_mod with pytest.raises(ValidationError): OliveEvaluatorConfig.from_json({"type": "test_evaluator"}) - registry_get_mock.called_once_with("test_evaluator") + registry_get_mock.assert_called_once_with("test_evaluator") diff --git a/test/unit_test/passes/common/test_user_script.py b/test/unit_test/passes/common/test_user_script.py index c1460dea4..671a3289c 100644 --- a/test/unit_test/passes/common/test_user_script.py +++ b/test/unit_test/passes/common/test_user_script.py @@ -10,4 +10,4 @@ class TestUserScriptConfig: def test_no_config(self): config = OrtSessionParamsTuning.generate_config(DEFAULT_CPU_ACCELERATOR, disable_search=True) assert config - assert OrtSessionParamsTuning.validate_config(config, DEFAULT_CPU_ACCELERATOR, disable_search=True) + assert OrtSessionParamsTuning.validate_config(config, DEFAULT_CPU_ACCELERATOR) diff --git a/test/unit_test/passes/onnx/test_optimum_conversion.py b/test/unit_test/passes/onnx/test_optimum_conversion.py index aecda404b..0a0b53e54 100644 --- a/test/unit_test/passes/onnx/test_optimum_conversion.py +++ b/test/unit_test/passes/onnx/test_optimum_conversion.py @@ -86,13 +86,13 @@ def test_optimum_configs(config, is_valid, tmp_path): output_folder = tmp_path if not is_valid: - assert p.validate_config(config, None) is False + assert p.validate_config(p.config, None) is False with pytest.raises( ValueError, match="FP16 export is supported only when exporting on GPU. Please pass the option `--device cuda`.", ): p.run(input_model, output_folder) else: - assert p.validate_config(config, None) + assert p.validate_config(p.config, None) is True onnx_model = p.run(input_model, output_folder) assert Path(onnx_model.model_path).exists() diff --git a/test/unit_test/passes/onnx/test_session_params_tuning.py b/test/unit_test/passes/onnx/test_session_params_tuning.py index c47ed4512..fb3b58f02 100644 --- a/test/unit_test/passes/onnx/test_session_params_tuning.py +++ b/test/unit_test/passes/onnx/test_session_params_tuning.py @@ -69,14 +69,14 @@ def test_ort_session_params_tuning_with_customized_configs(mock_run, config): # assert if "providers_list" not in config: assert ( - mock_run.call_args.args[1]["providers_list"] == "CPUExecutionProvider" + mock_run.call_args.args[1].providers_list == "CPUExecutionProvider" ), "providers_list is not set correctly as ['CPUExecutionProvider'] by default when user does not specify it" if "device" not in config: assert ( - mock_run.call_args.args[1]["device"] == "cpu" + mock_run.call_args.args[1].device == "cpu" ), "device is not set correctly as cpu by default when user does not specify it" for k, v in config.items(): - assert mock_run.call_args.args[1][k] == v, f"{k} is not set correctly as {v}" + assert getattr(mock_run.call_args.args[1], k) == v, f"{k} is not set correctly as {v}" @pytest.mark.parametrize( diff --git a/test/unit_test/passes/onnx/test_transformer_optimization.py b/test/unit_test/passes/onnx/test_transformer_optimization.py index 52a7c6819..423a4b565 100644 --- a/test/unit_test/passes/onnx/test_transformer_optimization.py +++ b/test/unit_test/passes/onnx/test_transformer_optimization.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import logging import shutil -from copy import deepcopy from test.unit_test.utils import ONNX_MODEL_PATH, get_onnx_model from unittest.mock import MagicMock, patch @@ -23,7 +22,7 @@ def test_fusion_options(): config = {"model_type": "bart", "optimization_options": {"use_multi_head_attention": True}} config = OrtTransformersOptimization.generate_config(DEFAULT_CPU_ACCELERATOR, config, disable_search=True) transformer_optimization = OrtTransformersOptimization(DEFAULT_CPU_ACCELERATOR, config, True) - run_config = deepcopy(config) + run_config = config.dict() del ( run_config["float16"], run_config["input_int32"], @@ -76,7 +75,7 @@ def test_invalid_ep_config(use_gpu, fp16, accelerator_spec, mock_inferece_sessio config = {"model_type": "bert", "use_gpu": use_gpu, "float16": fp16} config = OrtTransformersOptimization.generate_config(accelerator_spec, config, disable_search=True) p = OrtTransformersOptimization(accelerator_spec, config, True) - is_pruned = not p.validate_config(config, accelerator_spec, disable_search=True) + is_pruned = not p.validate_config(config, accelerator_spec) if accelerator_spec.execution_provider == "CPUExecutionProvider": if fp16 and use_gpu: diff --git a/test/unit_test/systems/python_environment/test_python_environment_system.py b/test/unit_test/systems/python_environment/test_python_environment_system.py index 04efb6233..4087c6247 100644 --- a/test/unit_test/systems/python_environment/test_python_environment_system.py +++ b/test/unit_test/systems/python_environment/test_python_environment_system.py @@ -154,7 +154,6 @@ def test_run_pass(self, mock_model_config_parse_obj, mock__run_command): dummy_config = dummy_pass_config["config"] expected_pass_config = {"type": "DummyPass", "config": dummy_config} the_pass.to_json.return_value = dummy_pass_config - the_pass.serialize_config.return_value = dummy_config # mock return value mock_return_value = {"dummy_output_model_key": "dummy_output_model_value"} diff --git a/test/unit_test/utils.py b/test/unit_test/utils.py index e4f000796..0e83bfd44 100644 --- a/test/unit_test/utils.py +++ b/test/unit_test/utils.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import os from pathlib import Path -from typing import Any, Dict, Tuple, Type, Union +from typing import Type from unittest.mock import MagicMock import numpy as np @@ -236,14 +236,11 @@ def get_throughput_metric(*lat_subtype, user_config=None): ) -def get_onnxconversion_pass( - ignore_pass_config=True, target_opset=13 -) -> Union[Type[Pass], Tuple[Type[Pass], Dict[str, Any]]]: +def get_onnxconversion_pass(target_opset=13) -> Type[Pass]: from olive.passes.onnx.conversion import OnnxConversion onnx_conversion_config = {"target_opset": target_opset} - p = create_pass_from_dict(OnnxConversion, onnx_conversion_config) - return p if ignore_pass_config else (p, p.to_json(check_object=True)) + return create_pass_from_dict(OnnxConversion, onnx_conversion_config) def get_onnx_dynamic_quantization_pass(disable_search=False):