Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DoRA pass #1579

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/options.md
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ Please also find the detailed options from following table for each pass:
| [SplitModel](../../reference/pass.rst#_split_model) | Split an ONNX model into multiple smaller sub-models based on predefined assignments. |
| [LoRA](../../reference/pass.rst#_lora) | Run LoRA fine-tuning on a Hugging Face PyTorch model. |
| [QLoRA](../../reference/pass.rst#_qlora) | Run QLoRA fine-tuning on a Hugging Face PyTorch model. |
| [DoRA](../../reference/pass.rst#_dora) | Run DoRA fine-tuning on a Hugging Face PyTorch model. |
| [LoftQ](../../reference/pass.rst#_loftq) | Run LoftQ fine-tuning on a Hugging Face PyTorch model. |
| [QuantizationAwareTraining](../../reference/pass.rst#_onnx_quantization_aware_training) | Run quantization aware training on PyTorch model. |
| [OpenVINOConversion](../../reference/pass.rst#_openvino_conversion) | Converts PyTorch, ONNX or TensorFlow Model to OpenVino Model. |
Expand Down
6 changes: 6 additions & 0 deletions docs/source/reference/pass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ QLoRA
-----
.. autoconfigclass:: olive.passes.QLoRA

.. _dora:

DoRA
-----
.. autoconfigclass:: olive.passes.DoRA

.. _loftq:

LoftQ
Expand Down
6 changes: 3 additions & 3 deletions examples/llama2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Sample use cases of Olive to optimize a [Llama2](https://huggingface.co/meta-lla
- [Optimization Workflows](#optimization-workflows)
- [Inference optimization using ONNX Runtime Tools](#inference-optimization-using-onnx-runtime-tools)
- [Inference optimization with ONNNX Runtime with DirectML](#inference-optimization-with-onnnx-runtime-with-directml)
- [Fine-tune on a code generation dataset using QLoRA and optimize using ONNX Runtime Tools](#fine-tune-on-a-code-generation-dataset-using-qlora-and-optimize-using-onnx-runtime-tools)
- [Fine-tune on a code generation dataset using QLoRA or Dora and optimize using ONNX Runtime Tools](#fine-tune-on-a-code-generation-dataset-using-qlora-or-dora-and-optimize-using-onnx-runtime-tools)
- [Inference optimization using ONNX Runtime GenAI](#inference-optimization-using-onnx-runtime-genai)
- [Quantization using GPTQ and do text generation using ONNX Runtime with Optimum](#quantization-using-gptq-and-do-text-generation-using-onnx-runtime-with-optimum)
- [Prerequisites](#prerequisites)
Expand Down Expand Up @@ -172,12 +172,12 @@ python llama2.py --model_name meta-llama/Llama-2-7b-hf --gpu --use_gqa
python llama2.py --model_name meta-llama/Llama-2-7b-hf --gpu --use_gptq
```

### Fine-tune on a code generation dataset using QLoRA and optimize using ONNX Runtime Tools
### Fine-tune on a code generation dataset using QLoRA or Dora and optimize using ONNX Runtime Tools

Run the following command to execute the workflow:

```bash
python llama2.py --qlora
python llama2.py --qlora/--dora
```

### Running Workflows on the Cloud
Expand Down
18 changes: 13 additions & 5 deletions examples/llama2/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def get_args(raw_args):
required=False,
help="Whether to use qlora for optimization. Only supported on gpu.",
)
parser.add_argument(
"--dora",
action="store_true",
required=False,
help="Whether to use dora for optimization.",
)
parser.add_argument(
"--account_name",
type=str,
Expand Down Expand Up @@ -103,8 +109,8 @@ def main(raw_args=None):
if args.use_gqa and not args.gpu:
raise ValueError("GQA is only supported on gpu.")

if args.qlora:
template_json, config_name = get_qlora_config()
if args.qlora or args.dora:
template_json, config_name = get_lora_config(args.qlora)
else:
template_json, config_name = get_general_config(args)

Expand Down Expand Up @@ -152,10 +158,12 @@ def get_valid_config(config, key, default=None):
raise ValueError(f"Key {key} is required in the config file.")


def get_qlora_config():
with open("llama2_qlora.json") as f:
def get_lora_config(use_qlora):
with open("llama2_lora_base.json") as f:
template_json = json.load(f)
return template_json, "llama2_gpu_qlora"
lora_type = "qlora" if use_qlora else "dora"
template_json["passes"]["f"]["type"] = lora_type
return template_json, f"llama2_gpu_{lora_type}"


def get_general_config(args):
Expand Down
2 changes: 1 addition & 1 deletion examples/test/azureml/test_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def setup():
@pytest.mark.parametrize("execution_order", [None])
@pytest.mark.parametrize("system", ["local_system"])
@pytest.mark.parametrize("cache_config", [None, {"account_name": account_name, "container_name": container_name}])
@pytest.mark.parametrize("olive_json", ["llama2_qlora.json"])
@pytest.mark.parametrize("olive_json", ["llama2_lora_base.json"])
def test_llama2(search_algorithm, execution_order, system, cache_config, olive_json):
from olive.workflows import run as olive_run

Expand Down
2 changes: 1 addition & 1 deletion examples/test/local/test_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setup():
@pytest.mark.parametrize("search_algorithm", [False])
@pytest.mark.parametrize("execution_order", [None])
@pytest.mark.parametrize("system", ["local_system"])
@pytest.mark.parametrize("olive_json", ["llama2_qlora.json"])
@pytest.mark.parametrize("olive_json", ["llama2_lora_base.json"])
def test_llama2(search_algorithm, execution_order, system, olive_json):
from onnxruntime import __version__ as ort_version

Expand Down
9 changes: 8 additions & 1 deletion olive/olive_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,13 @@
"supported_precisions": [ "*" ],
"extra_dependencies": [ "bnb", "lora" ]
},
"DoRA": {
"module_path": "olive.passes.pytorch.lora.DoRA",
"supported_providers": [ "*" ],
"supported_accelerators": [ "*" ],
"supported_precisions": [ "*" ],
"extra_dependencies": [ "bnb", "lora" ]
},
"QuantizationAwareTraining": {
"module_path": "olive.passes.pytorch.quantization_aware_training.QuantizationAwareTraining",
"supported_providers": [ "*" ],
Expand Down Expand Up @@ -367,7 +374,7 @@
"flash-attn": [ "flash_attn" ],
"gpu": [ "onnxruntime-gpu" ],
"inc": [ "neural-compressor" ],
"lora": [ "accelerate>=0.30.0", "peft", "scipy" ],
"lora": [ "accelerate>=0.30.0", "peft>=0.12.0", "scipy" ],
"nvmo": [ "nvidia-modelopt", "onnx-graphsurgeon", "datasets>=2.14.4", "cppimport==22.8.2" ],
"openvino": [ "openvino==2023.2.0", "nncf==2.7.0", "numpy<2.0" ],
"optimum": [ "optimum" ],
Expand Down
64 changes: 61 additions & 3 deletions olive/passes/pytorch/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,9 @@ def init_lora_adapters(
task: str,
config: ConfigBase,
target_modules: Optional[List[str]] = None,
ephemeral_gpu_offload: Optional[bool] = False,
use_loftq: Optional[bool] = False,
use_dora: Optional[bool] = False,
) -> "PeftModel":
"""Initialize LoRA adapters.

Expand All @@ -378,9 +380,10 @@ def init_lora_adapters(
:param config: The config for the pass run.
:param target_modules: List of modules to target for LoRA fine-tuning.
:param use_loftq: Whether to use LoftQ to initialize weights.
:param use_dora: Whether to use Dora for optimization.
:return: The LoRA model.
"""
from peft import LoraConfig, get_peft_model
from peft import LoraConfig, LoraRuntimeConfig, get_peft_model

lora_config_kwargs = {}
if use_loftq:
Expand All @@ -391,6 +394,11 @@ def init_lora_adapters(
"loftq_config": LoftQConfig(loftq_bits=4, loftq_iter=config.loftq_iter),
}

if use_dora:
lora_config_kwargs = {
"use_dora": True,
}

peft_task_type = get_peft_task_type_from_task(task, fail_on_not_found=True)
lora_config = LoraConfig(
r=config.lora_r,
Expand All @@ -400,6 +408,7 @@ def init_lora_adapters(
bias="none",
task_type=peft_task_type,
modules_to_save=config.modules_to_save,
runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=ephemeral_gpu_offload),
**lora_config_kwargs,
)

Expand All @@ -412,6 +421,8 @@ def enable_lora(
config: ConfigBase,
adapter_path: Optional[str] = None,
target_modules: Optional[List[str]] = None,
use_dora: Optional[bool] = False,
ephemeral_gpu_offload: Optional[bool] = False,
) -> "PeftModel":
"""Enable LoRA fine-tuning on a Hugging Face PyTorch model.

Expand All @@ -425,6 +436,8 @@ def enable_lora(
:param config: The config for the pass run.
:param adapter_path: Path to the adapter weights. If None, will initialize new adapters.
:param target_modules: List of modules to target for LoRA fine-tuning. Only used if adapter_path is None.
:param use_dora: Whether to use Dora for optimization.
:param ephemeral_gpu_offload: Whether to use ephemeral GPU offload.
:return: The LoRA model.
"""
from peft import PeftModel
Expand All @@ -450,10 +463,19 @@ def enable_lora(
)
if not adapter_path:
logger.debug("Initializing LoRA adapters from config")
lora_model = self.init_lora_adapters(model, task, config, target_modules=target_modules)
lora_model = self.init_lora_adapters(
model,
task,
config,
target_modules=target_modules,
use_dora=use_dora,
ephemeral_gpu_offload=ephemeral_gpu_offload,
)
else:
logger.debug("Loading LoRA adapters from %s", adapter_path)
lora_model = PeftModel.from_pretrained(model, adapter_path, is_trainable=True)
lora_model = PeftModel.from_pretrained(
model, adapter_path, is_trainable=True, ephemeral_gpu_offload=ephemeral_gpu_offload
)
logger.debug(
"The number of trainable parameters in the LoRA model: %s", self.count_trainable_parameters(lora_model)
)
Expand Down Expand Up @@ -641,6 +663,42 @@ def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_
)


class DoRA(LoRABase):
"""Run DoRA fine-tuning on a Hugging Face PyTorch model."""

@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
config = {
"ephemeral_gpu_offload": PassConfigParam(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be used with LoRA also?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, but from Huggingface documentation, this option is only used for Dora, so I didn't add this for Lora.

type_=bool, default_value=False, description="Ephemeral GPU offload"
),
}
config.update(super()._default_config(accelerator_spec))
return config

def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler:
# convert config to pass config class
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is identical as LoRA pass except the use_dora flag on enable_lora().
How could we avoid code duplication? The comments in LoRA's _run_for_config() has additional details that could be useful here also.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. I'm working on LoHa and LoKr also. Let me revisit this when adding those 2.

# this will validate the config and convert to the correct types
config = self._config_class(**config)

# check dependencies
self.check_dependencies(config)

# use default training args if not provided
config.training_args = config.training_args or HFTrainingArguments()

# get new model
pytorch_model = self.load_base_pytorch_model(model, config)

# add lora modules
pytorch_model = self.enable_lora(pytorch_model, model.task, config, use_dora=True)

# train and return new model
return self.train_and_save_new_model(
pytorch_model, model.get_hf_tokenizer(), config, deepcopy(model), output_model_path
)


class QLoRABase(LoRABase):
"""Base class for QLoRA and LoftQ fine-tuning passes."""

Expand Down
10 changes: 9 additions & 1 deletion test/unit_test/passes/pytorch/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from olive.data.template import huggingface_data_config_template
from olive.model import HfModelHandler
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.pytorch.lora import LoftQ, LoRA, QLoRA
from olive.passes.pytorch.lora import DoRA, LoftQ, LoRA, QLoRA

# pylint: disable=redefined-outer-name

Expand Down Expand Up @@ -112,3 +112,11 @@ def test_loftq(tmp_path):
# assert
assert Path(out.get_resource("model_path")).exists()
assert Path(out.get_resource("adapter_path")).exists()


def test_dora(tmp_path):
# execute
out = run_finetuning(DoRA, tmp_path, torch_dtype="float32")

# assert
assert Path(out.get_resource("adapter_path")).exists()
Loading