Skip to content

Commit e4281fe

Browse files
committed
check restore_config first
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent 102bac6 commit e4281fe

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

nemo/lightning/resume.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -103,23 +103,7 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
103103
if isinstance(trainer, fl.Fabric):
104104
raise NotImplementedError("Fabric is not supported yet.")
105105

106-
trainer_ckpt_path = self.get_trainer_ckpt_path(model)
107-
if trainer_ckpt_path:
108-
trainer.ckpt_path = trainer_ckpt_path
109-
trainer.checkpoint_callback.last_model_path = trainer_ckpt_path
110-
# Load artifacts
111-
if getattr(self.restore_config, 'load_artifacts', False):
112-
if isinstance(trainer_ckpt_path, AdapterPath):
113-
# load tokenizer from the base model during peft resume, in case the first peft checkpoint
114-
# is deleted before the current peft checkpoint is saved
115-
context_path = trainer_ckpt_path.base_model_path / "context"
116-
if not context_path.exists():
117-
context_path = trainer_ckpt_path.base_model_path
118-
else:
119-
context_path = self.get_context_path(model)
120-
model = _try_restore_tokenizer(model, context_path)
121-
122-
elif self.restore_config:
106+
if self.restore_config:
123107
new_path = self._extract_path(
124108
model=model,
125109
path=self.restore_config.path,
@@ -139,6 +123,21 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
139123

140124
_try_restore_tokenizer(model, context_path)
141125

126+
elif (trainer_ckpt_path := self.get_trainer_ckpt_path(model)) is not None:
127+
trainer.ckpt_path = trainer_ckpt_path
128+
trainer.checkpoint_callback.last_model_path = trainer_ckpt_path
129+
# Load artifacts
130+
if getattr(self.restore_config, 'load_artifacts', False):
131+
if isinstance(trainer_ckpt_path, AdapterPath):
132+
# load tokenizer from the base model during peft resume, in case the first peft checkpoint
133+
# is deleted before the current peft checkpoint is saved
134+
context_path = trainer_ckpt_path.base_model_path / "context"
135+
if not context_path.exists():
136+
context_path = trainer_ckpt_path.base_model_path
137+
else:
138+
context_path = self.get_context_path(model)
139+
model = _try_restore_tokenizer(model, context_path)
140+
142141
def _extract_path(
143142
self, model: Optional[io.ConnectorMixin], path: str, adapter_path: Optional[str] = None
144143
) -> BasePath:

0 commit comments

Comments
 (0)