diff --git a/examples/llama2/llama2_qlora.json b/examples/llama2/llama2_qlora.json index d3f9d6063..3e250a644 100644 --- a/examples/llama2/llama2_qlora.json +++ b/examples/llama2/llama2_qlora.json @@ -79,6 +79,7 @@ "gradient_accumulation_steps": 1, "max_steps": 1500, "logging_steps": 100, + "save_steps": 100, "evaluation_strategy": "steps", "adam_beta2": 0.999, "max_grad_norm": 0.3, diff --git a/examples/open_llama/README.md b/examples/open_llama/README.md index 5f5d43ed6..0397b38a7 100644 --- a/examples/open_llama/README.md +++ b/examples/open_llama/README.md @@ -102,6 +102,19 @@ Note: You must be logged in to HuggingFace using `huggingface-cli login` to down Requirements file: [requirements-lora.txt](requirements-lora.txt) +**Train using ONNX Runtime Training** +You can also train the model using [ONNX Runtime Training](https://techcommunity.microsoft.com/t5/ai-machine-learning-blog/onnx-runtime-training-technical-deep-dive/ba-p/1398310). + +The relevant config file is [open_llama_qlora_ort_tinycodes.json](open_llama_qlora_ort_tinycodes.json). + +Requirements file: [requirements-qlora-ort.txt](requirements-qlora-ort.txt) + +It also requires the latest version of onnxruntime-training: +```bash +python -m pip uninstall -y onnxruntime onnxruntime-gpu ort-nightly ort-nightly-gpu +python -m pip install onnxruntime-training --pre --upgrade --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ +``` + ### Optimizing Open Llama Model with Azure Arc This workflow optimizes Open Llama model on Azure ML compute, and evaluate output models on your device. Please connect your device to Azure Arc by following instruction: [Self-hosted Kubernetes cluster](https://microsoft.github.io/Olive/tutorials/azure_arc.html) diff --git a/examples/open_llama/open_llama_lora_tinycodes.json b/examples/open_llama/open_llama_lora_tinycodes.json index 560e664ca..1ff3efa48 100644 --- a/examples/open_llama/open_llama_lora_tinycodes.json +++ b/examples/open_llama/open_llama_lora_tinycodes.json @@ -49,6 +49,7 @@ "gradient_accumulation_steps": 1, "max_steps": 500, "logging_steps": 100, + "save_steps": 100, "evaluation_strategy": "steps", "adam_beta2": 0.999, "max_grad_norm": 0.3, diff --git a/examples/open_llama/open_llama_qlora_ort_tinycodes.json b/examples/open_llama/open_llama_qlora_ort_tinycodes.json new file mode 100644 index 000000000..171aba61e --- /dev/null +++ b/examples/open_llama/open_llama_qlora_ort_tinycodes.json @@ -0,0 +1,78 @@ +{ + "input_model":{ + "type": "PyTorchModel", + "config": { + "hf_config": { + "model_name": "openlm-research/open_llama_7b_v2", + "task": "text-generation" + } + } + }, + "data_configs": { + "tiny-codes-train": { + "name": "tiny-codes-train", + "type": "HuggingfaceContainer", + "user_script": "lora_user_script.py", + "components": { + "load_dataset": { + "type": "load_tiny_code_dataset" + } + }, + "params_config": { + "data_name": "nampdn-ai/tiny-codes", + "split": "train", + "component_kwargs": { + "load_dataset": { + "language": "Python", + "token": true + }, + "pre_process_data": { + "dataset_type": "corpus", + "corpus_strategy": "join", + "text_template": "### Question: {prompt} \n### Answer: {response}", + "source_max_len": 1024 + } + } + } + } + }, + "passes": { + "qlora": { + "type": "QLoRA", + "config": { + "use_ort_trainer": true, + "torch_dtype": "float32", + "lora_dropout": 0.1, + "train_data_config": "tiny-codes-train", + "eval_dataset_size": 1024, + "training_args": { + "per_device_train_batch_size": 1, + "per_device_eval_batch_size": 1, + "gradient_accumulation_steps": 16, + "gradient_checkpointing": false, + "max_steps": 1500, + "logging_steps": 100, + "save_steps": 100, + "evaluation_strategy": "steps", + "adam_beta2": 0.999, + "max_grad_norm": 0.3, + "load_best_model_at_end": true + } + } + } + }, + "engine": { + "log_severity_level": 0, + "search_strategy": false, + "evaluate_input_model": false, + "target": { + "type": "LocalSystem", + "config": { + "accelerators": ["gpu"] + } + }, + "execution_providers": ["CPUExecutionProvider"], + "cache_dir": "cache", + "output_dir" : "models/open_llama_qlora_ort_tinycodes" + } +} diff --git a/examples/open_llama/open_llama_qlora_tinycodes.json b/examples/open_llama/open_llama_qlora_tinycodes.json index 17f435a66..f14050a61 100644 --- a/examples/open_llama/open_llama_qlora_tinycodes.json +++ b/examples/open_llama/open_llama_qlora_tinycodes.json @@ -49,6 +49,7 @@ "gradient_accumulation_steps": 1, "max_steps": 1500, "logging_steps": 100, + "save_steps": 100, "evaluation_strategy": "steps", "adam_beta2": 0.999, "max_grad_norm": 0.3, diff --git a/examples/open_llama/requirements-qlora-ort.txt b/examples/open_llama/requirements-qlora-ort.txt new file mode 100644 index 000000000..aebac65a8 --- /dev/null +++ b/examples/open_llama/requirements-qlora-ort.txt @@ -0,0 +1,8 @@ +-r requirements.txt +# the latest version of accelerator will report deepcopy error +accelerate==0.23.0 +# the latest version of bitsandbytes has a new quant_state format +bitsandbytes==0.41.1 +optimum +peft +scikit-learn diff --git a/examples/phi/phi_qlora_tinycodes.json b/examples/phi/phi_qlora_tinycodes.json index cbd3bc8f1..21e2d0c92 100644 --- a/examples/phi/phi_qlora_tinycodes.json +++ b/examples/phi/phi_qlora_tinycodes.json @@ -54,6 +54,7 @@ "gradient_checkpointing": false, "max_steps": 1500, "logging_steps": 100, + "save_steps": 100, "evaluation_strategy": "steps", "adam_beta2": 0.999, "max_grad_norm": 0.3, diff --git a/olive/extra_dependencies.json b/olive/extra_dependencies.json index bd4f1c2f9..ff4fad1c3 100644 --- a/olive/extra_dependencies.json +++ b/olive/extra_dependencies.json @@ -31,5 +31,22 @@ ], "torch-tensorrt": [ "torch-tensorrt" + ], + "lora": [ + "accelerate", + "peft" + ], + "qlora": [ + "accelerate", + "bitsandbytes", + "peft" + ], + "qlora-ort": [ + "accelerate", + "bitsandbytes", + "onnxruntime-training", + "optimum", + "peft", + "torch-ort" ] } diff --git a/olive/passes/pytorch/lora.py b/olive/passes/pytorch/lora.py index 0a80b16d0..a2ccdbdea 100644 --- a/olive/passes/pytorch/lora.py +++ b/olive/passes/pytorch/lora.py @@ -100,21 +100,32 @@ def validate_extra_args(cls, v): v = {} # make sure extra args are fields of transformers.Trainer training_args_fields = {f.name for f in dataclasses.fields(transformers.TrainingArguments) if f.init} + # use_module_with_loss is a field of optimum.onnxruntime.ORTTrainingArguments + training_args_fields.add("use_module_with_loss") for k in list(v): # need a copy of the keys since we are mutating the dict - if k == "output_dir": - logger.warning(f"Extra arg {k} is not allowed. Please use `training_output_dir` instead.") + if k == "fp16": + logger.warning(f"Extra arg {k} is not allowed. Please use `torch_dtype` instead.") del v[k] elif k not in training_args_fields: logger.warning(f"Extra arg {k} is not a field of transformers.TrainingArguments. Ignoring.") del v[k] return v - def create_training_args(self) -> transformers.TrainingArguments: + def create_training_args(self, use_ort_trainer: bool) -> transformers.TrainingArguments: args = self.dict() if not args["output_dir"]: raise ValueError("output_dir must be provided.") extra_args = args.pop("extra_args") - return transformers.TrainingArguments(**args, **extra_args) + if use_ort_trainer: + from optimum.onnxruntime import ORTTrainingArguments + + training_args_cls = ORTTrainingArguments + else: + training_args_cls = transformers.TrainingArguments + if "use_module_with_loss" in extra_args: + logger.warning("use_module_with_loss is not supported by transformers.TrainingArguments. Ignoring.") + extra_args.pop("use_module_with_loss") + return training_args_cls(**args, **extra_args) class LoRABase(Pass): @@ -127,6 +138,9 @@ class LoRABase(Pass): @staticmethod def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]: return { + "use_ort_trainer": PassConfigParam( + type_=bool, default_value=False, description="Whether or not to use ORTTrainer." + ), "lora_r": PassConfigParam(type_=int, default_value=64, description="Lora attention dimension."), "lora_alpha": PassConfigParam( type_=float, default_value=16, description="The alpha parameter for Lora scaling." @@ -146,9 +160,8 @@ def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigPa type_=str, default_value="bfloat16", description=( - "Data type for model weights and adapter weights. For 4bit quantized model, " - "it is also the computation data type for the quantized modules. " - "Should be one of `bfloat16`, `float16` or `float32`." + "Data type to use for training. Should be one of `bfloat16`, `float16` or `float32`. If `float16`" + " will use fp16 mixed-precision training." ), ), "allow_tf32": PassConfigParam( @@ -196,6 +209,25 @@ def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigPa ), } + def validate_search_point( + self, search_point: Dict[str, Any], accelerator_spec: AcceleratorSpec, with_fixed_value: bool = False + ) -> bool: + if with_fixed_value: + search_point = self.config_at_search_point(search_point or {}) + if search_point.get("use_ort_trainer"): + if search_point.get("torch_dtype") == "bfloat16": + logger.info( + "bfloat16 is not supported by onnxruntime-training yet. Please use a different torch_dtype." + ) + return False + if search_point.get("training_args", {}).get("gradient_checkpointing"): + logger.info( + "gradient_checkpointing is not supported by onnxruntime-training. Please set gradient_checkpointing" + " to False." + ) + return False + return True + @staticmethod def collate_batch(batch: List[Dict], tokenizer: transformers.PreTrainedTokenizer) -> Dict[str, torch.Tensor]: """Collate a batch of samples into a padded batch of tensors. @@ -355,6 +387,7 @@ def train_and_save_new_model( data_root: str, output_model: PyTorchModel, output_model_path: str, + torch_dtype: torch.dtype, ) -> PyTorchModel: if torch.cuda.is_available(): allow_tf32 = torch.backends.cuda.matmul.allow_tf32 @@ -385,16 +418,27 @@ def train_and_save_new_model( # Plus the cleanup after error doesn't work as expected with notebooks with tempfile.TemporaryDirectory(prefix="olive_tmp") as temp_dir: if not config.training_args.output_dir: - logger.info("No training_output_dir provided. Using a temp dir.") + logger.info("No training_args.output_dir provided. Using a temp dir.") config.training_args.output_dir = temp_dir # set save_total_limit to 1 since the temp dir will be deleted after training config.training_args.extra_args["save_total_limit"] = 1 + if torch_dtype == torch.float16: + # use fp16 mixed precision training + config.training_args.extra_args["fp16"] = True + # create training args + logger.debug(f"Training args: {config.training_args.dict()}") + + trainer_cls = transformers.Trainer + if config.use_ort_trainer: + from optimum.onnxruntime import ORTTrainer + + trainer_cls = ORTTrainer # get trainer - trainer = transformers.Trainer( + trainer = trainer_cls( model=model, tokenizer=tokenizer, - args=config.training_args.create_training_args(), + args=config.training_args.create_training_args(config.use_ort_trainer), train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=partial(self.collate_batch, tokenizer=tokenizer), @@ -533,15 +577,22 @@ def _run_for_config( new_model = self.input_model_check(deepcopy(model)) torch_dtype = self.get_torch_dtype(config.torch_dtype) + # will use mixed precision since full fp16 is unstable + model_dtype = torch_dtype if torch_dtype != torch.float16 else torch.float32 # load model, reset model_loading_args and adapter_path model_loading_args = ( new_model.hf_config.model_loading_args.dict() if new_model.hf_config.model_loading_args else {} ) - model_loading_args.update({"torch_dtype": torch_dtype, "device_map": "auto"}) + model_loading_args.update( + {"torch_dtype": model_dtype, "device_map": "auto" if not config.use_ort_trainer else None} + ) new_model.hf_config.model_loading_args = HFModelLoadingArgs(**model_loading_args) pytorch_model = new_model.load_model() - pytorch_model.config.torch_dtype = torch_dtype + if torch.cuda.is_available() and config.use_ort_trainer: + # put the model on GPU since device_map is None and the model will be on CPU + pytorch_model.to("cuda") + pytorch_model.config.torch_dtype = model_dtype # tokenizer tokenizer = AutoTokenizer.from_pretrained(new_model.hf_config.model_name) @@ -552,7 +603,9 @@ def _run_for_config( ) # train and return new model - return self.train_and_save_new_model(pytorch_model, tokenizer, config, data_root, new_model, output_model_path) + return self.train_and_save_new_model( + pytorch_model, tokenizer, config, data_root, new_model, output_model_path, torch_dtype + ) class QLoRA(LoRABase): @@ -580,6 +633,13 @@ def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigPa default_value="nf4", description="Quantization data type to use. Should be one of `fp4` or `nf4`.", ), + "compute_dtype": PassConfigParam( + type_=str, + description=( + "Computation data type for the quantized modules. If not provided, will use the same dtype as" + " torch_dtype" + ), + ), } config.update(LoRABase._default_config(accelerator_spec)) return config @@ -595,15 +655,23 @@ def _run_for_config( # this will validate the config and convert to the correct types config = self._config_class(**config) + # MatMulBnb4 contrib op doesn't support double quantization so the trainer falls back to PythonOp + # which uses more memory and is slower + if config.use_ort_trainer and config.double_quant: + logger.warning( + "double_quant is set to True but it is inefficient with onnxruntime-training! Consider setting it to" + " False." + ) + # use default training args if not provided config.training_args = config.training_args or HFTrainingArguments() # get models and tokenizer - new_model, pytorch_model, tokenizer, quantized_modules = self.get_model_tokenizer(model, config) + new_model, pytorch_model, tokenizer, quantized_modules, torch_dtype = self.get_model_tokenizer(model, config) # train and get new model output_model = self.train_and_save_new_model( - pytorch_model, tokenizer, config, data_root, new_model, output_model_path + pytorch_model, tokenizer, config, data_root, new_model, output_model_path, torch_dtype ) # add quantized_modules attributes output_model.model_attributes["quantized_modules"] = quantized_modules @@ -611,7 +679,7 @@ def _run_for_config( def get_model_tokenizer( self, model: PyTorchModel, config: ConfigBase - ) -> Tuple[PyTorchModel, PreTrainedModel, PreTrainedTokenizer, List[str]]: + ) -> Tuple[PyTorchModel, PreTrainedModel, PreTrainedTokenizer, List[str], torch.dtype]: """Get the Olive model, PyTorch model and tokenizer for QLoRA fine-tuning.""" import bitsandbytes as bnb @@ -622,6 +690,8 @@ def get_model_tokenizer( new_model = self.input_model_check(deepcopy(model)) torch_dtype = self.get_torch_dtype(config.torch_dtype) + # will use mixed precision since full fp16 is unstable + model_dtype = torch_dtype if torch_dtype != torch.float16 else torch.float32 # load model, reset model_loading_args and adapter_path model_loading_args = ( @@ -629,14 +699,16 @@ def get_model_tokenizer( ) model_loading_args.update( { - "torch_dtype": torch_dtype, + "torch_dtype": model_dtype, # TODO(jambayk): Worry about `use_multi_gpu` and distributed training later # this uses all available GPUs, model parallel - "device_map": "auto", + # ORTTrainer falls back to pytorch when model parallel is used + # use `None` device_map to only use one GPU + "device_map": "auto" if not config.use_ort_trainer else None, "quantization_method": "bitsandbytes", "quantization_config": { "load_in_4bit": True, - "bnb_4bit_compute_dtype": torch_dtype, + "bnb_4bit_compute_dtype": self.get_torch_dtype(config.compute_dtype or config.torch_dtype), "bnb_4bit_use_double_quant": config.double_quant, "bnb_4bit_quant_type": config.quant_type, }, @@ -644,7 +716,7 @@ def get_model_tokenizer( ) new_model.hf_config.model_loading_args = HFModelLoadingArgs(**model_loading_args) pytorch_model = new_model.load_model() - pytorch_model.config.torch_dtype = torch_dtype + pytorch_model.config.torch_dtype = model_dtype # tokenizer tokenizer = AutoTokenizer.from_pretrained(new_model.hf_config.model_name) @@ -658,4 +730,4 @@ def get_model_tokenizer( target_modules = find_submodules(pytorch_model, bnb.nn.Linear4bit) pytorch_model = self.enable_lora(pytorch_model, tokenizer, new_model.hf_config.task, config, target_modules) - return new_model, pytorch_model, tokenizer, target_modules + return new_model, pytorch_model, tokenizer, target_modules, torch_dtype diff --git a/olive/workflows/run/run.py b/olive/workflows/run/run.py index 80d2e8304..07e3e64e2 100644 --- a/olive/workflows/run/run.py +++ b/olive/workflows/run/run.py @@ -75,6 +75,8 @@ def dependency_setup(config): "OptimumConversion": extras.get("optimum"), "OptimumMerging": extras.get("optimum"), "TorchTRTConversion": extras.get("torch-tensorrt"), + "LoRA": extras.get("lora"), + "QLoRA": extras.get("qlora"), }, } ort_packages = ["onnxruntime", "onnxruntime-directml", "onnxruntime-gpu", "onnxruntime-openvino"]