Skip to content

Commit

Permalink
[core] Ensure loading mp first (#252)
Browse files Browse the repository at this point in the history
* fix: model card info.

* always ensure loading takes place on the main process first.

* revert changes

* revert
  • Loading branch information
sayakpaul authored Jan 29, 2025
1 parent 0b4b61b commit 836ac78
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions finetrainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,12 @@ def prepare_models(self) -> None:
load_components_kwargs = self._get_load_components_kwargs()
condition_components, latent_components, diffusion_components = {}, {}, {}
if not self.args.precompute_conditions:
condition_components = self.model_config["load_condition_models"](**load_components_kwargs)
latent_components = self.model_config["load_latent_models"](**load_components_kwargs)
diffusion_components = self.model_config["load_diffusion_models"](**load_components_kwargs)
# To download the model files first on the main process (if not already present)
# and then load the cached files afterward from the other processes.
with self.state.accelerator.main_process_first():
condition_components = self.model_config["load_condition_models"](**load_components_kwargs)
latent_components = self.model_config["load_latent_models"](**load_components_kwargs)
diffusion_components = self.model_config["load_diffusion_models"](**load_components_kwargs)

components = {}
components.update(condition_components)
Expand Down Expand Up @@ -204,7 +207,8 @@ def collate_fn(batch):
logger.info("Precomputed conditions and latents not found. Running precomputation.")

# At this point, no models are loaded, so we need to load and precompute conditions and latents
condition_components = self.model_config["load_condition_models"](**self._get_load_components_kwargs())
with self.state.accelerator.main_process_first():
condition_components = self.model_config["load_condition_models"](**self._get_load_components_kwargs())
self._set_components(condition_components)
self._move_components_to_device()
self._disable_grad_for_components([self.text_encoder, self.text_encoder_2, self.text_encoder_3])
Expand Down Expand Up @@ -258,7 +262,8 @@ def collate_fn(batch):
torch.cuda.reset_peak_memory_stats(accelerator.device)

# Precompute latents
latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs())
with self.state.accelerator.main_process_first():
latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs())
self._set_components(latent_components)
self._move_components_to_device()
self._disable_grad_for_components([self.vae])
Expand Down Expand Up @@ -319,7 +324,8 @@ def collate_fn(batch):
def prepare_trainable_parameters(self) -> None:
logger.info("Initializing trainable parameters")

diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs())
with self.state.accelerator.main_process_first():
diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs())
self._set_components(diffusion_components)

components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.vae]
Expand Down

0 comments on commit 836ac78

Please sign in to comment.