From efde6d9fa5db4525d644b95aacd521024f8af5d8 Mon Sep 17 00:00:00 2001 From: Xiaoyu Date: Thu, 30 Jan 2025 14:51:01 +0000 Subject: [PATCH] Add LoHa and LoKr pass --- .../model-opt-and-transform/pytorch.md | 2 +- docs/source/reference/options.md | 1 + docs/source/reference/pass.rst | 6 + examples/llama2/llama2_lmeval.json | 2 +- examples/llama2/llama2_qlora.json | 4 +- examples/phi3/phi3_template.json | 4 +- olive/cli/finetune.py | 4 +- olive/passes/pytorch/lora.py | 304 ++++++++++++++---- olive/passes/pytorch/train_utils.py | 7 + test/unit_test/passes/pytorch/test_lora.py | 16 +- 10 files changed, 269 insertions(+), 81 deletions(-) diff --git a/docs/source/how-to/configure-workflows/model-opt-and-transform/pytorch.md b/docs/source/how-to/configure-workflows/model-opt-and-transform/pytorch.md index e4431d144..c0652a64d 100644 --- a/docs/source/how-to/configure-workflows/model-opt-and-transform/pytorch.md +++ b/docs/source/how-to/configure-workflows/model-opt-and-transform/pytorch.md @@ -14,7 +14,7 @@ This pass only supports HfModels. Please refer to [LoRA](lora) for more details ```json { "type": "LoRA", - "lora_alpha": 16, + "alpha": 16, "train_data_config": // ..., "training_args": { "learning_rate": 0.0002, diff --git a/docs/source/reference/options.md b/docs/source/reference/options.md index 322dc7144..dfe69a781 100644 --- a/docs/source/reference/options.md +++ b/docs/source/reference/options.md @@ -410,6 +410,7 @@ Please also find the detailed options from following table for each pass: | [SplitModel](../../reference/pass.rst#_split_model) | Split an ONNX model into multiple smaller sub-models based on predefined assignments. | | [LoRA](../../reference/pass.rst#_lora) | Run LoRA fine-tuning on a Hugging Face PyTorch model. | | [QLoRA](../../reference/pass.rst#_qlora) | Run QLoRA fine-tuning on a Hugging Face PyTorch model. | +| [LoHa](../../reference/pass.rst#_loha) | Run LoHa fine-tuning on a Hugging Face PyTorch model. | | [LoftQ](../../reference/pass.rst#_loftq) | Run LoftQ fine-tuning on a Hugging Face PyTorch model. | | [QuantizationAwareTraining](../../reference/pass.rst#_onnx_quantization_aware_training) | Run quantization aware training on PyTorch model. | | [OpenVINOConversion](../../reference/pass.rst#_openvino_conversion) | Converts PyTorch, ONNX or TensorFlow Model to OpenVino Model. | diff --git a/docs/source/reference/pass.rst b/docs/source/reference/pass.rst index ba080b3c2..ad5c8e343 100644 --- a/docs/source/reference/pass.rst +++ b/docs/source/reference/pass.rst @@ -197,6 +197,12 @@ QLoRA ----- .. autoconfigclass:: olive.passes.QLoRA +.. _loha: + +LoHa +----- +.. autoconfigclass:: olive.passes.LoHa + .. _loftq: LoftQ diff --git a/examples/llama2/llama2_lmeval.json b/examples/llama2/llama2_lmeval.json index b6630d25d..8328b25d9 100644 --- a/examples/llama2/llama2_lmeval.json +++ b/examples/llama2/llama2_lmeval.json @@ -42,7 +42,7 @@ "max_steps": 150, "logging_steps": 50.0 }, - "lora_alpha": 16, + "alpha": 16, "eval_data_config": "eval_data" } }, diff --git a/examples/llama2/llama2_qlora.json b/examples/llama2/llama2_qlora.json index 5823c2820..d5a0ae217 100644 --- a/examples/llama2/llama2_qlora.json +++ b/examples/llama2/llama2_qlora.json @@ -33,8 +33,8 @@ "max_steps": 150, "logging_steps": 50.0 }, - "lora_r": 64, - "lora_alpha": 16, + "r": 64, + "alpha": 16, "eval_data_config": "eval_data" }, "c": { diff --git a/examples/phi3/phi3_template.json b/examples/phi3/phi3_template.json index e5508d878..c100ab7e2 100644 --- a/examples/phi3/phi3_template.json +++ b/examples/phi3/phi3_template.json @@ -65,7 +65,7 @@ "type": "LoRA", "train_data_config": "tiny_codes_train", "eval_data_config": "tiny_codes_eval", - "lora_r": 64, + "r": 64, "training_args": { "per_device_train_batch_size": 1, "per_device_eval_batch_size": 1, @@ -80,7 +80,7 @@ "type": "QLoRA", "train_data_config": "tiny_codes_train", "eval_data_config": "tiny_codes_eval", - "lora_r": 64, + "r": 64, "training_args": { "per_device_train_batch_size": 1, "per_device_eval_batch_size": 1, diff --git a/olive/cli/finetune.py b/olive/cli/finetune.py index 792c7caf8..70505f39b 100644 --- a/olive/cli/finetune.py +++ b/olive/cli/finetune.py @@ -107,8 +107,8 @@ def _get_run_config(self, tempdir: str) -> Dict: ((*finetune_key, "type"), self.args.method), ((*finetune_key, "torch_dtype"), self.args.torch_dtype), ((*finetune_key, "training_args"), self.parse_training_args()), - ((*finetune_key, "lora_r"), self.args.lora_r), - ((*finetune_key, "lora_alpha"), self.args.lora_alpha), + ((*finetune_key, "r"), self.args.lora_r), + ((*finetune_key, "alpha"), self.args.lora_alpha), ("output_dir", self.args.output_path), ("log_severity_level", self.args.log_level), ] diff --git a/olive/passes/pytorch/lora.py b/olive/passes/pytorch/lora.py index 8135d5a40..54d7cf8e3 100644 --- a/olive/passes/pytorch/lora.py +++ b/olive/passes/pytorch/lora.py @@ -92,20 +92,18 @@ def validate_torch_dtype(cls, v): class LoRABase(Pass): - """Base class for LoRA and QLoRA fine-tuning passes.""" + """Base class for LoRA fine-tuning passes.""" @classmethod def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]: return { - "lora_r": PassConfigParam( + "r": PassConfigParam( type_=int, default_value=64, search_defaults=Categorical([16, 32, 64]), - description="Lora R dimension.", - ), - "lora_alpha": PassConfigParam( - type_=float, default_value=16, description="The alpha parameter for Lora scaling." + description="R dimension.", ), + "alpha": PassConfigParam(type_=float, default_value=16, description="The alpha parameter for scaling."), "lora_dropout": PassConfigParam( type_=float, default_value=0.05, description="The dropout probability for Lora layers." ), @@ -224,6 +222,38 @@ def get_datasets( return train_dataset, eval_dataset + def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler: + # convert config to pass config class + # this will validate the config and convert to the correct types + config = self._config_class(**config) + + # check dependencies + self.check_dependencies(config) + + # use default training args if not provided + config.training_args = config.training_args or HFTrainingArguments() + + # check if peft or olive has target modules for the model + config.target_modules = config.target_modules or self.check_target_modules(model) + + # get new model + pytorch_model = self.load_base_pytorch_model(model, config) + # NOTE: quantized model support + # awq: requires awq cuda extension or triton for backward pass, scale must be fp16 + # gptq: there is no custom backend. works fine when using naive dequantize + matmul + # no issue with single precision. mix precision depends on autocast as there is no input cast + # gradient might not be correct when using cuda/exllama deqauntize kernels + # we load in fp32/bf16 so cuda kernels are disabled by default. Might need extra work to + # disable exllama (gptq pass disables it) + + # add lora modules + pytorch_model = self.enable_lora(pytorch_model, config, model.task) + + # train and return new model + return self.train_and_save_new_model( + pytorch_model, model.get_hf_tokenizer(), config, deepcopy(model), output_model_path + ) + def load_base_pytorch_model(self, model_handler: HfModelHandler, config: ConfigBase, **kwargs) -> "PreTrainedModel": """Load a base PyTorch model for fine-tuning. @@ -242,55 +272,41 @@ def load_base_pytorch_model(self, model_handler: HfModelHandler, config: ConfigB # "auto": uses all available GPUs, model parallel return load_hf_base_model(model_handler, torch_dtype=model_dtype, device_map="auto", **kwargs) - def init_lora_adapters( + def init_adapters( self, model: "PreTrainedModel", - task: str, config: ConfigBase, - target_modules: Optional[List[str]] = None, + *, + task: Optional[str] = None, use_loftq: Optional[bool] = False, ) -> "PeftModel": """Initialize LoRA adapters. :param model: The Hugging Face PyTorch model to add LoRA adapters to. - :param task: The task type of the model. :param config: The config for the pass run. - :param target_modules: List of modules to target for LoRA fine-tuning. + :param task: The task type of the model. :param use_loftq: Whether to use LoftQ to initialize weights. :return: The LoRA model. """ - from peft import LoraConfig, get_peft_model - - lora_config_kwargs = {} + config_kwargs = {} if use_loftq: from peft import LoftQConfig - lora_config_kwargs = { + config_kwargs = { "init_lora_weights": "loftq", "loftq_config": LoftQConfig(loftq_bits=4, loftq_iter=config.loftq_iter), } + if task: + config_kwargs.update({"peft_task_type": get_peft_task_type_from_task(task, fail_on_not_found=True)}) - peft_task_type = get_peft_task_type_from_task(task, fail_on_not_found=True) - lora_config = LoraConfig( - r=config.lora_r, - lora_alpha=config.lora_alpha, - lora_dropout=config.lora_dropout, - target_modules=target_modules, - bias="none", - task_type=peft_task_type, - modules_to_save=config.modules_to_save, - **lora_config_kwargs, - ) - - return get_peft_model(model, lora_config) + return self.get_peft_model(model, config, config_kwargs) def enable_lora( self, model: "PreTrainedModel", - task: str, config: ConfigBase, + task: Optional[str] = None, adapter_path: Optional[str] = None, - target_modules: Optional[List[str]] = None, ) -> "PeftModel": """Enable LoRA fine-tuning on a Hugging Face PyTorch model. @@ -299,15 +315,11 @@ def enable_lora( Load or initialize LoRA adapters. :param model: The Hugging Face PyTorch model to enable LoRA fine-tuning on. - :param tokenizer: The tokenizer for the model. - :param task: The task type of the model. :param config: The config for the pass run. + :param task: The task type of the model. :param adapter_path: Path to the adapter weights. If None, will initialize new adapters. - :param target_modules: List of modules to target for LoRA fine-tuning. Only used if adapter_path is None. :return: The LoRA model. """ - from peft import PeftModel - logger.debug("Enabling LoRA fine-tuning") prepare_model_for_finetuning(model, config.training_args) @@ -319,8 +331,10 @@ def enable_lora( if not adapter_path: logger.debug("Initializing LoRA adapters from config") - lora_model = self.init_lora_adapters(model, task, config, target_modules=target_modules) + lora_model = self.init_adapters(model, config, task=task) else: + from peft import PeftModel + logger.debug("Loading LoRA adapters from %s", adapter_path) lora_model = PeftModel.from_pretrained(model, adapter_path, is_trainable=True) logger.debug("The number of trainable parameters in the LoRA model: %s", count_trainable_parameters(lora_model)) @@ -440,6 +454,39 @@ def get_torch_dtype(torch_dtype: str) -> "torch.dtype": assert torch_dtype in supported_dtypes, f"torch_dtype must be one of {supported_dtypes} but got {torch_dtype}" return resolve_torch_dtype(torch_dtype) + @staticmethod + def check_target_modules(model: HfModelHandler): + from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING + + model_type = model.get_hf_model_type() + if model_type not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + if model_type in MODELS_TO_LORA_TARGET_MODULES_MAPPING: + return MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_type] + else: + raise ValueError( + f"Model type {model_type} is not recognized by peft or olive. Please provide 'target_modules'." + ) + return None + + @staticmethod + def get_peft_model(model, config, config_kwargs=None): + from peft import LoraConfig, get_peft_model + + if config_kwargs is None: + config_kwargs = {} + + lora_config = LoraConfig( + r=config.r, + lora_alpha=config.alpha, + lora_dropout=config.lora_dropout, + target_modules=config.target_modules, + bias="none", + modules_to_save=config.modules_to_save, + **config_kwargs, + ) + + return get_peft_model(model, lora_config) + class LoRA(LoRABase): """Run LoRA fine-tuning on a Hugging Face PyTorch model.""" @@ -452,47 +499,162 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon config.update(super()._default_config(accelerator_spec)) return config - def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_model_path: str) -> HfModelHandler: - from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING - # convert config to pass config class - # this will validate the config and convert to the correct types - config = self._config_class(**config) +class LoRAVariant(LoRABase): + """Run LoRA variant fine-tuning on a Hugging Face PyTorch model.""" - # check dependencies - self.check_dependencies(config) + @classmethod + def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]: + config = { + "rank_dropout": PassConfigParam( + type_=float, + default_value=0.0, + description="The dropout probability for rank dimension during training.", + ), + "module_dropout": PassConfigParam( + type_=float, + default_value=0.0, + description="The dropout probability for disabling LoHa modules during training.", + ), + "use_effective_conv2d": PassConfigParam( + type_=bool, + default_value=True, + description="Use parameter effective decomposition for Conv2d with ksize > 1.", + ), + "target_modules": PassConfigParam( + type_=Optional[Union[List[str], str]], + default_value="all-linear", + description="The names of the modules to apply the adapter to.", + ), + "exclude_modules": PassConfigParam( + type_=Optional[Union[List[str], str]], default_value=None, description="Modules to exclude from LoHa." + ), + "init_weights": PassConfigParam( + type_=bool, default_value=True, description="Whether to perform initialization of adapter weights." + ), + "layers_to_transform": PassConfigParam( + type_=List[int], default_value=None, description="The layer indices to transform." + ), + "layers_pattern": PassConfigParam( + type_=List[str], + default_value=None, + description="The layer pattern name, used only if layers_to_transform is different from None.", + ), + "rank_pattern": PassConfigParam( + type_=Dict, + default_value={}, + description="The mapping from layer names or regexp expression " + "to ranks which are different from the default rank specified by r.", + ), + "alpha_pattern": PassConfigParam( + type_=Dict, + default_value={}, + description="The mapping from layer names or regexp expression " + "to alphas which are different from the default alpha specified by alpha.", + ), + } + config.update(super()._default_config(accelerator_spec)) + return config - # use default training args if not provided - config.training_args = config.training_args or HFTrainingArguments() - # check if peft or olive has target modules for the model - model_type = model.get_hf_model_type() - if not config.target_modules and model_type not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: - if model_type in MODELS_TO_LORA_TARGET_MODULES_MAPPING: - config.target_modules = MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_type] - else: - raise ValueError( - f"Model type {model_type} is not recognized by peft or olive. Please provide 'target_modules'." - ) +class LoHa(LoRAVariant): + """Run LoHa fine-tuning on a Hugging Face PyTorch model.""" - # get new model - pytorch_model = self.load_base_pytorch_model(model, config) - # NOTE: quantized model support - # awq: requires awq cuda extension or triton for backward pass, scale must be fp16 - # gptq: there is no custom backend. works fine when using naive dequantize + matmul - # no issue with single precision. mix precision depends on autocast as there is no input cast - # gradient might not be correct when using cuda/exllama deqauntize kernels - # we load in fp32/bf16 so cuda kernels are disabled by default. Might need extra work to - # disable exllama (gptq pass disables it) + @staticmethod + def get_peft_model(model, config, config_kwargs=None): + from peft import LoHaConfig, LoHaModel + + config = LoHaConfig( + r=config.r, + alpha=config.alpha, + rank_dropout=config.rank_dropout, + module_dropout=config.module_dropout, + use_effective_conv2d=config.use_effective_conv2d, + target_modules=config.target_modules, + exclude_modules=config.exclude_modules, + init_weights=config.init_weights, + layers_to_transform=config.layers_to_transform, + layers_pattern=config.layers_pattern, + rank_pattern=config.rank_pattern, + alpha_pattern=config.alpha_pattern, + modules_to_save=config.modules_to_save, + ) - # add lora modules - pytorch_model = self.enable_lora(pytorch_model, model.task, config, target_modules=config.target_modules) + return LoHaModel(model, config, "default") - # train and return new model - return self.train_and_save_new_model( - pytorch_model, model.get_hf_tokenizer(), config, deepcopy(model), output_model_path + @classmethod + def check_dependencies(cls, config: ConfigBase, is_qlora: bool = False): + """Check dependencies for the pass.""" + super().check_dependencies(config) + + from peft import __version__ as peft_version + + # LoHa is only supported after peft 0.7.0 + if version.parse(peft_version) < version.parse("0.7.0"): + raise ImportError(f"Please install peft >= 0.7.0 to use {cls.__name__} pass.") + + +class LoKr(LoRAVariant): + """Run LoKr fine-tuning on a Hugging Face PyTorch model.""" + + @staticmethod + def get_peft_model(model, config, config_kwargs=None): + from peft import LoKrConfig, LoKrModel + + config = LoKrConfig( + r=config.r, + alpha=config.alpha, + rank_dropout=config.rank_dropout, + module_dropout=config.module_dropout, + decompose_both=config.decompose_both, + decompose_factor=config.decompose_factor, + rank_dropout_scale=config.rank_dropout_scale, + use_effective_conv2d=config.use_effective_conv2d, + target_modules=config.target_modules, + exclude_modules=config.exclude_modules, + init_weights=config.init_weights, + layers_to_transform=config.layers_to_transform, + layers_pattern=config.layers_pattern, + rank_pattern=config.rank_pattern, + alpha_pattern=config.alpha_pattern, + modules_to_save=config.modules_to_save, ) + return LoKrModel(model, config, "default") + + @classmethod + def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]: + config = { + "decompose_both": PassConfigParam( + type_=bool, + default_value=False, + description="Perform rank decomposition of left kronecker product matrix.", + ), + "decompose_factor": PassConfigParam( + type_=int, + default_value=-1, + description="Kronecker product decomposition factor.", + ), + "rank_dropout_scale": PassConfigParam( + type_=bool, + default_value=False, + description="Whether to scale the rank dropout while training.", + ), + } + config.update(super()._default_config(accelerator_spec)) + return config + + @classmethod + def check_dependencies(cls, config: ConfigBase, is_qlora: bool = False): + """Check dependencies for the pass.""" + super().check_dependencies(config) + + from peft import __version__ as peft_version + + # LoHa is only supported after peft 0.7.0 + if version.parse(peft_version) < version.parse("0.7.0"): + raise ImportError(f"Please install peft >= 0.7.0 to use {cls.__name__} pass.") + class QLoRABase(LoRABase): """Base class for QLoRA and LoftQ fine-tuning passes.""" @@ -618,7 +780,8 @@ def get_quant_model( logger.debug("Quantized modules: %s", quantized_modules) # enable lora fine-tuning with new lora modules - pytorch_model = self.enable_lora(pytorch_model, model.task, config, target_modules=quantized_modules) + config.target_modules = quantized_modules + pytorch_model = self.enable_lora(pytorch_model, config, model.task) return deepcopy(model), pytorch_model, bnb_quant_config, quantized_modules @@ -671,7 +834,8 @@ def get_quant_model( # get loftq initialized lora model logger.debug("Initializing LoRA with LoftQ") - pytorch_model = self.init_lora_adapters(pytorch_model, model.task, config, quantized_modules, use_loftq=True) + config.target_modules = quantized_modules + pytorch_model = self.init_adapters(pytorch_model, config, task=model.task, use_loftq=True) output_model_path = Path(output_model_path) diff --git a/olive/passes/pytorch/train_utils.py b/olive/passes/pytorch/train_utils.py index 8b14efa70..fdc00f4e2 100644 --- a/olive/passes/pytorch/train_utils.py +++ b/olive/passes/pytorch/train_utils.py @@ -129,6 +129,13 @@ def prepare_model_for_finetuning(model: "PreTrainedModel", training_args: BaseHF :param model: The Hugging Face PyTorch model to prepare for fine-tuning. :param training_args: The training arguments for the model. """ + if training_args.gradient_checkpointing and not model.supports_gradient_checkpointing: + logger.warning( + "gradient_checkpointing is True, but model does not support gradient checkpointing! Setting" + " gradient_checkpoing to False" + ) + training_args.gradient_checkpointing = False + for param in model.parameters(): # freeze base model's layers param.requires_grad = False diff --git a/test/unit_test/passes/pytorch/test_lora.py b/test/unit_test/passes/pytorch/test_lora.py index cffcea60b..0903f443a 100644 --- a/test/unit_test/passes/pytorch/test_lora.py +++ b/test/unit_test/passes/pytorch/test_lora.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import platform from pathlib import Path +from unittest.mock import patch import pytest import torch @@ -12,7 +13,7 @@ from olive.data.template import huggingface_data_config_template from olive.model import HfModelHandler from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.pytorch.lora import LoftQ, LoRA, QLoRA +from olive.passes.pytorch.lora import LoftQ, LoHa, LoRA, QLoRA # pylint: disable=redefined-outer-name @@ -42,9 +43,9 @@ def get_pass_config(model_name, task, **kwargs): return { "train_data_config": data_config, # hidden sizes are 4 or 16 - # will have invalid adapter weights since `in_features` and/or `out_features` say 64 (lora_r) even though + # will have invalid adapter weights since `in_features` and/or `out_features` say 64 (r) even though # the actual weights are 4 or 16. Bug not from our code, it's from peft - "lora_r": 4, + "r": 4, "training_args": { "per_device_train_batch_size": 1, "per_device_eval_batch_size": 1, @@ -112,3 +113,12 @@ def test_loftq(tmp_path): # assert assert Path(out.get_resource("model_path")).exists() assert Path(out.get_resource("adapter_path")).exists() + + +@patch("transformers.Trainer.train") +def test_loha(mock_train, tmp_path): + # execute + out = run_finetuning(LoHa, tmp_path, torch_dtype="float32") + + # assert + assert Path(out.get_resource("adapter_path")).exists()