diff --git a/finetrainers/data/__init__.py b/finetrainers/data/__init__.py index bfd21c4..7706fb9 100644 --- a/finetrainers/data/__init__.py +++ b/finetrainers/data/__init__.py @@ -2,10 +2,12 @@ from .dataloader import DPDataLoader from .dataset import ( ImageCaptionFilePairDataset, + ImageFileCaptionFileListDataset, ImageFolderDataset, ImageWebDataset, ValidationDataset, VideoCaptionFilePairDataset, + VideoFileCaptionFileListDataset, VideoFolderDataset, VideoWebDataset, combine_datasets, diff --git a/finetrainers/data/dataset.py b/finetrainers/data/dataset.py index 18dad71..a319f9a 100644 --- a/finetrainers/data/dataset.py +++ b/finetrainers/data/dataset.py @@ -1,6 +1,6 @@ import pathlib import random -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import datasets import datasets.data_files @@ -30,6 +30,9 @@ MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024 +COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"] +COMMON_VIDEO_FILES = ["video.txt", "videos.txt"] +COMMON_IMAGE_FILES = ["image.txt", "images.txt"] class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): @@ -61,10 +64,6 @@ def __init__(self, root: str, infinite: bool = False) -> None: def _get_data_iter(self): if self._sample_index == 0: return iter(self._data) - - if isinstance(self._data, datasets.Dataset) and self._sample_index >= len(self._data): - return iter([]) - return iter(self._data.skip(self._sample_index)) def __iter__(self): @@ -685,13 +684,6 @@ def state_dict(self): def initialize_dataset( dataset_name_or_root: str, dataset_type: str = "video", streaming: bool = True, infinite: bool = False ) -> torch.utils.data.IterableDataset: - # 1. If there is a metadata.json or metadata.jsonl or metadata.csv file, we use the - # ImageFolderDataset or VideoFolderDataset class respectively. - # 2. If there is a list of .txt files sharing the same name with image extensions, we - # use the ImageCaptionFileDataset class. - # 3. If there is a list of .txt files sharing the same name with video extensions, we - # use the VideoCaptionFileDataset class. - # 4. If there is a dataset name, we use the ImageWebDataset or VideoWebDataset class. assert dataset_type in ["image", "video"] try: @@ -717,71 +709,64 @@ def wrap_iterable_dataset_for_preprocessing( return IterableDatasetPreprocessingWrapper(dataset, dataset_type, **config) -def _read_caption_from_file(filename: str) -> str: - with open(filename, "r") as f: - return f.read().strip() - - -def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor: - image = image.convert("RGB") - image = np.array(image).astype(np.float32) - image = torch.from_numpy(image) - image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0 - return image - - -def _preprocess_video(video: decord.VideoReader) -> torch.Tensor: - video = video.get_batch(list(range(len(video)))) - video = video.permute(0, 3, 1, 2).contiguous() - video = video.float() / 127.5 - 1.0 - return video - - def _initialize_local_dataset(dataset_name_or_root: str, dataset_type: str, infinite: bool = False): root = pathlib.Path(dataset_name_or_root) supported_metadata_files = ["metadata.json", "metadata.jsonl", "metadata.csv"] metadata_files = [root / metadata_file for metadata_file in supported_metadata_files] metadata_files = [metadata_file for metadata_file in metadata_files if metadata_file.exists()] - dataset = None - if len(metadata_files) == 0: - raise ValueError( - f"No metadata file found. Please ensure there is a metadata file named one of: {supported_metadata_files}." - ) - elif len(metadata_files) > 1: + if len(metadata_files) > 1: raise ValueError("Found multiple metadata files. Please ensure there is only one metadata file.") - else: + + if len(metadata_files) == 1: if dataset_type == "image": dataset = ImageFolderDataset(root.as_posix(), infinite=infinite) else: dataset = VideoFolderDataset(root.as_posix(), infinite=infinite) + return dataset - if dataset is None: + if _has_data_caption_file_pairs(root, remote=False): if dataset_type == "image": dataset = ImageCaptionFilePairDataset(root.as_posix(), infinite=infinite) else: dataset = VideoCaptionFilePairDataset(root.as_posix(), infinite=infinite) + elif _has_data_file_caption_file_lists(root, remote=False): + if dataset_type == "image": + dataset = ImageFileCaptionFileListDataset(root.as_posix(), infinite=infinite) + else: + dataset = VideoFileCaptionFileListDataset(root.as_posix(), infinite=infinite) + else: + raise ValueError( + f"Could not find any supported dataset structure in the directory {root}. Please open an issue at " + f"https://github.com/a-r-r-o-w/finetrainers with information about your dataset structure and we will " + f"help you set it up." + ) return dataset def _initialize_hub_dataset(dataset_name: str, dataset_type: str, infinite: bool = False): - _common_caption_files = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"] - _common_video_files = ["video.txt", "videos.txt"] - _common_image_files = ["image.txt", "images.txt"] - repo_file_list = list_repo_files(dataset_name, repo_type="dataset") - has_caption_files = any(file in repo_file_list for file in _common_caption_files) - has_video_files = any(file in repo_file_list for file in _common_video_files) - has_image_files = any(file in repo_file_list for file in _common_image_files) - - if has_caption_files and (has_video_files or has_image_files): - return _initialize_file_dataset(dataset_name, dataset_type, infinite) + if _has_data_caption_file_pairs(repo_file_list, remote=True): + return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite) + elif _has_data_file_caption_file_lists(repo_file_list, remote=True): + return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite) else: return _initialize_webdataset(dataset_name, dataset_type, infinite) -def _initialize_file_dataset( +def _initialize_data_caption_file_dataset_from_hub( + dataset_name: str, dataset_type: str, infinite: bool = False +) -> torch.utils.data.IterableDataset: + logger.info(f"Downloading dataset {dataset_name} from the HF Hub") + dataset_root = snapshot_download(dataset_name, repo_type="dataset") + if dataset_type == "image": + return ImageCaptionFilePairDataset(dataset_root, infinite=infinite) + else: + return VideoCaptionFilePairDataset(dataset_root, infinite=infinite) + + +def _initialize_data_file_caption_file_dataset_from_hub( dataset_name: str, dataset_type: str, infinite: bool = False ) -> torch.utils.data.IterableDataset: logger.info(f"Downloading dataset {dataset_name} from the HF Hub") @@ -800,3 +785,59 @@ def _initialize_webdataset( return ImageWebDataset(dataset_name, infinite=infinite) else: return VideoWebDataset(dataset_name, infinite=infinite) + + +def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool: + # TODO(aryan): this logic can be improved + if not remote: + caption_files = utils.find_files(root.as_posix(), "*.txt", depth=0) + for caption_file in caption_files: + caption_file = pathlib.Path(caption_file) + for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]: + data_filename = caption_file.with_suffix(f".{extension}") + if data_filename.exists(): + return True + return False + else: + caption_files = [file for file in root if file.endswith(".txt")] + for caption_file in caption_files: + for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]: + data_filename = caption_file.with_suffix(f".{extension}") + if data_filename in root: + return True + return False + + +def _has_data_file_caption_file_lists(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool: + # TODO(aryan): this logic can be improved + if not remote: + file_list = {x.name for x in root.iterdir()} + has_caption_files = any(file in file_list for file in COMMON_CAPTION_FILES) + has_video_files = any(file in file_list for file in COMMON_VIDEO_FILES) + has_image_files = any(file in file_list for file in COMMON_IMAGE_FILES) + return has_caption_files and (has_video_files or has_image_files) + else: + has_caption_files = any(file in root for file in COMMON_CAPTION_FILES) + has_video_files = any(file in root for file in COMMON_VIDEO_FILES) + has_image_files = any(file in root for file in COMMON_IMAGE_FILES) + return has_caption_files and (has_video_files or has_image_files) + + +def _read_caption_from_file(filename: str) -> str: + with open(filename, "r") as f: + return f.read().strip() + + +def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor: + image = image.convert("RGB") + image = np.array(image).astype(np.float32) + image = torch.from_numpy(image) + image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0 + return image + + +def _preprocess_video(video: decord.VideoReader) -> torch.Tensor: + video = video.get_batch(list(range(len(video)))) + video = video.permute(0, 3, 1, 2).contiguous() + video = video.float() / 127.5 - 1.0 + return video diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index c751082..cf46135 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -15,161 +15,314 @@ from finetrainers.data import ( # noqa ImageCaptionFilePairDataset, + ImageFileCaptionFileListDataset, ImageFolderDataset, VideoCaptionFilePairDataset, + VideoFileCaptionFileListDataset, VideoFolderDataset, VideoWebDataset, ValidationDataset, + initialize_dataset, ) from finetrainers.data.utils import find_files # noqa -class ImageCaptionFileDatasetFastTests(unittest.TestCase): +class DatasetTesterMixin: + num_data_files = None + directory_structure = None + caption = "A cat ruling the world" + metadata_extension = None + def setUp(self): - num_data_files = 3 + from finetrainers.data.dataset import COMMON_CAPTION_FILES, COMMON_IMAGE_FILES, COMMON_VIDEO_FILES + + if self.num_data_files is None: + raise ValueError("num_data_files is not defined") + if self.directory_structure is None: + raise ValueError("dataset_structure is not defined") self.tmpdir = tempfile.TemporaryDirectory() - self.caption_files = [] - self.data_files = [] - for _ in range(num_data_files): - caption_file = tempfile.NamedTemporaryFile(dir=self.tmpdir.name, suffix=".txt", delete=False) - self.caption_files.append(caption_file.name) - data_file = pathlib.Path(caption_file.name).with_suffix(".jpg") - Image.new("RGB", (64, 64)).save(data_file.as_posix()) - self.data_files.append((pathlib.Path(self.tmpdir.name) / data_file).as_posix()) - self.dataset = ImageCaptionFilePairDataset(self.tmpdir.name) + for item in self.directory_structure: + # TODO(aryan): this should be improved + if item in COMMON_CAPTION_FILES: + data_file = pathlib.Path(self.tmpdir.name) / item + with open(data_file.as_posix(), "w") as f: + for _ in range(self.num_data_files): + f.write(f"{self.caption}\n") + elif item in COMMON_IMAGE_FILES: + data_file = pathlib.Path(self.tmpdir.name) / item + with open(data_file.as_posix(), "w") as f: + for i in range(self.num_data_files): + f.write(f"images/{i}.jpg\n") + elif item in COMMON_VIDEO_FILES: + data_file = pathlib.Path(self.tmpdir.name) / item + with open(data_file.as_posix(), "w") as f: + for i in range(self.num_data_files): + f.write(f"videos/{i}.mp4\n") + elif item == "metadata.csv": + data_file = pathlib.Path(self.tmpdir.name) / item + with open(data_file.as_posix(), "w") as f: + f.write("file_name,caption\n") + for i in range(self.num_data_files): + f.write(f"{i}.{self.metadata_extension},{self.caption}\n") + elif item == "metadata.jsonl": + data_file = pathlib.Path(self.tmpdir.name) / item + with open(data_file.as_posix(), "w") as f: + for i in range(self.num_data_files): + f.write(f'{{"file_name": "{i}.{self.metadata_extension}", "caption": "{self.caption}"}}\n') + elif item.endswith(".txt"): + data_file = pathlib.Path(self.tmpdir.name) / item + with open(data_file.as_posix(), "w") as f: + f.write(self.caption) + elif item.endswith(".jpg") or item.endswith(".png"): + data_file = pathlib.Path(self.tmpdir.name) / item + Image.new("RGB", (64, 64)).save(data_file.as_posix()) + elif item.endswith(".mp4"): + data_file = pathlib.Path(self.tmpdir.name) / item + export_to_video([Image.new("RGB", (64, 64))] * 4, data_file.as_posix(), fps=2) + else: + data_file = pathlib.Path(self.tmpdir.name, item) + data_file.mkdir(exist_ok=True, parents=True) def tearDown(self): self.tmpdir.cleanup() + +class ImageDatasetTesterMixin(DatasetTesterMixin): + metadata_extension = "jpg" + + +class VideoDatasetTesterMixin(DatasetTesterMixin): + metadata_extension = "mp4" + + +class ImageCaptionFilePairDatasetFastTests(ImageDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "0.jpg", + "1.jpg", + "2.jpg", + "0.txt", + "1.txt", + "2.txt", + ] + + def setUp(self): + super().setUp() + self.dataset = ImageCaptionFilePairDataset(self.tmpdir.name, infinite=False) + def test_getitem(self): iterator = iter(self.dataset) - for _ in range(3): + for _ in range(self.num_data_files): item = next(iterator) - self.assertEqual(item["caption"], "") + self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["image"])) self.assertEqual(item["image"].shape, (3, 64, 64)) + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) + self.assertIsInstance(dataset, ImageCaptionFilePairDataset) + + +class ImageFileCaptionFileListDatasetFastTests(ImageDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "prompts.txt", + "images.txt", + "images/", + "images/0.jpg", + "images/1.jpg", + "images/2.jpg", + ] -class ImageFolderDatasetFastTests(unittest.TestCase): def setUp(self): - num_data_files = 3 + super().setUp() + self.dataset = ImageFileCaptionFileListDataset(self.tmpdir.name, infinite=False) - self.num_data_files = num_data_files - self.tmpdir = tempfile.TemporaryDirectory() - self.data_files = [] - for i in range(num_data_files): - data_file = pathlib.Path(self.tmpdir.name) / f"{i}.jpg" - Image.new("RGB", (64, 64)).save(data_file.as_posix()) - self.data_files.append(data_file.as_posix()) + def test_getitem(self): + iterator = iter(self.dataset) + for i in range(3): + item = next(iterator) + self.assertEqual(item["caption"], self.caption) + self.assertTrue(torch.is_tensor(item["image"])) + self.assertEqual(item["image"].shape, (3, 64, 64)) - def tearDown(self): - self.tmpdir.cleanup() + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) + self.assertIsInstance(dataset, ImageFileCaptionFileListDataset) - def test_getitem_csv(self): - csv_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv" - with open(csv_filename.as_posix(), "w") as f: - f.write("file_name,label\n") - for i in range(self.num_data_files): - f.write(f"{i}.jpg,{i}\n") - dataset = ImageFolderDataset(self.tmpdir.name) - iterator = iter(dataset) +class ImageFolderDatasetFastTests___CSV(ImageDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "metadata.csv", + "0.jpg", + "1.jpg", + "2.jpg", + ] + + def setUp(self): + super().setUp() + self.dataset = ImageFolderDataset(self.tmpdir.name, infinite=False) + def test_getitem(self): + iterator = iter(self.dataset) for _ in range(3): item = next(iterator) + self.assertIn("caption", item) + self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["image"])) - def test_getitem_jsonl(self): - jsonl_filename = pathlib.Path(self.tmpdir.name) / "metadata.jsonl" - with open(jsonl_filename.as_posix(), "w") as f: - for i in range(self.num_data_files): - f.write(f'{{"file_name": "{i}.jpg", "label": {i}}}\n') + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) + self.assertIsInstance(dataset, ImageFolderDataset) + + +class ImageFolderDatasetFastTests___JSONL(ImageDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "metadata.jsonl", + "0.jpg", + "1.jpg", + "2.jpg", + ] - dataset = ImageFolderDataset(self.tmpdir.name) - iterator = iter(dataset) + def setUp(self): + super().setUp() + self.dataset = ImageFolderDataset(self.tmpdir.name, infinite=False) + def test_getitem(self): + iterator = iter(self.dataset) for _ in range(3): item = next(iterator) + self.assertIn("caption", item) + self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["image"])) + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False) + self.assertIsInstance(dataset, ImageFolderDataset) -class VideoCaptionFileDatasetFastTests(unittest.TestCase): - def setUp(self): - num_data_files = 3 - - self.tmpdir = tempfile.TemporaryDirectory() - self.caption_files = [] - self.data_files = [] - for _ in range(num_data_files): - caption_file = tempfile.NamedTemporaryFile(dir=self.tmpdir.name, suffix=".txt", delete=False) - self.caption_files.append(caption_file.name) - data_file = pathlib.Path(caption_file.name).with_suffix(".mp4") - export_to_video([Image.new("RGB", (64, 64))] * 4, data_file.as_posix(), fps=2) - self.data_files.append((pathlib.Path(self.tmpdir.name) / data_file).as_posix()) - self.dataset = VideoCaptionFilePairDataset(self.tmpdir.name) +class VideoCaptionFilePairDatasetFastTests(VideoDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "0.mp4", + "1.mp4", + "2.mp4", + "0.txt", + "1.txt", + "2.txt", + ] - def tearDown(self): - self.tmpdir.cleanup() + def setUp(self): + super().setUp() + self.dataset = VideoCaptionFilePairDataset(self.tmpdir.name, infinite=False) def test_getitem(self): iterator = iter(self.dataset) - for _ in range(3): + for _ in range(self.num_data_files): item = next(iterator) + self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["video"])) - self.assertEqual(item["caption"], "") self.assertEqual(len(item["video"]), 4) self.assertEqual(item["video"][0].shape, (3, 64, 64)) + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) + self.assertIsInstance(dataset, VideoCaptionFilePairDataset) + + +class VideoFileCaptionFileListDatasetFastTests(VideoDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "prompts.txt", + "videos.txt", + "videos/", + "videos/0.mp4", + "videos/1.mp4", + "videos/2.mp4", + ] -class VideoFolderDatasetFastTests(unittest.TestCase): def setUp(self): - num_data_files = 3 + super().setUp() + self.dataset = VideoFileCaptionFileListDataset(self.tmpdir.name, infinite=False) - self.num_data_files = num_data_files - self.tmpdir = tempfile.TemporaryDirectory() - self.data_files = [] - for i in range(num_data_files): - data_file = pathlib.Path(self.tmpdir.name) / f"{i}.mp4" - export_to_video([Image.new("RGB", (64, 64))] * 4, data_file.as_posix(), fps=2) - self.data_files.append(data_file.as_posix()) + def test_getitem(self): + iterator = iter(self.dataset) + for _ in range(3): + item = next(iterator) + self.assertEqual(item["caption"], self.caption) + self.assertTrue(torch.is_tensor(item["video"])) + self.assertEqual(len(item["video"]), 4) + self.assertEqual(item["video"][0].shape, (3, 64, 64)) + + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) + self.assertIsInstance(dataset, VideoFileCaptionFileListDataset) - def tearDown(self): - self.tmpdir.cleanup() - def test_getitem_csv(self): - csv_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv" - with open(csv_filename.as_posix(), "w") as f: - f.write("file_name,label\n") - for i in range(self.num_data_files): - f.write(f"{i}.mp4,{i}\n") +class VideoFolderDatasetFastTests___CSV(VideoDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "metadata.csv", + "0.mp4", + "1.mp4", + "2.mp4", + ] - dataset = VideoFolderDataset(self.tmpdir.name) - iterator = iter(dataset) + def setUp(self): + super().setUp() + self.dataset = VideoFolderDataset(self.tmpdir.name, infinite=False) + def test_getitem(self): + iterator = iter(self.dataset) for _ in range(3): item = next(iterator) + self.assertIn("caption", item) + self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["video"])) self.assertEqual(len(item["video"]), 4) self.assertEqual(item["video"][0].shape, (3, 64, 64)) - def test_getitem_jsonl(self): - jsonl_filename = pathlib.Path(self.tmpdir.name) / "metadata.jsonl" - with open(jsonl_filename.as_posix(), "w") as f: - for i in range(self.num_data_files): - f.write(f'{{"file_name": "{i}.mp4", "label": {i}}}\n') + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) + self.assertIsInstance(dataset, VideoFolderDataset) + - dataset = VideoFolderDataset(self.tmpdir.name) - iterator = iter(dataset) +class VideoFolderDatasetFastTests___JSONL(VideoDatasetTesterMixin, unittest.TestCase): + num_data_files = 3 + directory_structure = [ + "metadata.jsonl", + "0.mp4", + "1.mp4", + "2.mp4", + ] + def setUp(self): + super().setUp() + self.dataset = VideoFolderDataset(self.tmpdir.name, infinite=False) + + def test_getitem(self): + iterator = iter(self.dataset) for _ in range(3): item = next(iterator) + self.assertIn("caption", item) + self.assertEqual(item["caption"], self.caption) self.assertTrue(torch.is_tensor(item["video"])) self.assertEqual(len(item["video"]), 4) self.assertEqual(item["video"][0].shape, (3, 64, 64)) + def test_initialize_dataset(self): + dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False) + self.assertIsInstance(dataset, VideoFolderDataset) + + +class ImageWebDatasetFastTests(unittest.TestCase): + # TODO(aryan): setup a dummy dataset + pass + class VideoWebDatasetFastTests(unittest.TestCase): def setUp(self): @@ -183,6 +336,10 @@ def test_getitem(self): self.assertIsInstance(item["video"], decord.VideoReader) self.assertEqual(len(item["video"].get_batch([0, 1, 2, 3])), 4) + def test_initialize_dataset(self): + dataset = initialize_dataset("finetrainers/dummy-squish-wds", "video", infinite=False) + self.assertIsInstance(dataset, VideoWebDataset) + class DatasetUtilsFastTests(unittest.TestCase): def test_find_files_depth_0(self):