@@ -103,23 +103,7 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
103
103
if isinstance (trainer , fl .Fabric ):
104
104
raise NotImplementedError ("Fabric is not supported yet." )
105
105
106
- trainer_ckpt_path = self .get_trainer_ckpt_path (model )
107
- if trainer_ckpt_path :
108
- trainer .ckpt_path = trainer_ckpt_path
109
- trainer .checkpoint_callback .last_model_path = trainer_ckpt_path
110
- # Load artifacts
111
- if getattr (self .restore_config , 'load_artifacts' , False ):
112
- if isinstance (trainer_ckpt_path , AdapterPath ):
113
- # load tokenizer from the base model during peft resume, in case the first peft checkpoint
114
- # is deleted before the current peft checkpoint is saved
115
- context_path = trainer_ckpt_path .base_model_path / "context"
116
- if not context_path .exists ():
117
- context_path = trainer_ckpt_path .base_model_path
118
- else :
119
- context_path = self .get_context_path (model )
120
- model = _try_restore_tokenizer (model , context_path )
121
-
122
- elif self .restore_config :
106
+ if self .restore_config :
123
107
new_path = self ._extract_path (
124
108
model = model ,
125
109
path = self .restore_config .path ,
@@ -139,6 +123,21 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
139
123
140
124
_try_restore_tokenizer (model , context_path )
141
125
126
+ elif (trainer_ckpt_path := self .get_trainer_ckpt_path (model )) is not None :
127
+ trainer .ckpt_path = trainer_ckpt_path
128
+ trainer .checkpoint_callback .last_model_path = trainer_ckpt_path
129
+ # Load artifacts
130
+ if getattr (self .restore_config , 'load_artifacts' , False ):
131
+ if isinstance (trainer_ckpt_path , AdapterPath ):
132
+ # load tokenizer from the base model during peft resume, in case the first peft checkpoint
133
+ # is deleted before the current peft checkpoint is saved
134
+ context_path = trainer_ckpt_path .base_model_path / "context"
135
+ if not context_path .exists ():
136
+ context_path = trainer_ckpt_path .base_model_path
137
+ else :
138
+ context_path = self .get_context_path (model )
139
+ model = _try_restore_tokenizer (model , context_path )
140
+
142
141
def _extract_path (
143
142
self , model : Optional [io .ConnectorMixin ], path : str , adapter_path : Optional [str ] = None
144
143
) -> BasePath :
0 commit comments