Skip to content

Commit 64c66f9

Browse files
committed
Add DoRA pass
1 parent f4c4f38 commit 64c66f9

File tree

10 files changed

+102
-15
lines changed

10 files changed

+102
-15
lines changed

docs/source/reference/options.md

+1
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ Please also find the detailed options from following table for each pass:
410410
| [SplitModel](../../reference/pass.rst#_split_model) | Split an ONNX model into multiple smaller sub-models based on predefined assignments. |
411411
| [LoRA](../../reference/pass.rst#_lora) | Run LoRA fine-tuning on a Hugging Face PyTorch model. |
412412
| [QLoRA](../../reference/pass.rst#_qlora) | Run QLoRA fine-tuning on a Hugging Face PyTorch model. |
413+
| [DoRA](../../reference/pass.rst#_dora) | Run DoRA fine-tuning on a Hugging Face PyTorch model. |
413414
| [LoftQ](../../reference/pass.rst#_loftq) | Run LoftQ fine-tuning on a Hugging Face PyTorch model. |
414415
| [QuantizationAwareTraining](../../reference/pass.rst#_onnx_quantization_aware_training) | Run quantization aware training on PyTorch model. |
415416
| [OpenVINOConversion](../../reference/pass.rst#_openvino_conversion) | Converts PyTorch, ONNX or TensorFlow Model to OpenVino Model. |

docs/source/reference/pass.rst

+6
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,12 @@ QLoRA
197197
-----
198198
.. autoconfigclass:: olive.passes.QLoRA
199199

200+
.. _dora:
201+
202+
DoRA
203+
-----
204+
.. autoconfigclass:: olive.passes.DoRA
205+
200206
.. _loftq:
201207

202208
LoftQ

examples/llama2/README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Sample use cases of Olive to optimize a [Llama2](https://huggingface.co/meta-lla
66
- [Optimization Workflows](#optimization-workflows)
77
- [Inference optimization using ONNX Runtime Tools](#inference-optimization-using-onnx-runtime-tools)
88
- [Inference optimization with ONNNX Runtime with DirectML](#inference-optimization-with-onnnx-runtime-with-directml)
9-
- [Fine-tune on a code generation dataset using QLoRA and optimize using ONNX Runtime Tools](#fine-tune-on-a-code-generation-dataset-using-qlora-and-optimize-using-onnx-runtime-tools)
9+
- [Fine-tune on a code generation dataset using QLoRA or Dora and optimize using ONNX Runtime Tools](#fine-tune-on-a-code-generation-dataset-using-qlora-or-dora-and-optimize-using-onnx-runtime-tools)
1010
- [Inference optimization using ONNX Runtime GenAI](#inference-optimization-using-onnx-runtime-genai)
1111
- [Quantization using GPTQ and do text generation using ONNX Runtime with Optimum](#quantization-using-gptq-and-do-text-generation-using-onnx-runtime-with-optimum)
1212
- [Prerequisites](#prerequisites)
@@ -172,12 +172,12 @@ python llama2.py --model_name meta-llama/Llama-2-7b-hf --gpu --use_gqa
172172
python llama2.py --model_name meta-llama/Llama-2-7b-hf --gpu --use_gptq
173173
```
174174

175-
### Fine-tune on a code generation dataset using QLoRA and optimize using ONNX Runtime Tools
175+
### Fine-tune on a code generation dataset using QLoRA or Dora and optimize using ONNX Runtime Tools
176176

177177
Run the following command to execute the workflow:
178178

179179
```bash
180-
python llama2.py --qlora
180+
python llama2.py --qlora/--dora
181181
```
182182

183183
### Running Workflows on the Cloud

examples/llama2/llama2.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ def get_args(raw_args):
7171
required=False,
7272
help="Whether to use qlora for optimization. Only supported on gpu.",
7373
)
74+
parser.add_argument(
75+
"--dora",
76+
action="store_true",
77+
required=False,
78+
help="Whether to use dora for optimization.",
79+
)
7480
parser.add_argument(
7581
"--account_name",
7682
type=str,
@@ -103,8 +109,8 @@ def main(raw_args=None):
103109
if args.use_gqa and not args.gpu:
104110
raise ValueError("GQA is only supported on gpu.")
105111

106-
if args.qlora:
107-
template_json, config_name = get_qlora_config()
112+
if args.qlora or args.dora:
113+
template_json, config_name = get_lora_config(args.qlora)
108114
else:
109115
template_json, config_name = get_general_config(args)
110116

@@ -152,10 +158,12 @@ def get_valid_config(config, key, default=None):
152158
raise ValueError(f"Key {key} is required in the config file.")
153159

154160

155-
def get_qlora_config():
156-
with open("llama2_qlora.json") as f:
161+
def get_lora_config(use_qlora):
162+
with open("llama2_lora_base.json") as f:
157163
template_json = json.load(f)
158-
return template_json, "llama2_gpu_qlora"
164+
lora_type = "qlora" if use_qlora else "dora"
165+
template_json["passes"]["f"]["type"] = lora_type
166+
return template_json, f"llama2_gpu_{lora_type}"
159167

160168

161169
def get_general_config(args):
File renamed without changes.

examples/test/azureml/test_llama2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def setup():
2626
@pytest.mark.parametrize("execution_order", [None])
2727
@pytest.mark.parametrize("system", ["local_system"])
2828
@pytest.mark.parametrize("cache_config", [None, {"account_name": account_name, "container_name": container_name}])
29-
@pytest.mark.parametrize("olive_json", ["llama2_qlora.json"])
29+
@pytest.mark.parametrize("olive_json", ["llama2_lora_base.json"])
3030
def test_llama2(search_algorithm, execution_order, system, cache_config, olive_json):
3131
from olive.workflows import run as olive_run
3232

examples/test/local/test_llama2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def setup():
1919
@pytest.mark.parametrize("search_algorithm", [False])
2020
@pytest.mark.parametrize("execution_order", [None])
2121
@pytest.mark.parametrize("system", ["local_system"])
22-
@pytest.mark.parametrize("olive_json", ["llama2_qlora.json"])
22+
@pytest.mark.parametrize("olive_json", ["llama2_lora_base.json"])
2323
def test_llama2(search_algorithm, execution_order, system, olive_json):
2424
from onnxruntime import __version__ as ort_version
2525

olive/olive_config.json

+8-1
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,13 @@
278278
"supported_precisions": [ "*" ],
279279
"extra_dependencies": [ "bnb", "lora" ]
280280
},
281+
"DoRA": {
282+
"module_path": "olive.passes.pytorch.lora.DoRA",
283+
"supported_providers": [ "*" ],
284+
"supported_accelerators": [ "*" ],
285+
"supported_precisions": [ "*" ],
286+
"extra_dependencies": [ "bnb", "lora" ]
287+
},
281288
"QuantizationAwareTraining": {
282289
"module_path": "olive.passes.pytorch.quantization_aware_training.QuantizationAwareTraining",
283290
"supported_providers": [ "*" ],
@@ -367,7 +374,7 @@
367374
"flash-attn": [ "flash_attn" ],
368375
"gpu": [ "onnxruntime-gpu" ],
369376
"inc": [ "neural-compressor" ],
370-
"lora": [ "accelerate>=0.30.0", "peft", "scipy" ],
377+
"lora": [ "accelerate>=0.30.0", "peft>=0.12.0", "scipy" ],
371378
"nvmo": [ "nvidia-modelopt", "onnx-graphsurgeon", "datasets>=2.14.4", "cppimport==22.8.2" ],
372379
"openvino": [ "openvino==2023.2.0", "nncf==2.7.0", "numpy<2.0" ],
373380
"optimum": [ "optimum" ],

olive/passes/pytorch/lora.py

+60-3
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,9 @@ def init_lora_adapters(
369369
task: str,
370370
config: ConfigBase,
371371
target_modules: Optional[List[str]] = None,
372+
ephemeral_gpu_offload: Optional[bool] = False,
372373
use_loftq: Optional[bool] = False,
374+
use_dora: Optional[bool] = False,
373375
) -> "PeftModel":
374376
"""Initialize LoRA adapters.
375377
@@ -380,7 +382,7 @@ def init_lora_adapters(
380382
:param use_loftq: Whether to use LoftQ to initialize weights.
381383
:return: The LoRA model.
382384
"""
383-
from peft import LoraConfig, get_peft_model
385+
from peft import LoraConfig, LoraRuntimeConfig, get_peft_model
384386

385387
lora_config_kwargs = {}
386388
if use_loftq:
@@ -391,6 +393,11 @@ def init_lora_adapters(
391393
"loftq_config": LoftQConfig(loftq_bits=4, loftq_iter=config.loftq_iter),
392394
}
393395

396+
if use_dora:
397+
lora_config_kwargs = {
398+
"use_dora": True,
399+
}
400+
394401
peft_task_type = get_peft_task_type_from_task(task, fail_on_not_found=True)
395402
lora_config = LoraConfig(
396403
r=config.lora_r,
@@ -400,6 +407,7 @@ def init_lora_adapters(
400407
bias="none",
401408
task_type=peft_task_type,
402409
modules_to_save=config.modules_to_save,
410+
runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=ephemeral_gpu_offload),
403411
**lora_config_kwargs,
404412
)
405413

@@ -412,6 +420,8 @@ def enable_lora(
412420
config: ConfigBase,
413421
adapter_path: Optional[str] = None,
414422
target_modules: Optional[List[str]] = None,
423+
use_dora: Optional[bool] = False,
424+
ephemeral_gpu_offload: Optional[bool] = False,
415425
) -> "PeftModel":
416426
"""Enable LoRA fine-tuning on a Hugging Face PyTorch model.
417427
@@ -425,6 +435,8 @@ def enable_lora(
425435
:param config: The config for the pass run.
426436
:param adapter_path: Path to the adapter weights. If None, will initialize new adapters.
427437
:param target_modules: List of modules to target for LoRA fine-tuning. Only used if adapter_path is None.
438+
:param use_dora: Whether to use Dora for optimization.
439+
:param ephemeral_gpu_offload: Whether to use ephemeral GPU offload.
428440
:return: The LoRA model.
429441
"""
430442
from peft import PeftModel
@@ -450,10 +462,19 @@ def enable_lora(
450462
)
451463
if not adapter_path:
452464
logger.debug("Initializing LoRA adapters from config")
453-
lora_model = self.init_lora_adapters(model, task, config, target_modules=target_modules)
465+
lora_model = self.init_lora_adapters(
466+
model,
467+
task,
468+
config,
469+
target_modules=target_modules,
470+
use_dora=use_dora,
471+
ephemeral_gpu_offload=ephemeral_gpu_offload,
472+
)
454473
else:
455474
logger.debug("Loading LoRA adapters from %s", adapter_path)
456-
lora_model = PeftModel.from_pretrained(model, adapter_path, is_trainable=True)
475+
lora_model = PeftModel.from_pretrained(
476+
model, adapter_path, is_trainable=True, ephemeral_gpu_offload=ephemeral_gpu_offload
477+
)
457478
logger.debug(
458479
"The number of trainable parameters in the LoRA model: %s", self.count_trainable_parameters(lora_model)
459480
)
@@ -641,6 +662,42 @@ def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_
641662
)
642663

643664

665+
class DoRA(LoRABase):
666+
"""Run DoRA fine-tuning on a Hugging Face PyTorch model."""
667+
668+
@classmethod
669+
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
670+
config = {
671+
"ephemeral_gpu_offload": PassConfigParam(
672+
type_=bool, default_value=False, description="Ephemeral GPU offload"
673+
),
674+
}
675+
config.update(super()._default_config(accelerator_spec))
676+
return config
677+
678+
def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler:
679+
# convert config to pass config class
680+
# this will validate the config and convert to the correct types
681+
config = self._config_class(**config)
682+
683+
# check dependencies
684+
self.check_dependencies(config)
685+
686+
# use default training args if not provided
687+
config.training_args = config.training_args or HFTrainingArguments()
688+
689+
# get new model
690+
pytorch_model = self.load_base_pytorch_model(model, config)
691+
692+
# add lora modules
693+
pytorch_model = self.enable_lora(pytorch_model, model.task, config, use_dora=True)
694+
695+
# train and return new model
696+
return self.train_and_save_new_model(
697+
pytorch_model, model.get_hf_tokenizer(), config, deepcopy(model), output_model_path
698+
)
699+
700+
644701
class QLoRABase(LoRABase):
645702
"""Base class for QLoRA and LoftQ fine-tuning passes."""
646703

test/unit_test/passes/pytorch/test_lora.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from olive.data.template import huggingface_data_config_template
1313
from olive.model import HfModelHandler
1414
from olive.passes.olive_pass import create_pass_from_dict
15-
from olive.passes.pytorch.lora import LoftQ, LoRA, QLoRA
15+
from olive.passes.pytorch.lora import DoRA, LoftQ, LoRA, QLoRA
1616

1717
# pylint: disable=redefined-outer-name
1818

@@ -112,3 +112,11 @@ def test_loftq(tmp_path):
112112
# assert
113113
assert Path(out.get_resource("model_path")).exists()
114114
assert Path(out.get_resource("adapter_path")).exists()
115+
116+
117+
def test_dora(tmp_path):
118+
# execute
119+
out = run_finetuning(DoRA, tmp_path, torch_dtype="float32")
120+
121+
# assert
122+
assert Path(out.get_resource("adapter_path")).exists()

0 commit comments

Comments
 (0)