Skip to content

Commit

Permalink
Add DoRA pass
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyu-work committed Feb 5, 2025
1 parent cd3cefb commit b210e65
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/options.md
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ Please also find the detailed options from following table for each pass:
| [LoHa](../../reference/pass.rst#_loha) | Run LoHa fine-tuning on a Hugging Face PyTorch model. |
| [LoKr](../../reference/pass.rst#_lokr) | Run LoKr fine-tuning on a Hugging Face PyTorch model. |
| [QLoRA](../../reference/pass.rst#_qlora) | Run QLoRA fine-tuning on a Hugging Face PyTorch model. |
| [DoRA](../../reference/pass.rst#_dora) | Run DoRA fine-tuning on a Hugging Face PyTorch model. |
| [LoftQ](../../reference/pass.rst#_loftq) | Run LoftQ fine-tuning on a Hugging Face PyTorch model. |
| [QuantizationAwareTraining](../../reference/pass.rst#_onnx_quantization_aware_training) | Run quantization aware training on PyTorch model. |
| [OpenVINOConversion](../../reference/pass.rst#_openvino_conversion) | Converts PyTorch, ONNX or TensorFlow Model to OpenVino Model. |
Expand Down
6 changes: 6 additions & 0 deletions docs/source/reference/pass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ QLoRA
-----
.. autoconfigclass:: olive.passes.QLoRA

.. _dora:

DoRA
-----
.. autoconfigclass:: olive.passes.DoRA

.. _loftq:

LoftQ
Expand Down
9 changes: 8 additions & 1 deletion olive/olive_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,13 @@
"supported_precisions": [ "*" ],
"extra_dependencies": [ "bnb", "lora" ]
},
"DoRA": {
"module_path": "olive.passes.pytorch.lora.DoRA",
"supported_providers": [ "*" ],
"supported_accelerators": [ "*" ],
"supported_precisions": [ "*" ],
"extra_dependencies": [ "bnb", "lora" ]
},
"QuantizationAwareTraining": {
"module_path": "olive.passes.pytorch.quantization_aware_training.QuantizationAwareTraining",
"supported_providers": [ "*" ],
Expand Down Expand Up @@ -387,7 +394,7 @@
"flash-attn": [ "flash_attn" ],
"gpu": [ "onnxruntime-gpu" ],
"inc": [ "neural-compressor" ],
"lora": [ "accelerate>=0.30.0", "peft", "scipy" ],
"lora": [ "accelerate>=0.30.0", "peft>=0.12.0", "scipy" ],
"nvmo": [ "nvidia-modelopt", "onnx-graphsurgeon", "datasets>=2.14.4", "cppimport==22.8.2" ],
"openvino": [ "openvino==2023.2.0", "nncf==2.7.0", "numpy<2.0" ],
"optimum": [ "optimum" ],
Expand Down
30 changes: 27 additions & 3 deletions olive/passes/pytorch/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
"see 'https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices'"
),
),
"ephemeral_gpu_offload": PassConfigParam(
type_=bool, default_value=False, description="Ephemeral GPU offload"
),
# data parameters
"train_data_config": PassConfigParam(
type_=Union[DataConfig, Dict],
Expand Down Expand Up @@ -236,6 +239,11 @@ def get_datasets(
return train_dataset, eval_dataset

def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler:
return self._run_lora_training(model, config, output_model_path)

def _run_lora_training(
self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str, use_dora: bool = False
) -> HfModelHandler:
# convert config to pass config class
# this will validate the config and convert to the correct types
config = self._config_class(**config)
Expand All @@ -260,7 +268,7 @@ def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_
# disable exllama (gptq pass disables it)

# add lora modules
pytorch_model = self.enable_lora(pytorch_model, config, model.task)
pytorch_model = self.enable_lora(pytorch_model, config, model.task, use_dora=use_dora)

# train and return new model
return self.train_and_save_new_model(
Expand Down Expand Up @@ -295,13 +303,15 @@ def init_adapters(
*,
task: Optional[str] = None,
use_loftq: Optional[bool] = False,
use_dora: Optional[bool] = False,
) -> "PeftModel":
"""Initialize LoRA adapters.
:param model: The Hugging Face PyTorch model to add LoRA adapters to.
:param config: The config for the pass run.
:param task: The task type of the model.
:param use_loftq: Whether to use LoftQ to initialize weights.
:param use_dora: Whether to use DoRA to initialize weights.
:return: The LoRA model.
"""
config_kwargs = {}
Expand All @@ -312,6 +322,10 @@ def init_adapters(
"init_lora_weights": "loftq",
"loftq_config": LoftQConfig(loftq_bits=4, loftq_iter=config.loftq_iter),
}
if use_dora:
config_kwargs = {
"use_dora": True,
}
if task:
config_kwargs.update({"task_type": get_peft_task_type_from_task(task, fail_on_not_found=True)})

Expand All @@ -322,6 +336,7 @@ def enable_lora(
model: "PreTrainedModel",
config: ConfigBase,
task: Optional[str] = None,
use_dora: bool = False,
adapter_path: Optional[str] = None,
) -> "PeftModel":
"""Enable LoRA fine-tuning on a Hugging Face PyTorch model.
Expand All @@ -333,6 +348,7 @@ def enable_lora(
:param model: The Hugging Face PyTorch model to enable LoRA fine-tuning on.
:param config: The config for the pass run.
:param task: The task type of the model.
:param use_dora: Whether to use DoRA to train adapters.
:param adapter_path: Path to the adapter weights. If None, will initialize new adapters.
:return: The LoRA model.
"""
Expand All @@ -351,7 +367,7 @@ def enable_lora(

if not adapter_path:
logger.debug("Initializing LoRA adapters from config")
lora_model = self.init_adapters(model, config, task=task)
lora_model = self.init_adapters(model, config, task=task, use_dora=use_dora)
else:
from peft import PeftModel

Expand Down Expand Up @@ -492,7 +508,7 @@ def get_target_modules(model: HfModelHandler) -> Optional[List[str]]:
@staticmethod
def get_peft_model(model: "PreTrainedModel", config: ConfigBase, config_kwargs: Dict = None) -> "PeftModel":
"""Get the PEFT model for LoRA fine-tuning."""
from peft import LoraConfig, get_peft_model
from peft import LoraConfig, LoraRuntimeConfig, get_peft_model

if config_kwargs is None:
config_kwargs = {}
Expand All @@ -504,12 +520,20 @@ def get_peft_model(model: "PreTrainedModel", config: ConfigBase, config_kwargs:
target_modules=config.target_modules,
bias="none",
modules_to_save=config.modules_to_save,
runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=config.ephemeral_gpu_offload),
**config_kwargs,
)

return get_peft_model(model, lora_config)


class DoRA(LoRA):
"""Run DoRA fine-tuning on a Hugging Face PyTorch model."""

def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler:
return self._run_lora_training(model, config, output_model_path, use_dora=True)


class LoRAVariant(LoRA):
"""Run LoRA variant fine-tuning on a Hugging Face PyTorch model."""

Expand Down
13 changes: 12 additions & 1 deletion test/unit_test/passes/pytorch/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from olive.data.template import huggingface_data_config_template
from olive.model import HfModelHandler
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.pytorch.lora import LoftQ, LoHa, LoKr, LoRA, QLoRA
from olive.passes.pytorch.lora import DoRA, LoftQ, LoHa, LoKr, LoRA, QLoRA

# pylint: disable=redefined-outer-name

Expand Down Expand Up @@ -138,5 +138,16 @@ def test_lokr(tmp_path):
LoKr, tmp_path, torch_dtype="float16", training_args={"remove_unused_columns": False, "save_safetensors": False}
)

assert Path(out.get_resource("adapter_path")).exists()


@pytest.mark.skipif(
platform.system() == OS.WINDOWS or not torch.cuda.is_available(),
reason="bitsandbytes requires Linux GPU.",
)
def test_dora(tmp_path):
# execute
out = run_finetuning(DoRA, tmp_path, torch_dtype="float32")

# assert
assert Path(out.get_resource("adapter_path")).exists()

0 comments on commit b210e65

Please sign in to comment.