Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory using Validation #295

Open
dorpxam opened this issue Mar 5, 2025 · 6 comments
Open

Memory using Validation #295

dorpxam opened this issue Mar 5, 2025 · 6 comments

Comments

@dorpxam
Copy link

dorpxam commented Mar 5, 2025

I've just update with the latest version and start a training (LTXV LoRA) using validation steps this time. I use the new bash method (from the sft/ltx_video/crush_smol_lora example). I've some details to report.

  1. The truncation of the prompt used for validation with the message : The following part of your input was truncated because 'max_sequence_length' is set to 128 tokens. Is the max_sequence_length customizable ?
  2. The generated validation video are in the filename format : validation-step-n-n-idtoken-startoftheprompt.mp4. The problem is that if two validation prompts start with the same sequence, one validation video overwrite the other. Maybe it can be cool to replace the 'startoftheprompt' text by the id of the prompt in the list or a hash ?
  3. Here with my RTX4080 (16GB), I do not have OOM using the validation steps, but the problem comes with the release of the memory after the validation. Here is the logs :
INFO - Memory before training start: {
    "memory_allocated": 3.911,
    "memory_reserved": 4.463,
    "max_memory_allocated": 4.021,
    "max_memory_reserved": 4.463
}

INFO:finetrainers:Memory before training start: {
    "memory_allocated": 3.911,
    "memory_reserved": 4.463,
    "max_memory_allocated": 4.021,
    "max_memory_reserved": 4.463
}

INFO - Training configuration: {
    "trainable parameters": 29360128,
    "train steps": 1000,
    "per-replica batch size": 1,
    "global batch size": 1,
    "gradient accumulation steps": 1
}

INFO - Memory before validation start: {
    "memory_allocated": 3.935,
    "memory_reserved": 6.84,
    "max_memory_allocated": 12.823,
    "max_memory_reserved": 24.469
}

INFO:finetrainers:Memory before validation start: {
    "memory_allocated": 3.935,
    "memory_reserved": 6.84,
    "max_memory_allocated": 12.823,
    "max_memory_reserved": 24.469
}

INFO - Memory after validation end: {
    "memory_allocated": 13.614,
    "memory_reserved": 14.135,
    "max_memory_allocated": 15.019,
    "max_memory_reserved": 24.469
}

INFO:finetrainers:Memory after validation end: {
    "memory_allocated": 13.614,
    "memory_reserved": 14.135,
    "max_memory_allocated": 15.019,
    "max_memory_reserved": 24.469
}

The training continue after the validation step, but the iteration/seconds become crazy. Jump from 1 s/it to 10 s/it approximately depending of the step (probably because memory swap from VRAM and shared RAM). II guess there is a reason for this, but is there a possibility in the new version to optimize this part? What makes the memory usage grow during validation and not return to its initial state as during the previous training steps?

Anyway. You rock. The new version is clean. Maybe some colors in the command line could be appreciate for lisibility but this is just cosmethic details ;)

@dorpxam
Copy link
Author

dorpxam commented Mar 6, 2025

For information, while adding --enable_model_cpu_offload in the validation_cmd for testing this option, I got a crash after the 1st validation step.

2025-03-06 10:31:44,091 - finetrainers - INFO - Memory after validation end: {
    "memory_allocated": 0.897,
    "memory_reserved": 0.957,
    "max_memory_allocated": 13.091,
    "max_memory_reserved": 27.371
}
INFO:finetrainers:Memory after validation end: {
    "memory_allocated": 0.897,
    "memory_reserved": 0.957,
    "max_memory_allocated": 13.091,
    "max_memory_reserved": 27.371
}
2025-03-06 10:31:44,048 - finetrainers - DEBUG - Starting training step (361/7200)
DEBUG:finetrainers:Starting training step (361/7200)
2025-03-06 10:31:44,060 - finetrainers - ERROR - Error during training: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
ERROR:finetrainers:Error during training: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
wandb:
wandb:
wandb: Run history:
wandb:       global_avg_loss ▅▂▃▆▅▂▂▃▄▂▆▃▃▃▃▂▅▂▂▂▂▂▁▂▄▂▅▃▂▂▂▆▂▂█▂▃▂▅▂
wandb:       global_max_loss ▂▄▁▃▅▂▄▇▅▂█▇▅▄▅▅▃█▅▄█▅▂▂▂▁▄▁▁▂▇▂▃▂▇▃▆▄▆▄
wandb:             grad_norm ▁▁▁▂▂▁▁▁▁▁▁▁▂▁▁▁▂▁▂▁▁▁▁▁▁▂▅▁▁▂▁▃▂█▂▂▁▁▁▂
wandb: observed_data_samples ▁▁▁▁▂▂▂▃▃▃▃▄▄▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇██
wandb:   observed_num_tokens ▁▁▁▁▁▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇█████
wandb:
wandb: Run summary:
wandb:       global_avg_loss 0.41032
wandb:       global_max_loss 0.41032
wandb:             grad_norm 0.44087
wandb: observed_data_samples 360
wandb:   observed_num_tokens 10342400
wandb:
wandb: 🚀 View run winter-lake-9 at: https://wandb.ai/--------/finetrainers-ltxvideo/runs/--------
wandb: ⭐️ View project at: https://wandb.ai/--------/finetrainers-ltxvideo
wandb: Synced 5 W&B file(s), 3 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: logs/wandb/run-20250306_093039---------/logs
2025-03-06 10:31:47,028 - finetrainers - ERROR - An error occurred during training: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
ERROR:finetrainers:An error occurred during training: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
2025-03-06 10:31:47,038 - finetrainers - ERROR - Traceback (most recent call last):
  File "/home/dorpxam/AI/finetrainers/train.py", line 70, in main
    trainer.run()
  File "/home/dorpxam/AI/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 88, in run
    raise e
  File "/home/dorpxam/AI/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 83, in run
    self._train()
  File "/home/dorpxam/AI/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 454, in _train
    pred, target, sigmas = self.model_specification.forward(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/AI/finetrainers/finetrainers/models/ltx_video/base_specification.py", line 361, in forward
    pred = transformer(
           ^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/AI/finetrainers/finetrainers/patches/models/ltx_video/patch.py", line 67, in _patched_LTXVideoTransformer3Dforward
    temb, embedded_timestep = self.time_embed(
                              ^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/diffusers/models/normalization.py", line 266, in forward
    embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/diffusers/models/embeddings.py", line 2185, in forward
    timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))  # (N, D)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/diffusers/models/embeddings.py", line 1305, in forward
    sample = self.linear_1(sample)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

ERROR:finetrainers:Traceback (most recent call last):
  File "/home/dorpxam/AI/finetrainers/train.py", line 70, in main
    trainer.run()
  File "/home/dorpxam/AI/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 88, in run
    raise e
  File "/home/dorpxam/AI/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 83, in run
    self._train()
  File "/home/dorpxam/AI/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 454, in _train
    pred, target, sigmas = self.model_specification.forward(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/AI/finetrainers/finetrainers/models/ltx_video/base_specification.py", line 361, in forward
    pred = transformer(
           ^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/AI/finetrainers/finetrainers/patches/models/ltx_video/patch.py", line 67, in _patched_LTXVideoTransformer3Dforward
    temb, embedded_timestep = self.time_embed(
                              ^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/diffusers/models/normalization.py", line 266, in forward
    embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/diffusers/models/embeddings.py", line 2185, in forward
    timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))  # (N, D)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/diffusers/models/embeddings.py", line 1305, in forward
    sample = self.linear_1(sample)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2310: UserWarning: Run (ro2ls1ey) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.
  lambda data: self._console_raw_callback("stderr", data),
Training steps:   5%|█████████▊                                                                                                                                                                                         | 360/7200 [1:00:43<19:13:39, 10.12s/it, grad_norm=0.441, global_avg_loss=0.41, global_max_loss=0.41]
+ echo -ne '-------------------- Finished executing script --------------------\n\n'
-------------------- Finished executing script --------------------

@dorpxam
Copy link
Author

dorpxam commented Mar 6, 2025

Another point about the new pre-processing. In the train_multires.sh file a comment say:

Note: This is a copy of train.sh file from the same directory. For multi-resolution training, we
are utilizing the same dataset with different total frame counts. We do this with 4 copies of the
dataset, so also multiply the precomputation items with 4, in order to make sure all required embeddings
are precomputed at once.

The problem is that I have a large number of different durations in the corpus I use:

[23, 512, 768]
[27, 512, 768]
[29, 512, 768]
[31, 512, 768]
[33, 512, 768]
[37, 512, 768]
[39, 512, 768]
[41, 512, 768]
[42, 512, 768]
[46, 512, 768]
[49, 512, 768]
[53, 512, 768]
[58, 512, 768]
[62, 512, 768]
[63, 512, 768]
[71, 512, 768]
[73, 512, 768]
[82, 512, 768]
[83, 512, 768]
[96, 512, 768]
[104, 512, 768]
[113, 512, 768]
[116, 512, 768]
[168, 512, 768]
[257, 512, 768]

So 25 various duration, some like 257 frames is common to 4 videos.

If I use the rule indicated in the comment, for my 36 videos corpus with 25 durations type, I need to put the --precomputation_items to 36 x 25 = 900.

The problem is that the computation is intensive and very long (a lot more than previous code). For example, I just put the --precomputation_items to 144 (juste 36 x 4) and the current computation is insane:

Processing data on rank 0: 57%|████ | 82/144 [1:14:02<2:27:12, 142.45s/it]

The s/it varying depending of the step ofcourse, but here this is clearly a bootleneck.

I understand that precomputation can be slow, but it can be good to find a solution to fix this step to do not have to precompute more than needed. What do you think about that point?

EDIT: Pretty sure the s/it slowdown is due to memory swap during the precomputation. I think it's due by the number of items to precompute.

@a-r-r-o-w
Copy link
Owner

The truncation of the prompt used for validation with the message : The following part of your input was truncated because 'max_sequence_length' is set to 128 tokens. Is the max_sequence_length customizable ?

The base model is trained with a specific max sequence length. You could modify the code for it to be higher, but there are no guarantees on how it affects the quality. If a model hasn't observed data in the range it expects, it would probably not end up well. But, I haven't tried it myself to answer this, so maybe worth exploring!

The training continue after the validation step, but the iteration/seconds become crazy. Jump from 1 s/it to 10 s/it approximately depending of the step (probably because memory swap from VRAM and shared RAM). II guess there is a reason for this, but is there a possibility in the new version to optimize this part? What makes the memory usage grow during validation and not return to its initial state as during the previous training steps?

Do you notice this consistently? Does it ever recover back to 1 s/it? I'm unable to replicate the behaviour in all my examples I've shared in the examples/ directory -- they immediately start are training as fast as before validation.

Anyway. You rock. The new version is clean. Maybe some colors in the command line could be appreciate for lisibility but this is just cosmethic details ;)

Awesome, thank you for the kind words! I'll think about the cosmetic information after some bigger optimizations!

For information, while adding --enable_model_cpu_offload in the validation_cmd for testing this option, I got a crash after the 1st validation step.

Seems like a bug. I'll try to repro and fix

I understand that precomputation can be slow, but it can be good to find a solution to fix this step to do not have to precompute more than needed. What do you think about that point?

Working on adding non-precomputation based data loading soon! Also tracking in #296.

The --image_resolution_buckets and --video_resolution_buckets can accept a list of values, not just one. My intention in the example was to show that you can effectively repeat the same dataset multiple times at different resolutions. BUT, if you have multiple duration/resolution data within a single dataset, you're better off specifying the values you want to train with in the bucket-related parameters

@dorpxam
Copy link
Author

dorpxam commented Mar 8, 2025

The base model is trained with a specific max sequence length. You could modify the code for it to be higher, but there are no guarantees on how it affects the quality. If a model hasn't observed data in the range it expects, it would probably not end up well. But, I haven't tried it myself to answer this, so maybe worth exploring!

Yes. I understand. It seem that LTXV fix the max tokens to 128. I need to adapt my automatic prompt generation to stay under.

Do you notice this consistently? Does it ever recover back to 1 s/it? I'm unable to replicate the behaviour in all my examples I've shared in the examples/ directory -- they immediately start are training as fast as before validation.

For the test I've done, yes. But I think (instinctively) that it's a memory swap problem. I reach the VRAM limit and CUDA start to use shared memory (memory swap), CUDA seem to stay in this state. I'm not sure and do not have the expertise, but I will search information about that point soon. Maybe a little trick exists to override this problem.

Awesome, thank you for the kind words! I'll think about the cosmetic information after some bigger optimizations!

Sure! you deserve lots of encouragement. finetrainers is an excellent project.

Seems like a bug. I'll try to repro and fix

Cool.

Working on adding non-precomputation based data loading soon! Also tracking in #296.

The --image_resolution_buckets and --video_resolution_buckets can accept a list of values, not just one. My intention in the example was to show that you can effectively repeat the same dataset multiple times at different resolutions. BUT, if you have multiple duration/resolution data within a single dataset, you're better off specifying the values you want to train with in the bucket-related parameters

Yes. This is clearly the only problem for me now. Even with low --precomputation_items, this is a loss of time when you want to experiment and start and restart over the same corpus. Anyway, I'm sure you will find a solution ;)

@a-r-r-o-w
Copy link
Owner

For the test I've done, yes. But I think (instinctively) that it's a memory swap problem. I reach the VRAM limit and CUDA start to use shared memory (memory swap), CUDA seem to stay in this state. I'm not sure and do not have the expertise, but I will search information about that point soon. Maybe a little trick exists to override this problem.

I've made some updates to how precomputation is done to save more VRAM, and am consistently seeing the usage either lower than or equivalent to legacy scripts with precomputation. It should hopefully help and not end up using shared memory.

Yes. This is clearly the only problem for me now. Even with low --precomputation_items, this is a loss of time when you want to experiment and start and restart over the same corpus. Anyway, I'm sure you will find a solution ;)

For precomputation, if the dataset remains the same and is a local dataset, I should probably not trigger it to even start. Currently, a new run always runs precomputation if enabled with --enable_precomputation. But I need to think about this a little longer since there's a variety of cases that need handling:

  • Local dataset
    • User has decided to precompute entire dataset at once with --precomputation_once on run X, so we can re-use that in run X + 1
    • User has decided to not run full precomputation in run X, so run X + 1 will need to run on-the-fly or full precomputation again
    • User changed the dataset preprocessing config from run X to X + 1, so now will need to run precomputation again if specified
  • Remote dataset
    • If dataset can be made local by snapshot_download'ing it, follow same as above
    • If dataset is too large or in webdataset/parquet format exceeding default file limit, need to run precomputation on each run

These are some of the core cases, but can bet that there will be a lot more cases to cover. In general, could use some improvements with dataset handling since there are so many formats to deal with correctly.

@dorpxam
Copy link
Author

dorpxam commented Mar 9, 2025

If I understand the various cases, we can reuse the precomputed conditions/latents. It can be cool, for a local training, to check if the 'precomputed' directory and files exists and do not overwrite. From my point of view, it's always more simple for the user to delete the directory manually than to edit the script file. What do you think about this?

EDIT:

Testing with:

dataset_cmd=(
  --dataset_config $TRAINING_DATASET_CONFIG
  --dataset_shuffle_buffer_size 24
  --precomputation_items 24
  --precomputation_once
  --enable_precomputation
)

The process start and precompute. But if I restart, the precomputation restart and overwrite previous one. If I remove --enable_precomputation after the first precomputation:

dataset_cmd=(
  --dataset_config $TRAINING_DATASET_CONFIG
  --dataset_shuffle_buffer_size 24
  --precomputation_items 24
  --precomputation_once
)

The process start with this message:
INFO - Precomputation disabled. Loading in-memory data loaders. All components will be loaded on GPUs.
I admit this is a bit ambiguous. Is the precomputed directory reused?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants