Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve local dataset loading #289

Merged
merged 2 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions finetrainers/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from .dataloader import DPDataLoader
from .dataset import (
ImageCaptionFilePairDataset,
ImageFileCaptionFileListDataset,
ImageFolderDataset,
ImageWebDataset,
ValidationDataset,
VideoCaptionFilePairDataset,
VideoFileCaptionFileListDataset,
VideoFolderDataset,
VideoWebDataset,
combine_datasets,
Expand Down
143 changes: 92 additions & 51 deletions finetrainers/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pathlib
import random
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import datasets
import datasets.data_files
Expand Down Expand Up @@ -30,6 +30,9 @@


MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024
COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"]
COMMON_VIDEO_FILES = ["video.txt", "videos.txt"]
COMMON_IMAGE_FILES = ["image.txt", "images.txt"]


class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
Expand Down Expand Up @@ -61,10 +64,6 @@ def __init__(self, root: str, infinite: bool = False) -> None:
def _get_data_iter(self):
if self._sample_index == 0:
return iter(self._data)

if isinstance(self._data, datasets.Dataset) and self._sample_index >= len(self._data):
return iter([])

return iter(self._data.skip(self._sample_index))

def __iter__(self):
Expand Down Expand Up @@ -685,13 +684,6 @@ def state_dict(self):
def initialize_dataset(
dataset_name_or_root: str, dataset_type: str = "video", streaming: bool = True, infinite: bool = False
) -> torch.utils.data.IterableDataset:
# 1. If there is a metadata.json or metadata.jsonl or metadata.csv file, we use the
# ImageFolderDataset or VideoFolderDataset class respectively.
# 2. If there is a list of .txt files sharing the same name with image extensions, we
# use the ImageCaptionFileDataset class.
# 3. If there is a list of .txt files sharing the same name with video extensions, we
# use the VideoCaptionFileDataset class.
# 4. If there is a dataset name, we use the ImageWebDataset or VideoWebDataset class.
assert dataset_type in ["image", "video"]

try:
Expand All @@ -717,71 +709,64 @@ def wrap_iterable_dataset_for_preprocessing(
return IterableDatasetPreprocessingWrapper(dataset, dataset_type, **config)


def _read_caption_from_file(filename: str) -> str:
with open(filename, "r") as f:
return f.read().strip()


def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
image = image.convert("RGB")
image = np.array(image).astype(np.float32)
image = torch.from_numpy(image)
image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
return image


def _preprocess_video(video: decord.VideoReader) -> torch.Tensor:
video = video.get_batch(list(range(len(video))))
video = video.permute(0, 3, 1, 2).contiguous()
video = video.float() / 127.5 - 1.0
return video


def _initialize_local_dataset(dataset_name_or_root: str, dataset_type: str, infinite: bool = False):
root = pathlib.Path(dataset_name_or_root)
supported_metadata_files = ["metadata.json", "metadata.jsonl", "metadata.csv"]
metadata_files = [root / metadata_file for metadata_file in supported_metadata_files]
metadata_files = [metadata_file for metadata_file in metadata_files if metadata_file.exists()]

dataset = None
if len(metadata_files) == 0:
raise ValueError(
f"No metadata file found. Please ensure there is a metadata file named one of: {supported_metadata_files}."
)
elif len(metadata_files) > 1:
if len(metadata_files) > 1:
raise ValueError("Found multiple metadata files. Please ensure there is only one metadata file.")
else:

if len(metadata_files) == 1:
if dataset_type == "image":
dataset = ImageFolderDataset(root.as_posix(), infinite=infinite)
else:
dataset = VideoFolderDataset(root.as_posix(), infinite=infinite)
return dataset

if dataset is None:
if _has_data_caption_file_pairs(root, remote=False):
if dataset_type == "image":
dataset = ImageCaptionFilePairDataset(root.as_posix(), infinite=infinite)
else:
dataset = VideoCaptionFilePairDataset(root.as_posix(), infinite=infinite)
elif _has_data_file_caption_file_lists(root, remote=False):
if dataset_type == "image":
dataset = ImageFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
else:
dataset = VideoFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
else:
raise ValueError(
f"Could not find any supported dataset structure in the directory {root}. Please open an issue at "
f"https://github.com/a-r-r-o-w/finetrainers with information about your dataset structure and we will "
f"help you set it up."
)

return dataset


def _initialize_hub_dataset(dataset_name: str, dataset_type: str, infinite: bool = False):
_common_caption_files = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"]
_common_video_files = ["video.txt", "videos.txt"]
_common_image_files = ["image.txt", "images.txt"]

repo_file_list = list_repo_files(dataset_name, repo_type="dataset")
has_caption_files = any(file in repo_file_list for file in _common_caption_files)
has_video_files = any(file in repo_file_list for file in _common_video_files)
has_image_files = any(file in repo_file_list for file in _common_image_files)

if has_caption_files and (has_video_files or has_image_files):
return _initialize_file_dataset(dataset_name, dataset_type, infinite)
if _has_data_caption_file_pairs(repo_file_list, remote=True):
return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
else:
return _initialize_webdataset(dataset_name, dataset_type, infinite)


def _initialize_file_dataset(
def _initialize_data_caption_file_dataset_from_hub(
dataset_name: str, dataset_type: str, infinite: bool = False
) -> torch.utils.data.IterableDataset:
logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
dataset_root = snapshot_download(dataset_name, repo_type="dataset")
if dataset_type == "image":
return ImageCaptionFilePairDataset(dataset_root, infinite=infinite)
else:
return VideoCaptionFilePairDataset(dataset_root, infinite=infinite)


def _initialize_data_file_caption_file_dataset_from_hub(
dataset_name: str, dataset_type: str, infinite: bool = False
) -> torch.utils.data.IterableDataset:
logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
Expand All @@ -800,3 +785,59 @@ def _initialize_webdataset(
return ImageWebDataset(dataset_name, infinite=infinite)
else:
return VideoWebDataset(dataset_name, infinite=infinite)


def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
# TODO(aryan): this logic can be improved
if not remote:
caption_files = utils.find_files(root.as_posix(), "*.txt", depth=0)
for caption_file in caption_files:
caption_file = pathlib.Path(caption_file)
for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
data_filename = caption_file.with_suffix(f".{extension}")
if data_filename.exists():
return True
return False
else:
caption_files = [file for file in root if file.endswith(".txt")]
for caption_file in caption_files:
for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
data_filename = caption_file.with_suffix(f".{extension}")
if data_filename in root:
return True
return False


def _has_data_file_caption_file_lists(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
# TODO(aryan): this logic can be improved
if not remote:
file_list = {x.name for x in root.iterdir()}
has_caption_files = any(file in file_list for file in COMMON_CAPTION_FILES)
has_video_files = any(file in file_list for file in COMMON_VIDEO_FILES)
has_image_files = any(file in file_list for file in COMMON_IMAGE_FILES)
return has_caption_files and (has_video_files or has_image_files)
else:
has_caption_files = any(file in root for file in COMMON_CAPTION_FILES)
has_video_files = any(file in root for file in COMMON_VIDEO_FILES)
has_image_files = any(file in root for file in COMMON_IMAGE_FILES)
return has_caption_files and (has_video_files or has_image_files)


def _read_caption_from_file(filename: str) -> str:
with open(filename, "r") as f:
return f.read().strip()


def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
image = image.convert("RGB")
image = np.array(image).astype(np.float32)
image = torch.from_numpy(image)
image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
return image


def _preprocess_video(video: decord.VideoReader) -> torch.Tensor:
video = video.get_batch(list(range(len(video))))
video = video.permute(0, 3, 1, 2).contiguous()
video = video.float() / 127.5 - 1.0
return video
Loading