Skip to content

Commit

Permalink
fix (#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w authored Mar 9, 2025
1 parent 60ab7f0 commit cd8527a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 26 deletions.
23 changes: 13 additions & 10 deletions finetrainers/trainer/sft_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,15 +918,17 @@ def _prepare_data(
else:
logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")

parallel_backend = self.state.parallel_backend
train_state = self.state.train_state
self.checkpointer.save(
train_state.step,
force=True,
_device=parallel_backend.device,
_is_main_process=parallel_backend.is_main_process,
)
self._delete_components(component_names=["transformer", "unet"])
# TODO(aryan): This needs to be revisited. For some reason, the tests did not detect that self.transformer
# had become None after this but should have been loaded back from the checkpoint.
# parallel_backend = self.state.parallel_backend
# train_state = self.state.train_state
# self.checkpointer.save(
# train_state.step,
# force=True,
# _device=parallel_backend.device,
# _is_main_process=parallel_backend.is_main_process,
# )
# self._delete_components(component_names=["transformer", "unet"])

if self.args.precomputation_once:
consume_fn = preprocessor.consume_once
Expand Down Expand Up @@ -967,7 +969,8 @@ def _prepare_data(
self._delete_components(component_names)
del latent_components, component_names, component_modules

self.checkpointer.load()
# self.checkpointer.load()
# self.transformer = self.checkpointer.states["model"].model[0]

return condition_iterator, latent_iterator

Expand Down
32 changes: 16 additions & 16 deletions tests/trainer/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,55 +113,55 @@ def get_args(self) -> BaseArgs:
args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
return args

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
args = self.get_args()
args.dp_degree = 1
args.batch_size = 1
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool):
args = self.get_args()
args.dp_degree = 1
args.batch_size = 2
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
args = self.get_args()
args.dp_degree = 2
args.batch_size = 1
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool):
args = self.get_args()
args.dp_degree = 2
args.batch_size = 2
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
args = self.get_args()
args.dp_shards = 2
args.batch_size = 1
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool):
args = self.get_args()
args.dp_shards = 2
args.batch_size = 1
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
args = self.get_args()
args.dp_degree = 2
Expand All @@ -170,7 +170,7 @@ def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation:
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool):
args = self.get_args()
args.tp_degree = 2
Expand All @@ -186,55 +186,55 @@ def get_args(self) -> BaseArgs:
args.training_type = TrainingType.FULL_FINETUNE
return args

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
args = self.get_args()
args.dp_degree = 1
args.batch_size = 1
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool):
args = self.get_args()
args.dp_degree = 1
args.batch_size = 2
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
args = self.get_args()
args.dp_degree = 2
args.batch_size = 1
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool):
args = self.get_args()
args.dp_degree = 2
args.batch_size = 2
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
args = self.get_args()
args.dp_shards = 2
args.batch_size = 1
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool):
args = self.get_args()
args.dp_shards = 2
args.batch_size = 1
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
args = self.get_args()
args.dp_degree = 2
Expand All @@ -243,7 +243,7 @@ def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation:
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized("enable_precomputation", [False, True])
@parameterized.expand([(False,), (True,)])
def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool):
args = self.get_args()
args.tp_degree = 2
Expand Down

0 comments on commit cd8527a

Please sign in to comment.