Skip to content

Commit

Permalink
Remove forceful precomputation behaviour; Improvements to data loading (
Browse files Browse the repository at this point in the history
#303)

* enable online dataprocessing without forcing precomputation

* add tests

* make style

* update

* update examples & docs
  • Loading branch information
a-r-r-o-w authored Mar 8, 2025
1 parent 4be06ae commit 41c3c40
Show file tree
Hide file tree
Showing 19 changed files with 666 additions and 148 deletions.
5 changes: 4 additions & 1 deletion docs/dataset/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ Any dataset loadable via the [🤗 HF datasets] directly should work (not widely
Any dataset loadable via the [🤗 HF datasets] directly should work (not widely tested at the moment). We support the [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/image_dataset#webdataset) and [`webdataset`](https://huggingface.co/docs/datasets/v3.3.2/en/video_dataset#webdataset) formats.



## Validation Dataset Format

Arguments related to validation are:
Expand Down Expand Up @@ -157,7 +159,8 @@ The following is a high-level overview of how datasets are loaded and preprocess

## Understanding how datasets are precomputed

There are 3 arguments related to precomputation:
There are 4 arguments related to precomputation:
- `--enable_precomputation`: If set, precomputation will be enabled. The parameters that follow are only relevant if this flag is set. If this flag is not set, all models will be loaded in memory and training will take place without first precomputing embeddings.
- `--precomputation_items`: The number of data points to precompute and store to disk at a time. This is useful for performing memory-efficient training without exhausting disk space by precomputing embeddings of the entire dataset(s) at once. We default to `512` data points, but configure this to a lower value for smaller datasets. As training progresses, the precomputed data will be read from disk and dispatched to data replicas. Once all precomputed data has been used, the next batch of data points will be precomputed and stored to disk in a rolling fashion.
- `--precomputation_dir`: The directory where precomputed data will be stored. This is useful for resuming training from a checkpoint, as the precomputed data will be loaded from this directory. If this directory is not provided, the precomputed data will be stored in the `--output_dir/precomputed`.
- `--precomputation_once`: If you're working with small datasets and want to precompute all embeddings at once, set this flag. This will allow you to train without having to compute embeddings every time the precomputed data is exhausted. Currently, `webdataset` format loading does not support this feature, and it is also disabled for `> 1024` data points due to hard coded logic (can be removed manually by users for now).
Expand Down
1 change: 1 addition & 0 deletions examples/training/sft/cogvideox/crush_smol_lora/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ model_cmd=(
dataset_cmd=(
--dataset_config $TRAINING_DATASET_CONFIG
--dataset_shuffle_buffer_size 10
--enable_precomputation
--precomputation_items 25
--precomputation_once
)
Expand Down
1 change: 1 addition & 0 deletions examples/training/sft/cogview4/raider_white_tarot/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ model_cmd=(
dataset_cmd=(
--dataset_config $TRAINING_DATASET_CONFIG
--dataset_shuffle_buffer_size 32
--enable_precomputation
--precomputation_items 120
--precomputation_once
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ model_cmd=(
dataset_cmd=(
--dataset_config $TRAINING_DATASET_CONFIG
--dataset_shuffle_buffer_size 10
--enable_precomputation
--precomputation_items 10
--precomputation_once
)
Expand Down
1 change: 1 addition & 0 deletions examples/training/sft/ltx_video/crush_smol_lora/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ model_cmd=(
dataset_cmd=(
--dataset_config $TRAINING_DATASET_CONFIG
--dataset_shuffle_buffer_size 10
--enable_precomputation
--precomputation_items 25
--precomputation_once
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ model_cmd=(
dataset_cmd=(
--dataset_config $TRAINING_DATASET_CONFIG
--dataset_shuffle_buffer_size 50
--enable_precomputation
--precomputation_items 100
--precomputation_once
)
Expand Down
1 change: 1 addition & 0 deletions examples/training/sft/wan/3dgs_dissolve/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ model_cmd=(
dataset_cmd=(
--dataset_config $TRAINING_DATASET_CONFIG
--dataset_shuffle_buffer_size 10
--enable_precomputation
--precomputation_items 100
--precomputation_once
)
Expand Down
1 change: 1 addition & 0 deletions examples/training/sft/wan/crush_smol_lora/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ model_cmd=(
dataset_cmd=(
--dataset_config $TRAINING_DATASET_CONFIG
--dataset_shuffle_buffer_size 10
--enable_precomputation
--precomputation_items 25
--precomputation_once
)
Expand Down
4 changes: 4 additions & 0 deletions finetrainers/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ class BaseArgs:
# Dataset arguments
dataset_config: str = None
dataset_shuffle_buffer_size: int = 1
enable_precomputation: bool = False
precomputation_items: int = 512
precomputation_dir: Optional[str] = None
precomputation_once: bool = False
Expand Down Expand Up @@ -420,6 +421,7 @@ def to_dict(self) -> Dict[str, Any]:
dataset_arguments = {
"dataset_config": self.dataset_config,
"dataset_shuffle_buffer_size": self.dataset_shuffle_buffer_size,
"enable_precomputation": self.enable_precomputation,
"precomputation_items": self.precomputation_items,
"precomputation_dir": self.precomputation_dir,
"precomputation_once": self.precomputation_once,
Expand Down Expand Up @@ -625,6 +627,7 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--dataset_config", type=str, required=True)
parser.add_argument("--dataset_shuffle_buffer_size", type=int, default=1)
parser.add_argument("--enable_precomputation", action="store_true")
parser.add_argument("--precomputation_items", type=int, default=512)
parser.add_argument("--precomputation_dir", type=str, default=None)
parser.add_argument("--precomputation_once", action="store_true")
Expand Down Expand Up @@ -761,6 +764,7 @@ def _map_to_args_type(args: Dict[str, Any]) -> BaseArgs:
# Dataset arguments
result_args.dataset_config = args.dataset_config
result_args.dataset_shuffle_buffer_size = args.dataset_shuffle_buffer_size
result_args.enable_precomputation = args.enable_precomputation
result_args.precomputation_items = args.precomputation_items
result_args.precomputation_dir = args.precomputation_dir or os.path.join(args.output_dir, "precomputed")
result_args.precomputation_once = args.precomputation_once
Expand Down
10 changes: 9 additions & 1 deletion finetrainers/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
initialize_dataset,
wrap_iterable_dataset_for_preprocessing,
)
from .precomputation import DistributedDataPreprocessor, PreprocessedDataIterable
from .precomputation import (
InMemoryDataIterable,
InMemoryDistributedDataPreprocessor,
InMemoryOnceDataIterable,
PrecomputedDataIterable,
PrecomputedDistributedDataPreprocessor,
PrecomputedOnceDataIterable,
initialize_preprocessor,
)
from .sampler import ResolutionSampler
from .utils import find_files
3 changes: 1 addition & 2 deletions finetrainers/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,11 +758,10 @@ def _initialize_hub_dataset(dataset_name: str, dataset_type: str, infinite: bool
elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)

has_tar_files = any(file.endswith(".tar") for file in repo_file_list)
has_tar_files = any(file.endswith(".tar") or file.endswith(".parquet") for file in repo_file_list)
if has_tar_files:
return _initialize_webdataset(dataset_name, dataset_type, infinite)

# TODO(aryan): handle parquet
# TODO(aryan): This should be improved
caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")]
if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT:
Expand Down
Loading

0 comments on commit 41c3c40

Please sign in to comment.