Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extract adapters support for dora and loha #1611

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ for an example implementation of `"user_script.py"` and `"calib_data_config/data
LoRA, QLoRA and related techniques allow us to fine-tune a pre-trained model by adding a small number of trainable matrices called adapters. The same base model can be used for multiple tasks by adding different adapters for each task. To support using multiple adapters with the same optimized onnx model, the `ExtractAdapters` pass extracts the adapters weights from the model and saves them to a separate file. The model graph is then modified in one of the following ways:
- Adapters weights are set as external tensors pointing to a non-existent file. The onnx model is thus invalid by itself as it cannot be loaded. In order to create an inference session using this model, the adapter weights must be added to a sessions options object using `add_initializer` or `add_external_initializers`.
- Adapter weights are converted into model inputs. The onnx model is valid. During inference, the adapter weights must be provided as part of the inputs. We call them constant inputs here since these weights don't change between runs when using the one set of adapters.
- `ExtractAdapters` pass supports `DoRA` and `LoHa` as well. Add `adapter_type` to configuration to specify the adapter. The default value is `lora`.

### Example Configuration

Expand All @@ -286,6 +287,7 @@ a. As external initializers
```json
{
"type": "ExtractAdapters",
"adapter_tpye": "dora",
"make_inputs": false
}
```
Expand Down
10 changes: 9 additions & 1 deletion olive/cli/generate_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
update_shared_cache_options,
)
from olive.common.utils import WeightsFileFormat, set_nested_dict_value
from olive.passes.onnx.common import AdapterType


class GenerateAdapterCommand(BaseOliveCLICommand):
Expand All @@ -29,7 +30,13 @@ def register_subcommand(parser: ArgumentParser):

# Model options
add_input_model_options(sub_parser, enable_onnx=True, default_output_path="optimized-model")

sub_parser.add_argument(
"--adapter_type",
type=AdapterType,
default=AdapterType.LORA,
choices=[el.value for el in AdapterType],
help=f"Type of adapters to extract. Default is {AdapterType.LORA}.",
)
sub_parser.add_argument(
"--adapter_format",
type=str,
Expand All @@ -56,6 +63,7 @@ def _get_run_config(self, tempdir: str) -> Dict:
to_replace = [
("input_model", input_model_config),
(("passes", "e", "save_format"), self.args.adapter_format),
(("passes", "e", "adapter_type"), self.args.adapter_type),
("output_dir", self.args.output_path),
("log_severity_level", self.args.log_level),
]
Expand Down
52 changes: 47 additions & 5 deletions olive/passes/onnx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import onnx

from olive.common.utils import StrEnumBase
from olive.model import ONNXModelHandler
from olive.passes.onnx.onnx_dag import OnnxDAG
from olive.passes.pass_config import PassConfigParam
Expand All @@ -17,6 +18,12 @@
logger = logging.getLogger(__name__)


class AdapterType(StrEnumBase):
LORA = "lora"
DORA = "dora"
LOHA = "loha"


def get_external_data_config() -> Dict[str, PassConfigParam]:
return {
"save_as_external_data": PassConfigParam(
Expand Down Expand Up @@ -192,18 +199,53 @@ def model_proto_to_olive_model(
for name in ["default_0", "default_0_1", "default", "default_1", "lora_A", "lora_B"]
for matmul in ["MatMul", "MatMul_Q4"]
]
LOHA_NAME_PATTERNS = [f".*[./]{name}[./]default" for name in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]]

DORA_NAME_PATTERNS = [
f".*{pattern}$"
for pattern in [
"default/Add",
"default/default/MatMul",
"default/default_1/MatMul",
"default/default/MatMul_Q4",
"default/default_1/MatMul_Q4",
]
]


# TODO(jambayk): considering matching by subgraph pattern, more involved but more reliable
def model_has_adapters(model_path: Union[str, Path]) -> bool:
def model_has_adapters(model_path: Union[str, Path], adapter_type: AdapterType = AdapterType.LORA) -> bool:
"""Check if the model has adapters.

:param model_path: The path to the model.
:return: True if the model has adapters, False otherwise.
"""
dag = OnnxDAG(onnx.load(model_path, load_external_data=False))
for node_name in dag.get_node_names():
op_type = dag.get_node_op_type(node_name)
if op_type in {"MatMul", "MatMulNBits"} and any(re.match(pattern, node_name) for pattern in LORA_NAME_PATTERNS):
return True
if adapter_type == AdapterType.LOHA and is_loha_model(dag):
return True
else:
for node_name in dag.get_node_names():
op_type = dag.get_node_op_type(node_name)
if (adapter_type == AdapterType.LORA and is_lora_node(op_type, node_name)) or (
adapter_type == AdapterType.DORA and is_dora_node(op_type, node_name)
):
return True
return False


def is_dora_node(op_type: str, node_name: str) -> bool:
return op_type in {"MatMul", "MatMulNBits", "Add"} and any(
re.match(pattern, node_name) for pattern in DORA_NAME_PATTERNS
)


def is_lora_node(op_type: str, node_name: str) -> bool:
return op_type in {"MatMul", "MatMulNBits"} and any(re.match(pattern, node_name) for pattern in LORA_NAME_PATTERNS)


def is_loha_model(dag: OnnxDAG) -> bool:
for graph in dag.graphs:
for initializer in graph.initializer:
if any(re.match(pattern, initializer.name) for pattern in LOHA_NAME_PATTERNS):
return True
return False
Loading