Skip to content

Commit

Permalink
Improve local dataset loading (#289)
Browse files Browse the repository at this point in the history
* update

* update
  • Loading branch information
a-r-r-o-w authored Mar 4, 2025
1 parent 7a34bbf commit ea69aaf
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 134 deletions.
2 changes: 2 additions & 0 deletions finetrainers/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from .dataloader import DPDataLoader
from .dataset import (
ImageCaptionFilePairDataset,
ImageFileCaptionFileListDataset,
ImageFolderDataset,
ImageWebDataset,
ValidationDataset,
VideoCaptionFilePairDataset,
VideoFileCaptionFileListDataset,
VideoFolderDataset,
VideoWebDataset,
combine_datasets,
Expand Down
143 changes: 92 additions & 51 deletions finetrainers/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pathlib
import random
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import datasets
import datasets.data_files
Expand Down Expand Up @@ -30,6 +30,9 @@


MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024
COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"]
COMMON_VIDEO_FILES = ["video.txt", "videos.txt"]
COMMON_IMAGE_FILES = ["image.txt", "images.txt"]


class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
Expand Down Expand Up @@ -61,10 +64,6 @@ def __init__(self, root: str, infinite: bool = False) -> None:
def _get_data_iter(self):
if self._sample_index == 0:
return iter(self._data)

if isinstance(self._data, datasets.Dataset) and self._sample_index >= len(self._data):
return iter([])

return iter(self._data.skip(self._sample_index))

def __iter__(self):
Expand Down Expand Up @@ -685,13 +684,6 @@ def state_dict(self):
def initialize_dataset(
dataset_name_or_root: str, dataset_type: str = "video", streaming: bool = True, infinite: bool = False
) -> torch.utils.data.IterableDataset:
# 1. If there is a metadata.json or metadata.jsonl or metadata.csv file, we use the
# ImageFolderDataset or VideoFolderDataset class respectively.
# 2. If there is a list of .txt files sharing the same name with image extensions, we
# use the ImageCaptionFileDataset class.
# 3. If there is a list of .txt files sharing the same name with video extensions, we
# use the VideoCaptionFileDataset class.
# 4. If there is a dataset name, we use the ImageWebDataset or VideoWebDataset class.
assert dataset_type in ["image", "video"]

try:
Expand All @@ -717,71 +709,64 @@ def wrap_iterable_dataset_for_preprocessing(
return IterableDatasetPreprocessingWrapper(dataset, dataset_type, **config)


def _read_caption_from_file(filename: str) -> str:
with open(filename, "r") as f:
return f.read().strip()


def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
image = image.convert("RGB")
image = np.array(image).astype(np.float32)
image = torch.from_numpy(image)
image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
return image


def _preprocess_video(video: decord.VideoReader) -> torch.Tensor:
video = video.get_batch(list(range(len(video))))
video = video.permute(0, 3, 1, 2).contiguous()
video = video.float() / 127.5 - 1.0
return video


def _initialize_local_dataset(dataset_name_or_root: str, dataset_type: str, infinite: bool = False):
root = pathlib.Path(dataset_name_or_root)
supported_metadata_files = ["metadata.json", "metadata.jsonl", "metadata.csv"]
metadata_files = [root / metadata_file for metadata_file in supported_metadata_files]
metadata_files = [metadata_file for metadata_file in metadata_files if metadata_file.exists()]

dataset = None
if len(metadata_files) == 0:
raise ValueError(
f"No metadata file found. Please ensure there is a metadata file named one of: {supported_metadata_files}."
)
elif len(metadata_files) > 1:
if len(metadata_files) > 1:
raise ValueError("Found multiple metadata files. Please ensure there is only one metadata file.")
else:

if len(metadata_files) == 1:
if dataset_type == "image":
dataset = ImageFolderDataset(root.as_posix(), infinite=infinite)
else:
dataset = VideoFolderDataset(root.as_posix(), infinite=infinite)
return dataset

if dataset is None:
if _has_data_caption_file_pairs(root, remote=False):
if dataset_type == "image":
dataset = ImageCaptionFilePairDataset(root.as_posix(), infinite=infinite)
else:
dataset = VideoCaptionFilePairDataset(root.as_posix(), infinite=infinite)
elif _has_data_file_caption_file_lists(root, remote=False):
if dataset_type == "image":
dataset = ImageFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
else:
dataset = VideoFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
else:
raise ValueError(
f"Could not find any supported dataset structure in the directory {root}. Please open an issue at "
f"https://github.com/a-r-r-o-w/finetrainers with information about your dataset structure and we will "
f"help you set it up."
)

return dataset


def _initialize_hub_dataset(dataset_name: str, dataset_type: str, infinite: bool = False):
_common_caption_files = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"]
_common_video_files = ["video.txt", "videos.txt"]
_common_image_files = ["image.txt", "images.txt"]

repo_file_list = list_repo_files(dataset_name, repo_type="dataset")
has_caption_files = any(file in repo_file_list for file in _common_caption_files)
has_video_files = any(file in repo_file_list for file in _common_video_files)
has_image_files = any(file in repo_file_list for file in _common_image_files)

if has_caption_files and (has_video_files or has_image_files):
return _initialize_file_dataset(dataset_name, dataset_type, infinite)
if _has_data_caption_file_pairs(repo_file_list, remote=True):
return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
else:
return _initialize_webdataset(dataset_name, dataset_type, infinite)


def _initialize_file_dataset(
def _initialize_data_caption_file_dataset_from_hub(
dataset_name: str, dataset_type: str, infinite: bool = False
) -> torch.utils.data.IterableDataset:
logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
dataset_root = snapshot_download(dataset_name, repo_type="dataset")
if dataset_type == "image":
return ImageCaptionFilePairDataset(dataset_root, infinite=infinite)
else:
return VideoCaptionFilePairDataset(dataset_root, infinite=infinite)


def _initialize_data_file_caption_file_dataset_from_hub(
dataset_name: str, dataset_type: str, infinite: bool = False
) -> torch.utils.data.IterableDataset:
logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
Expand All @@ -800,3 +785,59 @@ def _initialize_webdataset(
return ImageWebDataset(dataset_name, infinite=infinite)
else:
return VideoWebDataset(dataset_name, infinite=infinite)


def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
# TODO(aryan): this logic can be improved
if not remote:
caption_files = utils.find_files(root.as_posix(), "*.txt", depth=0)
for caption_file in caption_files:
caption_file = pathlib.Path(caption_file)
for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
data_filename = caption_file.with_suffix(f".{extension}")
if data_filename.exists():
return True
return False
else:
caption_files = [file for file in root if file.endswith(".txt")]
for caption_file in caption_files:
for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
data_filename = caption_file.with_suffix(f".{extension}")
if data_filename in root:
return True
return False


def _has_data_file_caption_file_lists(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
# TODO(aryan): this logic can be improved
if not remote:
file_list = {x.name for x in root.iterdir()}
has_caption_files = any(file in file_list for file in COMMON_CAPTION_FILES)
has_video_files = any(file in file_list for file in COMMON_VIDEO_FILES)
has_image_files = any(file in file_list for file in COMMON_IMAGE_FILES)
return has_caption_files and (has_video_files or has_image_files)
else:
has_caption_files = any(file in root for file in COMMON_CAPTION_FILES)
has_video_files = any(file in root for file in COMMON_VIDEO_FILES)
has_image_files = any(file in root for file in COMMON_IMAGE_FILES)
return has_caption_files and (has_video_files or has_image_files)


def _read_caption_from_file(filename: str) -> str:
with open(filename, "r") as f:
return f.read().strip()


def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
image = image.convert("RGB")
image = np.array(image).astype(np.float32)
image = torch.from_numpy(image)
image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
return image


def _preprocess_video(video: decord.VideoReader) -> torch.Tensor:
video = video.get_batch(list(range(len(video))))
video = video.permute(0, 3, 1, 2).contiguous()
video = video.float() / 127.5 - 1.0
return video
Loading

0 comments on commit ea69aaf

Please sign in to comment.