Skip to content

Commit

Permalink
LoRA/QLoRA: Mixed precision for torch_dtype=float16, Support onnxru…
Browse files Browse the repository at this point in the history
…ntime-training (#722)

## Describe your changes
Changes in this PR:
- Full `float16` training is unstable. Now the lora passes use
mixed-precision training (`fp16=True` in the trainer) when `torch_dtype`
is `float16`.
- `compute_dtype` for the quantized modules is a separate config
parameter to allow for flexibility. If not provided, use the same dtype
as `torch_dtype`
- Support for `onnxruntime-training`. 

Also adds an example for qlora with onnxruntime-training. 

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.

## (Optional) Issue link
  • Loading branch information
jambayk authored Nov 14, 2023
1 parent 0616b51 commit a4b9e51
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 21 deletions.
1 change: 1 addition & 0 deletions examples/llama2/llama2_qlora.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"gradient_accumulation_steps": 1,
"max_steps": 1500,
"logging_steps": 100,
"save_steps": 100,
"evaluation_strategy": "steps",
"adam_beta2": 0.999,
"max_grad_norm": 0.3,
Expand Down
13 changes: 13 additions & 0 deletions examples/open_llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,19 @@ Note: You must be logged in to HuggingFace using `huggingface-cli login` to down

Requirements file: [requirements-lora.txt](requirements-lora.txt)

**Train using ONNX Runtime Training**
You can also train the model using [ONNX Runtime Training](https://techcommunity.microsoft.com/t5/ai-machine-learning-blog/onnx-runtime-training-technical-deep-dive/ba-p/1398310).

The relevant config file is [open_llama_qlora_ort_tinycodes.json](open_llama_qlora_ort_tinycodes.json).

Requirements file: [requirements-qlora-ort.txt](requirements-qlora-ort.txt)

It also requires the latest version of onnxruntime-training:
```bash
python -m pip uninstall -y onnxruntime onnxruntime-gpu ort-nightly ort-nightly-gpu
python -m pip install onnxruntime-training --pre --upgrade --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/
```

### Optimizing Open Llama Model with Azure Arc
This workflow optimizes Open Llama model on Azure ML compute, and evaluate output models on your device. Please connect your device to Azure Arc by following instruction: [Self-hosted Kubernetes cluster](https://microsoft.github.io/Olive/tutorials/azure_arc.html)

Expand Down
1 change: 1 addition & 0 deletions examples/open_llama/open_llama_lora_tinycodes.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"gradient_accumulation_steps": 1,
"max_steps": 500,
"logging_steps": 100,
"save_steps": 100,
"evaluation_strategy": "steps",
"adam_beta2": 0.999,
"max_grad_norm": 0.3,
Expand Down
78 changes: 78 additions & 0 deletions examples/open_llama/open_llama_qlora_ort_tinycodes.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
{
"input_model":{
"type": "PyTorchModel",
"config": {
"hf_config": {
"model_name": "openlm-research/open_llama_7b_v2",
"task": "text-generation"
}
}
},
"data_configs": {
"tiny-codes-train": {
"name": "tiny-codes-train",
"type": "HuggingfaceContainer",
"user_script": "lora_user_script.py",
"components": {
"load_dataset": {
"type": "load_tiny_code_dataset"
}
},
"params_config": {
"data_name": "nampdn-ai/tiny-codes",
"split": "train",
"component_kwargs": {
"load_dataset": {
"language": "Python",
"token": true
},
"pre_process_data": {
"dataset_type": "corpus",
"corpus_strategy": "join",
"text_template": "### Question: {prompt} \n### Answer: {response}",
"source_max_len": 1024
}
}
}
}
},
"passes": {
"qlora": {
"type": "QLoRA",
"config": {
"use_ort_trainer": true,
"torch_dtype": "float32",
"lora_dropout": 0.1,
"train_data_config": "tiny-codes-train",
"eval_dataset_size": 1024,
"training_args": {
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
"gradient_accumulation_steps": 16,
"gradient_checkpointing": false,
"max_steps": 1500,
"logging_steps": 100,
"save_steps": 100,
"evaluation_strategy": "steps",
"adam_beta2": 0.999,
"max_grad_norm": 0.3,
"load_best_model_at_end": true
}
}
}
},
"engine": {
"log_severity_level": 0,
"search_strategy": false,
"evaluate_input_model": false,
"target": {
"type": "LocalSystem",
"config": {
"accelerators": ["gpu"]
}
},
"execution_providers": ["CPUExecutionProvider"],
"cache_dir": "cache",
"output_dir" : "models/open_llama_qlora_ort_tinycodes"
}
}
1 change: 1 addition & 0 deletions examples/open_llama/open_llama_qlora_tinycodes.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"gradient_accumulation_steps": 1,
"max_steps": 1500,
"logging_steps": 100,
"save_steps": 100,
"evaluation_strategy": "steps",
"adam_beta2": 0.999,
"max_grad_norm": 0.3,
Expand Down
8 changes: 8 additions & 0 deletions examples/open_llama/requirements-qlora-ort.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-r requirements.txt
# the latest version of accelerator will report deepcopy error
accelerate==0.23.0
# the latest version of bitsandbytes has a new quant_state format
bitsandbytes==0.41.1
optimum
peft
scikit-learn
1 change: 1 addition & 0 deletions examples/phi/phi_qlora_tinycodes.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"gradient_checkpointing": false,
"max_steps": 1500,
"logging_steps": 100,
"save_steps": 100,
"evaluation_strategy": "steps",
"adam_beta2": 0.999,
"max_grad_norm": 0.3,
Expand Down
17 changes: 17 additions & 0 deletions olive/extra_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,22 @@
],
"torch-tensorrt": [
"torch-tensorrt"
],
"lora": [
"accelerate",
"peft"
],
"qlora": [
"accelerate",
"bitsandbytes",
"peft"
],
"qlora-ort": [
"accelerate",
"bitsandbytes",
"onnxruntime-training",
"optimum",
"peft",
"torch-ort"
]
}
Loading

0 comments on commit a4b9e51

Please sign in to comment.