@@ -369,7 +369,9 @@ def init_lora_adapters(
369
369
task : str ,
370
370
config : ConfigBase ,
371
371
target_modules : Optional [List [str ]] = None ,
372
+ ephemeral_gpu_offload : Optional [bool ] = False ,
372
373
use_loftq : Optional [bool ] = False ,
374
+ use_dora : Optional [bool ] = False ,
373
375
) -> "PeftModel" :
374
376
"""Initialize LoRA adapters.
375
377
@@ -380,7 +382,7 @@ def init_lora_adapters(
380
382
:param use_loftq: Whether to use LoftQ to initialize weights.
381
383
:return: The LoRA model.
382
384
"""
383
- from peft import LoraConfig , get_peft_model
385
+ from peft import LoraConfig , LoraRuntimeConfig , get_peft_model
384
386
385
387
lora_config_kwargs = {}
386
388
if use_loftq :
@@ -391,6 +393,11 @@ def init_lora_adapters(
391
393
"loftq_config" : LoftQConfig (loftq_bits = 4 , loftq_iter = config .loftq_iter ),
392
394
}
393
395
396
+ if use_dora :
397
+ lora_config_kwargs = {
398
+ "use_dora" : True ,
399
+ }
400
+
394
401
peft_task_type = get_peft_task_type_from_task (task , fail_on_not_found = True )
395
402
lora_config = LoraConfig (
396
403
r = config .lora_r ,
@@ -400,6 +407,7 @@ def init_lora_adapters(
400
407
bias = "none" ,
401
408
task_type = peft_task_type ,
402
409
modules_to_save = config .modules_to_save ,
410
+ runtime_config = LoraRuntimeConfig (ephemeral_gpu_offload = ephemeral_gpu_offload ),
403
411
** lora_config_kwargs ,
404
412
)
405
413
@@ -412,6 +420,8 @@ def enable_lora(
412
420
config : ConfigBase ,
413
421
adapter_path : Optional [str ] = None ,
414
422
target_modules : Optional [List [str ]] = None ,
423
+ use_dora : Optional [bool ] = False ,
424
+ ephemeral_gpu_offload : Optional [bool ] = False ,
415
425
) -> "PeftModel" :
416
426
"""Enable LoRA fine-tuning on a Hugging Face PyTorch model.
417
427
@@ -425,6 +435,8 @@ def enable_lora(
425
435
:param config: The config for the pass run.
426
436
:param adapter_path: Path to the adapter weights. If None, will initialize new adapters.
427
437
:param target_modules: List of modules to target for LoRA fine-tuning. Only used if adapter_path is None.
438
+ :param use_dora: Whether to use Dora for optimization.
439
+ :param ephemeral_gpu_offload: Whether to use ephemeral GPU offload.
428
440
:return: The LoRA model.
429
441
"""
430
442
from peft import PeftModel
@@ -450,10 +462,19 @@ def enable_lora(
450
462
)
451
463
if not adapter_path :
452
464
logger .debug ("Initializing LoRA adapters from config" )
453
- lora_model = self .init_lora_adapters (model , task , config , target_modules = target_modules )
465
+ lora_model = self .init_lora_adapters (
466
+ model ,
467
+ task ,
468
+ config ,
469
+ target_modules = target_modules ,
470
+ use_dora = use_dora ,
471
+ ephemeral_gpu_offload = ephemeral_gpu_offload ,
472
+ )
454
473
else :
455
474
logger .debug ("Loading LoRA adapters from %s" , adapter_path )
456
- lora_model = PeftModel .from_pretrained (model , adapter_path , is_trainable = True )
475
+ lora_model = PeftModel .from_pretrained (
476
+ model , adapter_path , is_trainable = True , ephemeral_gpu_offload = ephemeral_gpu_offload
477
+ )
457
478
logger .debug (
458
479
"The number of trainable parameters in the LoRA model: %s" , self .count_trainable_parameters (lora_model )
459
480
)
@@ -641,6 +662,42 @@ def _run_for_config(self, model: HfModelHandler, config: Dict[str, Any], output_
641
662
)
642
663
643
664
665
+ class DoRA (LoRABase ):
666
+ """Run DoRA fine-tuning on a Hugging Face PyTorch model."""
667
+
668
+ @classmethod
669
+ def _default_config (cls , accelerator_spec : AcceleratorSpec ) -> Dict [str , PassConfigParam ]:
670
+ config = {
671
+ "ephemeral_gpu_offload" : PassConfigParam (
672
+ type_ = bool , default_value = False , description = "Ephemeral GPU offload"
673
+ ),
674
+ }
675
+ config .update (super ()._default_config (accelerator_spec ))
676
+ return config
677
+
678
+ def _run_for_config (self , model : HfModelHandler , config : Dict [str , Any ], output_model_path : str ) -> HfModelHandler :
679
+ # convert config to pass config class
680
+ # this will validate the config and convert to the correct types
681
+ config = self ._config_class (** config )
682
+
683
+ # check dependencies
684
+ self .check_dependencies (config )
685
+
686
+ # use default training args if not provided
687
+ config .training_args = config .training_args or HFTrainingArguments ()
688
+
689
+ # get new model
690
+ pytorch_model = self .load_base_pytorch_model (model , config )
691
+
692
+ # add lora modules
693
+ pytorch_model = self .enable_lora (pytorch_model , model .task , config , use_dora = True )
694
+
695
+ # train and return new model
696
+ return self .train_and_save_new_model (
697
+ pytorch_model , model .get_hf_tokenizer (), config , deepcopy (model ), output_model_path
698
+ )
699
+
700
+
644
701
class QLoRABase (LoRABase ):
645
702
"""Base class for QLoRA and LoftQ fine-tuning passes."""
646
703
0 commit comments