From aa3cc70d3c7b0122a96df52170df9d51bae428f5 Mon Sep 17 00:00:00 2001 From: Xiaoyu Date: Wed, 12 Feb 2025 11:35:57 +0000 Subject: [PATCH] Add extract adapters support for dora and loha --- .../model-opt-and-transform/onnx.md | 2 + olive/cli/generate_adapter.py | 10 +- olive/passes/onnx/common.py | 52 ++- olive/passes/onnx/extract_adapters.py | 422 +++++++++++++----- .../passes/onnx/test_extract_adapters.py | 155 ++++--- 5 files changed, 465 insertions(+), 176 deletions(-) diff --git a/docs/source/how-to/configure-workflows/model-opt-and-transform/onnx.md b/docs/source/how-to/configure-workflows/model-opt-and-transform/onnx.md index 20788cb7b..b93fa7cba 100644 --- a/docs/source/how-to/configure-workflows/model-opt-and-transform/onnx.md +++ b/docs/source/how-to/configure-workflows/model-opt-and-transform/onnx.md @@ -278,6 +278,7 @@ for an example implementation of `"user_script.py"` and `"calib_data_config/data LoRA, QLoRA and related techniques allow us to fine-tune a pre-trained model by adding a small number of trainable matrices called adapters. The same base model can be used for multiple tasks by adding different adapters for each task. To support using multiple adapters with the same optimized onnx model, the `ExtractAdapters` pass extracts the adapters weights from the model and saves them to a separate file. The model graph is then modified in one of the following ways: - Adapters weights are set as external tensors pointing to a non-existent file. The onnx model is thus invalid by itself as it cannot be loaded. In order to create an inference session using this model, the adapter weights must be added to a sessions options object using `add_initializer` or `add_external_initializers`. - Adapter weights are converted into model inputs. The onnx model is valid. During inference, the adapter weights must be provided as part of the inputs. We call them constant inputs here since these weights don't change between runs when using the one set of adapters. + - `ExtractAdapters` pass supports `DoRA` and `LoHa` as well. Add `adapter_type` to configuration to specify the adapter. The default value is `lora`. ### Example Configuration @@ -286,6 +287,7 @@ a. As external initializers ```json { "type": "ExtractAdapters", + "adapter_tpye": "dora", "make_inputs": false } ``` diff --git a/olive/cli/generate_adapter.py b/olive/cli/generate_adapter.py index f28555c53..a28bd60a0 100644 --- a/olive/cli/generate_adapter.py +++ b/olive/cli/generate_adapter.py @@ -18,6 +18,7 @@ update_shared_cache_options, ) from olive.common.utils import WeightsFileFormat, set_nested_dict_value +from olive.passes.onnx.common import AdapterType class GenerateAdapterCommand(BaseOliveCLICommand): @@ -29,7 +30,13 @@ def register_subcommand(parser: ArgumentParser): # Model options add_input_model_options(sub_parser, enable_onnx=True, default_output_path="optimized-model") - + sub_parser.add_argument( + "--adapter_type", + type=AdapterType, + default=AdapterType.LORA, + choices=[el.value for el in AdapterType], + help=f"Type of adapters to extract. Default is {AdapterType.LORA}.", + ) sub_parser.add_argument( "--adapter_format", type=str, @@ -56,6 +63,7 @@ def _get_run_config(self, tempdir: str) -> Dict: to_replace = [ ("input_model", input_model_config), (("passes", "e", "save_format"), self.args.adapter_format), + (("passes", "e", "adapter_type"), self.args.adapter_type), ("output_dir", self.args.output_path), ("log_severity_level", self.args.log_level), ] diff --git a/olive/passes/onnx/common.py b/olive/passes/onnx/common.py index 936dc5f67..3e4fd1d1a 100644 --- a/olive/passes/onnx/common.py +++ b/olive/passes/onnx/common.py @@ -9,6 +9,7 @@ import onnx +from olive.common.utils import StrEnumBase from olive.model import ONNXModelHandler from olive.passes.onnx.onnx_dag import OnnxDAG from olive.passes.pass_config import PassConfigParam @@ -17,6 +18,12 @@ logger = logging.getLogger(__name__) +class AdapterType(StrEnumBase): + LORA = "lora" + DORA = "dora" + LOHA = "loha" + + def get_external_data_config() -> Dict[str, PassConfigParam]: return { "save_as_external_data": PassConfigParam( @@ -192,18 +199,53 @@ def model_proto_to_olive_model( for name in ["default_0", "default_0_1", "default", "default_1", "lora_A", "lora_B"] for matmul in ["MatMul", "MatMul_Q4"] ] +LOHA_NAME_PATTERNS = [f".*[./]{name}[./]default" for name in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]] + +DORA_NAME_PATTERNS = [ + f".*{pattern}$" + for pattern in [ + "default/Add", + "default/default/MatMul", + "default/default_1/MatMul", + "default/default/MatMul_Q4", + "default/default_1/MatMul_Q4", + ] +] # TODO(jambayk): considering matching by subgraph pattern, more involved but more reliable -def model_has_adapters(model_path: Union[str, Path]) -> bool: +def model_has_adapters(model_path: Union[str, Path], adapter_type: AdapterType = AdapterType.LORA) -> bool: """Check if the model has adapters. :param model_path: The path to the model. :return: True if the model has adapters, False otherwise. """ dag = OnnxDAG(onnx.load(model_path, load_external_data=False)) - for node_name in dag.get_node_names(): - op_type = dag.get_node_op_type(node_name) - if op_type in {"MatMul", "MatMulNBits"} and any(re.match(pattern, node_name) for pattern in LORA_NAME_PATTERNS): - return True + if adapter_type == AdapterType.LOHA and is_loha_model(dag): + return True + else: + for node_name in dag.get_node_names(): + op_type = dag.get_node_op_type(node_name) + if (adapter_type == AdapterType.LORA and is_lora_node(op_type, node_name)) or ( + adapter_type == AdapterType.DORA and is_dora_node(op_type, node_name) + ): + return True + return False + + +def is_dora_node(op_type: str, node_name: str) -> bool: + return op_type in {"MatMul", "MatMulNBits", "Add"} and any( + re.match(pattern, node_name) for pattern in DORA_NAME_PATTERNS + ) + + +def is_lora_node(op_type: str, node_name: str) -> bool: + return op_type in {"MatMul", "MatMulNBits"} and any(re.match(pattern, node_name) for pattern in LORA_NAME_PATTERNS) + + +def is_loha_model(dag: OnnxDAG) -> bool: + for graph in dag.graphs: + for initializer in graph.initializer: + if any(re.match(pattern, initializer.name) for pattern in LOHA_NAME_PATTERNS): + return True return False diff --git a/olive/passes/onnx/extract_adapters.py b/olive/passes/onnx/extract_adapters.py index f290a8371..9ab3e09d0 100644 --- a/olive/passes/onnx/extract_adapters.py +++ b/olive/passes/onnx/extract_adapters.py @@ -6,7 +6,7 @@ import re from copy import deepcopy from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, List, Set import numpy as np import onnx @@ -16,7 +16,14 @@ from olive.model import ONNXModelHandler from olive.model.utils import resolve_onnx_path from olive.passes import Pass -from olive.passes.onnx.common import LORA_NAME_PATTERNS, get_external_data_config, model_proto_to_olive_model +from olive.passes.onnx.common import ( + DORA_NAME_PATTERNS, + LOHA_NAME_PATTERNS, + LORA_NAME_PATTERNS, + AdapterType, + get_external_data_config, + model_proto_to_olive_model, +) from olive.passes.onnx.onnx_dag import OnnxDAG from olive.passes.pass_config import PassConfigParam @@ -39,6 +46,11 @@ class ExtractAdapters(Pass): @classmethod def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]: config = { + "adapter_type": PassConfigParam( + type_=AdapterType, + default_value=AdapterType.LORA, + description=f"Type of adapter to extract. Valid values are {AdapterType.__members__.values()}.", + ), "make_inputs": PassConfigParam( type_=bool, default_value=True, @@ -85,108 +97,21 @@ def _run_for_config( # dictionary to store adapter weights weights = {} - # keep track of float and quantized modules - float_modules = set() - quant_modules = set() - - # nodes to remove at the end - nodes_to_remove = set() - for node_name in dag.get_node_names(): - op_type = dag.get_node_op_type(node_name) - if op_type not in {"MatMul", "MatMulNBits"} or not any( - re.match(pattern, node_name) for pattern in LORA_NAME_PATTERNS - ): - # not a lora module - continue - - # new name for the float weight - new_weight_name = self._create_new_weight_name(node_name) - # new names for quantized weight and parameters - # zero point is optional if symmetric - quantized_suffices = [".quant.weight", ".quant.scale", ".quant.zero_point"] - new_quantized_names = [new_weight_name.replace(".weight", suffix) for suffix in quantized_suffices] - - if op_type == "MatMul": - # float or QDQ quantized - # original weight name - old_weight_name = dag.get_node_inputs(node_name)[1] - - if dag.is_input(old_weight_name): - # nothing to do here - continue - elif dag.is_initializer(old_weight_name): - # weight is an float initializer - # create initializer with new weight name - self._externalize_initializer(dag, weights, old_weight_name, new_weight_name) - - # change input to the new name - dag.replace_node_input(node_name, old_weight_name, new_weight_name) - - # add the module to the float modules - float_modules.add(new_weight_name.replace(".weight", "")) - elif dag.get_node_op_type(dag.get_producer(old_weight_name)) == "DequantizeLinear": - # weight is QDQ quantized - # get the dequantize node - old_dequantize_name = dag.get_producer(old_weight_name) - old_dequantize_node = dag.get_node(old_dequantize_name) - - # zero point is optional so we keep track of used inputs - used_inputs = [] - # create new initializers for the dequantize node - for old_input, new_input in zip(old_dequantize_node.inputs, new_quantized_names): - self._externalize_initializer(dag, weights, old_input, new_input) - used_inputs.append(new_input) - - # create a new dequantize node - # NOTE: We could directly modify the original dequantize node but this assumes that the dequantize - # node is not used elsewhere - # this cannot be guaranteed (for instance, if the float model has lora modules with same weights, - # they might all share the same dequantize node) - new_dequantize_proto = onnx.NodeProto() - new_dequantize_proto.CopyFrom(old_dequantize_node.proto) - # change node name - new_dequantize_proto.name = new_weight_name.replace("weight", "dequantize") - # change input names - for i, new_input in enumerate(used_inputs): - new_dequantize_proto.input[i] = new_input - # change output name - new_dequantize_proto.output[0] = new_weight_name - - # add new dequantize node - dag.add_node(new_dequantize_proto, old_dequantize_node.graph_idx) - - # replace input to the new name - dag.replace_node_input(node_name, old_weight_name, new_weight_name) - - # add old dequantize node to remove - nodes_to_remove.add(old_dequantize_name) - - # add the module to the quant modules - quant_modules.add(new_weight_name.replace(".weight", ".quant")) - elif op_type == "MatMulNBits": - # weight is Nbits quantized - # create empty initializers and change node inputs - for old_input, new_input in zip(dag.get_node_inputs(node_name)[1:], new_quantized_names): - self._externalize_initializer(dag, weights, old_input, new_input) - dag.replace_node_input(node_name, old_input, new_input) - # add the module to the quant modules - quant_modules.add(new_weight_name.replace(".weight", ".quant")) + if config["adapter_type"] == AdapterType.LORA: + weights = self._extract_adapter(dag, config, adapter_type=AdapterType.LORA) + elif config["adapter_type"] == AdapterType.DORA: + weights = self._extract_adapter(dag, config, adapter_type=AdapterType.DORA) + elif config["adapter_type"] == AdapterType.LOHA: + weights = self._extract_loha_adapter(dag, config) + else: + raise ValueError(f"Unsupported adapter type: {config['adapter_type']}") if not weights: - logger.info("No lora modules found in the model. Returning the original model.") + logger.info("No %s modules found in the model. Returning the original model.", config["adapter_type"]) return model - # remove old dequantize nodes - for node_name in nodes_to_remove: - dag.remove_node(node_name) - if config["make_inputs"]: - if quant_modules and config["dynamic_lora_r"]: - # MatMulNBits has static K,N dimensions which are set as attributes - # No use case for DequantizeLinear with dynamic lora_r - logger.info("Quantized modules do not support dynamic_lora_r. Ignoring.") - # create inputs for the weights for weight_name in weights: dag.convert_initializer_to_input(weight_name) @@ -220,22 +145,287 @@ def _run_for_config( output_model.model_attributes["constant_inputs"] = weights_info return output_model + def _extract_loha_adapter(self, dag: OnnxDAG, config: Dict[str, Any]): + """Extract LoHa adapter weights from all graphs in the ONNX model. + + This version supports both normal float initializers and QDQ (DequantizeLinear) chains. + + LoHa training adds 4 trainable initializers: + hada_w1_a.default + hada_w1_b.default -> MatMul -> ... + hada_w2_a.default + hada_w2_b.default -> MatMul -> ... + + Quantization (MatMulNBits) quantizes b initializers: + hada_w1_a.default + hada_w1_b.default_Q4 + hada_w1_b.default_scales -> MatMulNBits -> ... + hada_w2_a.default + hada_w2_b.default_Q4 + hada_w2_b.default_scales -> MatMulNBits -> ... + + QDQ quantizes b initializers: + DequantizeLinear (x + x_scale + x_zero_point) + hada_w1_a.default -> MatMul -> ... + DequantizeLinear (x + x_scale + x_zero_point) + hada_w2_a.default -> MatMul -> ... + """ + weights = {} + nodes_to_remove = set() + float_modules = set() + quant_modules = set() + + for graph in dag.graphs: + initializers_to_process = list(graph.initializer) + + for initializer in initializers_to_process: + old_initializer_name = initializer.name + + if any(re.match(pattern, initializer.name) for pattern in LOHA_NAME_PATTERNS): + new_initializer_name = self._create_new_weight_name(old_initializer_name, AdapterType.LOHA) + consumer_name = dag.get_consumers(old_initializer_name)[0] + if dag.get_node_op_type(consumer_name) == "MatMulNBits": + new_initializer_name = new_initializer_name + ".quant" + quant_modules.add(new_initializer_name) + else: + float_modules.add(new_initializer_name.replace(".weight", "")) + self._process_initializer(dag, consumer_name, weights, old_initializer_name, new_initializer_name) + + # check if the 2nd weight is quantized + node_inputs = dag.get_node_inputs(consumer_name) + if len(node_inputs) < 2: + continue + sec_weight_name = dag.get_node_inputs(consumer_name)[1] + if dag.is_initializer(sec_weight_name): + continue + sec_weight_new_name = sec_weight_name + ".weight" + producer_op_name = dag.get_producer(sec_weight_name) + producer_op_type = dag.get_node_op_type(producer_op_name) + + if producer_op_type == "DequantizeLinear": + quant_suffixes = [".quant.weight", ".quant.scale", ".quant.zero_point"] + new_quant_names = [sec_weight_new_name.replace(".weight", suf) for suf in quant_suffixes] + self._process_dequantizelinear( + dag, + consumer_name, + weights, + sec_weight_name, + sec_weight_new_name, + new_quant_names, + nodes_to_remove, + ) + for node_name in nodes_to_remove: + dag.remove_node(node_name) + + if config["make_inputs"] and quant_modules and config["dynamic_lora_r"]: + # No use case for DequantizeLinear with dynamic lora_r + logger.info("Quantized modules do not support dynamic_lora_r. Ignoring.") + + return weights + + def _extract_adapter(self, dag: OnnxDAG, config: Dict[str, Any], adapter_type: AdapterType = AdapterType.LORA): + """Extract adapter weights for either LoRA or Dora from an ONNX model. + + LoRA: + output + default (lora_A) -> MatMul -> ... + output + default_1 (lora_B) -> MatMul -> ... + + DoRA: + Besides LoRA A and LoRA B, DoRA also has a learnable magnitude vector M (dora_M): + W' = mV + dV = mV + AB + AB + dora_M -> Add -> ... + """ + # dictionary to store adapter weights + weights = {} + # keep track of float and quantized modules + float_modules = set() + quant_modules = set() + + # nodes to remove at the end + nodes_to_remove = set() + + # lora and dora modules have different name patterns and valid ops + patterns = None + valid_ops = None + if adapter_type == AdapterType.LORA: + patterns = LORA_NAME_PATTERNS + valid_ops = {"MatMul", "MatMulNBits"} + if adapter_type == AdapterType.DORA: + patterns = DORA_NAME_PATTERNS + valid_ops = {"MatMul", "MatMulNBits", "Add"} + + for node_name in dag.get_node_names(): + op_type = dag.get_node_op_type(node_name) + if op_type not in valid_ops or not any(re.match(pattern, node_name) for pattern in patterns): + # not a lora module + continue + + # new name for the float weight + new_weight_name = self._create_new_weight_name(node_name, adapter_type) + # new names for quantized weight and parameters + # zero point is optional if symmetric + quantized_suffices = [".quant.weight", ".quant.scale", ".quant.zero_point"] + new_quantized_names = [new_weight_name.replace(".weight", suffix) for suffix in quantized_suffices] + + if op_type == "Add": + self._process_add_node(dag, node_name, weights, new_weight_name, float_modules) + elif op_type == "MatMul": + self._process_matmul_node( + dag, + node_name, + weights, + new_weight_name, + float_modules, + quant_modules, + new_quantized_names, + nodes_to_remove, + ) + elif op_type == "MatMulNBits": + self._process_matmulnbits_node( + dag, node_name, weights, new_weight_name, quant_modules, new_quantized_names + ) + + for node_name in nodes_to_remove: + dag.remove_node(node_name) + + if config["make_inputs"] and quant_modules and config["dynamic_lora_r"]: + # MatMulNBits has static K,N dimensions which are set as attributes + # No use case for DequantizeLinear with dynamic lora_r + logger.info("Quantized modules do not support dynamic_lora_r. Ignoring.") + + return weights + + def _process_add_node( + self, dag: OnnxDAG, node_name: str, weights: Dict[str, Any], new_weight_name: str, float_modules: Set[str] + ): + old_weight_name = dag.get_node_inputs(node_name)[0] + self._process_initializer(dag, node_name, weights, old_weight_name, new_weight_name) + # add the module to the float modules + float_modules.add(new_weight_name.replace(".weight", "")) + + def _process_matmul_node( + self, + dag: OnnxDAG, + node_name: str, + weights: Dict[str, Any], + new_weight_name: str, + float_modules: Set[str], + quant_modules: Set[str], + new_quantized_names: List[str], + nodes_to_remove: Set[str], + ): + # float or QDQ quantized + # original weight name + old_weight_name = dag.get_node_inputs(node_name)[1] + if dag.is_input(old_weight_name): + # nothing to do here + return + if dag.is_initializer(old_weight_name): + self._process_initializer(dag, node_name, weights, old_weight_name, new_weight_name) + + # add the module to the float modules + float_modules.add(new_weight_name.replace(".weight", "")) + elif dag.get_node_op_type(dag.get_producer(old_weight_name)) == "DequantizeLinear": + self._process_dequantizelinear( + dag, node_name, weights, old_weight_name, new_weight_name, new_quantized_names, nodes_to_remove + ) + + # add the module to the quant modules + quant_modules.add(new_weight_name.replace(".weight", ".quant")) + + def _process_matmulnbits_node( + self, + dag: OnnxDAG, + node_name: str, + weights: Dict[str, Any], + new_weight_name: str, + quant_modules: Set[str], + new_quantized_names: List[str], + ): + # weight is Nbits quantized + # create empty initializers and change node inputs + for old_input, new_input in zip(dag.get_node_inputs(node_name)[1:], new_quantized_names): + self._process_initializer(dag, node_name, weights, old_input, new_input) + + # add the module to the quant modules + quant_modules.add(new_weight_name.replace(".weight", ".quant")) + + def _process_initializer( + self, dag: OnnxDAG, node_name: str, weights: Dict[str, Any], old_input: str, new_input: str + ): + # create initializer with new weight name + self._externalize_initializer(dag, weights, old_input, new_input) + # change input to the new name + dag.replace_node_input(node_name, old_input, new_input) + + def _process_dequantizelinear( + self, + dag: OnnxDAG, + node_name: str, + weights: Dict[str, Any], + old_weight_name: str, + new_weight_name: str, + new_quantized_names: List[str], + nodes_to_remove: Set[str], + ): + # weight is QDQ quantized + # get the dequantize node + old_dequantize_name = dag.get_producer(old_weight_name) + old_dequantize_node = dag.get_node(old_dequantize_name) + + # zero point is optional so we keep track of used inputs + used_inputs = [] + # create new initializers for the dequantize node + for old_input, new_input in zip(old_dequantize_node.inputs, new_quantized_names): + self._externalize_initializer(dag, weights, old_input, new_input) + used_inputs.append(new_input) + + # create a new dequantize node + # NOTE: We could directly modify the original dequantize node but this assumes that the dequantize + # node is not used elsewhere + # this cannot be guaranteed (for instance, if the float model has lora modules with same weights, + # they might all share the same dequantize node) + new_dequantize_proto = onnx.NodeProto() + new_dequantize_proto.CopyFrom(old_dequantize_node.proto) + # change node name + new_dequantize_proto.name = new_weight_name.replace("weight", "dequantize") + # change input names + for i, new_input in enumerate(used_inputs): + new_dequantize_proto.input[i] = new_input + # change output name + new_dequantize_proto.output[0] = new_weight_name + + # add new dequantize node + dag.add_node(new_dequantize_proto, old_dequantize_node.graph_idx) + + # replace input to the new name + dag.replace_node_input(node_name, old_weight_name, new_weight_name) + + # add old dequantize node to remove + nodes_to_remove.add(old_dequantize_name) + @staticmethod - def _create_new_weight_name(old_name: str) -> str: + def _create_new_weight_name(old_name: str, adapter_type: AdapterType = AdapterType.LORA) -> str: """Create new weight name based on old name. - The new weight name is of the form model.layers.0.self_attn.q_proj.lora_A.quant.weight + LORA: the new weight name is of the form model.layers.0.self_attn.q_proj.lora_A.quant.weight + DORA: the new weight name is of the form model.layers.0.self_attn.q_proj.dora_A.weight for MatMul and + model.layers.0.self_attn.q_proj.dora_M.weight for Mul """ weight_name = old_name[1:] if old_name.startswith("/") else old_name - matmul_name = weight_name.split("/")[-1] - return ( - weight_name.replace("/", ".") - .replace("default.", "lora_A.") - .replace("default_1.", "lora_B.") - .replace("default_0.", "lora_A.") - .replace("default_0_1.", "lora_B.") - .replace(matmul_name, "weight") - ) + op = weight_name.split("/")[-1] + if adapter_type == AdapterType.LORA: + return ( + weight_name.replace("/", ".") + .replace("default.", "lora_A.") + .replace("default_1.", "lora_B.") + .replace("default_0.", "lora_A.") + .replace("default_0_1.", "lora_B.") + .replace(op, "weight") + ) + if adapter_type == AdapterType.DORA: + return ( + weight_name.replace("/", ".") + .replace("default.default.", "dora_A.") # For MatMul + .replace("default.default_1.", "dora_B.") # For MatMul + .replace("default.", "dora_M.") # For Add + .replace(op, "weight") + ) + if adapter_type == AdapterType.LOHA: + return weight_name.replace("default", "weight") + raise ValueError(f"Unsupported adapter type: {adapter_type}") @staticmethod def _copy_initializer(old_initializer: onnx.TensorProto, new_name: str) -> onnx.TensorProto: @@ -295,8 +485,18 @@ def _make_dynamic_optional(cls, dag: OnnxDAG, weights: Dict[str, "NDArray"], nam return - # lora r dimension index - dim_idx = 1 if "lora_A" in name else 0 + # Determine which dimension should be made dynamic + dim_idx = 1 + if config["adapter_type"] == AdapterType.LORA: + dim_idx = 1 if "lora_A" in name else 0 + elif config["adapter_type"] == AdapterType.DORA: + dim_idx = 0 if "dora_B" in name else 1 + elif config["adapter_type"] == AdapterType.LOHA: + # LoHa uses multiple Hadamard product matrices + if "hada_w1_a" in name or "hada_w2_a" in name: + dim_idx = 0 # For the first matrix in Hadamard products + else: + dim_idx = 1 # For the second matrix in Hadamard products # make the input dynamic if config["dynamic_lora_r"]: diff --git a/test/unit_test/passes/onnx/test_extract_adapters.py b/test/unit_test/passes/onnx/test_extract_adapters.py index 019480790..71938a2b4 100644 --- a/test/unit_test/passes/onnx/test_extract_adapters.py +++ b/test/unit_test/passes/onnx/test_extract_adapters.py @@ -11,14 +11,15 @@ import pytest from onnxruntime.quantization.calibrate import CalibrationDataReader from packaging import version -from peft import LoraConfig, get_peft_model +from peft import LoHaConfig, LoraConfig, get_peft_model +from peft.tuners.loha import LoHaLayer from peft.tuners.lora import LoraLayer from transformers import AutoModelForCausalLM from olive.common.utils import WeightsFileFormat, find_submodules, load_weights from olive.model import HfModelHandler, ONNXModelHandler from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.onnx.common import model_has_adapters +from olive.passes.onnx.common import AdapterType, model_has_adapters from olive.passes.onnx.conversion import OnnxConversion from olive.passes.onnx.extract_adapters import ExtractAdapters from olive.passes.onnx.quantization import OnnxMatMul4Quantizer @@ -48,29 +49,52 @@ def get_calib_data_loader(dummy_input): @pytest.fixture(name="input_model_info", scope="module") -def input_model_info_fixture(tmp_path_factory): - # this tmp_path exists for the duration of the test session - # module is scope is used to ensure that the fixture is only created once +def input_model_info_fixture(tmp_path_factory, request): tmp_path = tmp_path_factory.mktemp("extract-adapters-test") model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM" + adapter_type = request.param + use_dora = adapter_type == AdapterType.DORA + pytorch_model = AutoModelForCausalLM.from_pretrained(model_name) # init_lora_weights are set so that lora_B weights are not all zeros # if they are all zeros, the exported onnx model uses identity node as input to lora_B - peft_model = get_peft_model(pytorch_model, LoraConfig(init_lora_weights=False)) # keep track of all lora modules - all_lora_modules = [ - m.replace("base_model.model.", "") for m in find_submodules(peft_model, LoraLayer, full_name=True) or [] - ] + if adapter_type == AdapterType.LOHA: + peft_model = get_peft_model(pytorch_model, LoHaConfig(init_weights=False, target_modules="all-linear")) + all_modules = [ + m.replace("base_model.model.", "") for m in find_submodules(peft_model, LoHaLayer, full_name=True) or [] + ] + adapter_suffix = ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"] + else: + peft_model = get_peft_model(pytorch_model, LoraConfig(init_lora_weights=False, use_dora=use_dora)) + all_modules = [ + m.replace("base_model.model.", "") for m in find_submodules(peft_model, LoraLayer, full_name=True) or [] + ] + adapter_suffix = ["dora_A", "dora_B", "dora_M"] if use_dora else ["lora_A", "lora_B"] + # names of float weights - all_weights = [f"{m}.{lora_i}.weight" for m in all_lora_modules for lora_i in ["lora_A", "lora_B"]] - # names of quantized weights + all_weights = [f"{m}.{suffix}.weight" for m in all_modules for suffix in adapter_suffix] all_quant_weights = [ w.replace(".weight", suffix) for w in all_weights for suffix in [".quant.weight", ".quant.scale", ".quant.zero_point"] + if not w.endswith("M.weight") # Dora Add is not quantized ] + if adapter_type == AdapterType.DORA: + all_quant_weights += [f"{m}.dora_M.weight" for m in all_modules] + if adapter_type == AdapterType.LOHA: + int4_weight = [ + "hada_w1_a.default", + "hada_w1_b.default_Q4", + "hada_w1_b.default_scales", + "hada_w2_a.default", + "hada_w2_b.default_Q4", + "hada_w2_b.default_scales", + ] + weights = [f"{m}.{suffix.replace('default', 'weight')}" for m in all_modules for suffix in int4_weight] + all_quant_weights = [w + ".quant" for w in weights] # dump adapters adapters_path = tmp_path / "pytorch-adapters" @@ -102,6 +126,7 @@ def input_model_info_fixture(tmp_path_factory): olive_int4_onnx_model = matmul4_quantizer.run(olive_onnx_model, str(tmp_path / "int4-onnx")) return { + "adapter_type": adapter_type, "float": {"onnx_model": olive_onnx_model, "all_weights": all_weights}, # "qdq": { # "onnx_model": olive_qdq_onnx_model, @@ -115,26 +140,62 @@ def input_model_info_fixture(tmp_path_factory): } +@pytest.mark.parametrize("input_model_info", [AdapterType.LORA, AdapterType.DORA, AdapterType.LOHA], indirect=True) @pytest.mark.parametrize("model_type", [None, "float", "int4"]) def test_model_has_adapters(input_model_info, model_type): + model_info = input_model_info + adapter_type = model_info["adapter_type"] + if model_type is None: - assert not model_has_adapters(get_onnx_model().model_path) + assert not model_has_adapters(get_onnx_model().model_path, adapter_type) else: - assert model_has_adapters(input_model_info[model_type]["onnx_model"].model_path) + assert model_has_adapters(model_info[model_type]["onnx_model"].model_path, adapter_type) + + +@pytest.mark.parametrize("input_model_info", [AdapterType.LORA], indirect=True) +@pytest.mark.parametrize("quantize_int4", [1, 0]) +@pytest.mark.parametrize("adapter_format", [el.value for el in WeightsFileFormat]) +def test_convert_adapters_command(tmp_path, adapter_format, quantize_int4, input_model_info): + if adapter_format == WeightsFileFormat.ONNX_ADAPTER and version.parse(ort.__version__) < version.parse("1.20"): + pytest.skip("ONNX_ADAPTER format is only supported in onnxruntime 1.20+") + + from olive.cli.launcher import main as cli_main + + # args + suffix = ".npz" if adapter_format == WeightsFileFormat.NUMPY else f".{adapter_format}" + exported_adapters_path = tmp_path / f"exported-adapters.{suffix}" + args = [ + "convert-adapters", + "--adapter_path", + str(input_model_info["adapter_path"]), + "--output_path", + str(exported_adapters_path), + "--adapter_format", + str(adapter_format), + ] + if quantize_int4: + args.append("--quantize_int4") + # execute + cli_main(args) + # assert + assert Path(exported_adapters_path).is_file() + weight_dtype = "int4" if quantize_int4 else "float" + assert set(input_model_info[weight_dtype]["all_weights"]) == set(load_weights(exported_adapters_path)) + + +@pytest.mark.parametrize("input_model_info", [AdapterType.LORA, AdapterType.DORA, AdapterType.LOHA], indirect=True) @pytest.mark.parametrize("model_type", ["float", "qdq", "int4"]) -def test_extract_adapters_as_initializers(tmp_path, input_model_info, model_type): +def test_extract_adapters(tmp_path, model_type, input_model_info): if model_type == "qdq": pytest.skip("QDQ model test is disabled due to flaky quantization failure") + adapter_type = input_model_info["adapter_type"] - # setup - p = create_pass_from_dict( - ExtractAdapters, {"make_inputs": False, "save_format": WeightsFileFormat.NUMPY}, disable_search=True - ) - output_folder = tmp_path / "extracted-adapters" + pass_config = {"make_inputs": False, "save_format": WeightsFileFormat.NUMPY, "adapter_type": adapter_type} - # execute + p = create_pass_from_dict(ExtractAdapters, pass_config, disable_search=True) + output_folder = tmp_path / "extracted-adapters" extracted_model: ONNXModelHandler = p.run(input_model_info[model_type]["onnx_model"], output_folder) # assert @@ -154,60 +215,36 @@ def test_extract_adapters_as_initializers(tmp_path, input_model_info, model_type assert seen_weights == expected_weights +@pytest.mark.parametrize("input_model_info", [AdapterType.LORA, AdapterType.DORA, AdapterType.LOHA], indirect=True) @pytest.mark.parametrize("model_type", ["float", "qdq", "int4"]) @pytest.mark.parametrize("save_format", [el.value for el in WeightsFileFormat]) -def test_extract_adapters_as_inputs(tmp_path, input_model_info, save_format, model_type): +def test_extract_adapters_as_inputs(tmp_path, save_format, model_type, input_model_info): if model_type == "qdq": pytest.skip("QDQ model test is disabled due to flaky quantization failure") if save_format == WeightsFileFormat.ONNX_ADAPTER and version.parse(ort.__version__) < version.parse("1.20"): pytest.skip("ONNX_ADAPTER format is only supported in onnxruntime 1.20+") + adapter_type = input_model_info["adapter_type"] + if adapter_type == AdapterType.DORA and model_type == "int4": + pytest.skip("DORA model test is disabled for int4 model") - # setup - p = create_pass_from_dict(ExtractAdapters, {"save_format": save_format}, disable_search=True) + # Create the configuration for the pass + pass_config = {"save_format": save_format, "adapter_type": adapter_type} + p = create_pass_from_dict(ExtractAdapters, pass_config, disable_search=True) output_folder = tmp_path / "extracted-adapters" - # execute + # Execute the pass extracted_model: ONNXModelHandler = p.run(input_model_info[model_type]["onnx_model"], output_folder) io_config = extracted_model.io_config - # assert + # Assertions assert Path(extracted_model.model_path).is_file() assert Path(extracted_model.constant_inputs_path).is_file() + expected_weights = set(input_model_info[model_type]["all_weights"]) - # all lora weights should be extracted as constant inputs + + # Check if all expected weights are extracted assert expected_weights == set(extracted_model.model_attributes["constant_inputs"]) assert expected_weights == set(load_weights(extracted_model.constant_inputs_path)) - # ensure all constant inputs are marked as such - assert all(i in io_config["input_names"] for i in expected_weights) - -@pytest.mark.parametrize("quantize_int4", [1, 0]) -@pytest.mark.parametrize("adapter_format", [el.value for el in WeightsFileFormat]) -def test_convert_adapters_command(tmp_path, input_model_info, adapter_format, quantize_int4): - if adapter_format == WeightsFileFormat.ONNX_ADAPTER and version.parse(ort.__version__) < version.parse("1.20"): - pytest.skip("ONNX_ADAPTER format is only supported in onnxruntime 1.20+") - - from olive.cli.launcher import main as cli_main - - # args - suffix = ".npz" if adapter_format == WeightsFileFormat.NUMPY else f".{adapter_format}" - exported_adapters_path = tmp_path / f"exported-adapters.{suffix}" - args = [ - "convert-adapters", - "--adapter_path", - str(input_model_info["adapter_path"]), - "--output_path", - str(exported_adapters_path), - "--adapter_format", - str(adapter_format), - ] - if quantize_int4: - args.append("--quantize_int4") - - # execute - cli_main(args) - - # assert - assert Path(exported_adapters_path).is_file() - weight_dtype = "int4" if quantize_int4 else "float" - assert set(input_model_info[weight_dtype]["all_weights"]) == set(load_weights(exported_adapters_path)) + # Ensure all constant inputs are present in the input names + assert all(i in io_config["input_names"] for i in expected_weights)