Skip to content

Commit

Permalink
Merge branch 'feature/remove-forceful-precomputation-behaviour' into …
Browse files Browse the repository at this point in the history
…training/cogview4-the_simpsons
  • Loading branch information
a-r-r-o-w committed Mar 8, 2025
2 parents 064101e + 788a73b commit edf988c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
14 changes: 6 additions & 8 deletions finetrainers/data/precomputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def consume(
if drop_samples:
raise ValueError("Cannot cache and drop samples at the same time.")

for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items):
for i in range(self._num_items):
if use_cached_samples:
item = self._cached_samples[i]
else:
Expand Down Expand Up @@ -102,7 +102,7 @@ def consume_once(
if drop_samples:
raise ValueError("Cannot cache and drop samples at the same time.")

for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items):
for i in range(self._num_items):
if use_cached_samples:
item = self._cached_samples[i]
else:
Expand Down Expand Up @@ -241,10 +241,10 @@ def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> N
self._requires_data = False

def __iter__(self) -> Iterable[Dict[str, Any]]:
while length := self._buffer.get_length(self._data_type) > 0:
yield self._buffer.get(self._data_type)
if length == 1:
while (length := self._buffer.get_length(self._data_type)) > 0:
if length <= 1:
self._requires_data = True
yield self._buffer.get(self._data_type)

def __len__(self) -> int:
return self._buffer.get_length(self._data_type)
Expand All @@ -269,10 +269,8 @@ def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> N
self._requires_data = False

def __iter__(self) -> Iterable[Dict[str, Any]]:
assert len(self) > 0, "No data available in the buffer."
while True:
if self._buffer.get_length(self._data_type) == 0:
self._requires_data = True
break
item = self._buffer.get(self._data_type)
yield item
self._buffer.add(self._data_type, item)
Expand Down
35 changes: 20 additions & 15 deletions finetrainers/trainer/sft_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@


class SFTTrainer:
# fmt: off
_all_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"]
_condition_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3"]
_latent_component_names = ["vae"]
_diffusion_component_names = ["transformer", "unet", "scheduler"]
# fmt: on

def __init__(self, args: "BaseArgs", model_specification: "ModelSpecification") -> None:
self.args = args
self.state = State()
Expand Down Expand Up @@ -810,24 +817,16 @@ def _move_components_to_device(
component.to(device)

def _set_components(self, components: Dict[str, Any]) -> None:
# fmt: off
component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"]
# fmt: on

for component_name in component_names:
for component_name in self._all_component_names:
existing_component = getattr(self, component_name, None)
new_component = components.get(component_name, existing_component)
setattr(self, component_name, new_component)

def _delete_components(self, component_names: Optional[List[str]] = None) -> None:
if component_names is None:
# fmt: off
component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"]
# fmt: on

component_names = self._all_component_names
for component_name in component_names:
setattr(self, component_name, None)

utils.free_memory()
utils.synchronize_device()

Expand Down Expand Up @@ -879,19 +878,25 @@ def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline:
self._move_components_to_device(list(components.values()))
return pipeline

def _prepare_data(self, preprocessor: data.PrecomputedDistributedDataPreprocessor, data_iterator):
def _prepare_data(
self,
preprocessor: Union[data.InMemoryDistributedDataPreprocessor, data.PrecomputedDistributedDataPreprocessor],
data_iterator,
):
if not self.args.enable_precomputation:
logger.info(
"Precomputation disabled. Loading in-memory data loaders. All components will be loaded on GPUs."
)

if not self._are_condition_models_loaded:
logger.info(
"Precomputation disabled. Loading in-memory data loaders. All components will be loaded on GPUs."
)
condition_components = self.model_specification.load_condition_models()
latent_components = self.model_specification.load_latent_models()
all_components = {**condition_components, **latent_components}
self._set_components(all_components)
self._move_components_to_device(list(all_components.values()))
utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling)
else:
condition_components = {k: v for k in self._condition_component_names if (v := getattr(self, k, None))}
latent_components = {k: v for k in self._latent_component_names if (v := getattr(self, k, None))}

condition_iterator = preprocessor.consume(
"condition",
Expand Down

0 comments on commit edf988c

Please sign in to comment.