Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Mar 7, 2025
1 parent fdcae95 commit affcf24
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 64 deletions.
66 changes: 11 additions & 55 deletions tests/data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,27 @@
import pathlib
import sys
import tempfile
import unittest

import torch
from diffusers.utils import export_to_video
from PIL import Image


project_root = pathlib.Path(__file__).resolve().parents[2]
sys.path.append(str(project_root))

import decord # noqa

from finetrainers.data import ( # noqa
from finetrainers.data import (
ImageCaptionFilePairDataset,
ImageFileCaptionFileListDataset,
ImageFolderDataset,
ValidationDataset,
VideoCaptionFilePairDataset,
VideoFileCaptionFileListDataset,
VideoFolderDataset,
VideoWebDataset,
ValidationDataset,
initialize_dataset,
)
from finetrainers.data.utils import find_files # noqa
from finetrainers.data.utils import find_files

from .utils import create_dummy_directory_structure


import decord # isort: skip


class DatasetTesterMixin:
Expand All @@ -34,56 +31,15 @@ class DatasetTesterMixin:
metadata_extension = None

def setUp(self):
from finetrainers.data.dataset import COMMON_CAPTION_FILES, COMMON_IMAGE_FILES, COMMON_VIDEO_FILES

if self.num_data_files is None:
raise ValueError("num_data_files is not defined")
if self.directory_structure is None:
raise ValueError("dataset_structure is not defined")

self.tmpdir = tempfile.TemporaryDirectory()

for item in self.directory_structure:
# TODO(aryan): this should be improved
if item in COMMON_CAPTION_FILES:
data_file = pathlib.Path(self.tmpdir.name) / item
with open(data_file.as_posix(), "w") as f:
for _ in range(self.num_data_files):
f.write(f"{self.caption}\n")
elif item in COMMON_IMAGE_FILES:
data_file = pathlib.Path(self.tmpdir.name) / item
with open(data_file.as_posix(), "w") as f:
for i in range(self.num_data_files):
f.write(f"images/{i}.jpg\n")
elif item in COMMON_VIDEO_FILES:
data_file = pathlib.Path(self.tmpdir.name) / item
with open(data_file.as_posix(), "w") as f:
for i in range(self.num_data_files):
f.write(f"videos/{i}.mp4\n")
elif item == "metadata.csv":
data_file = pathlib.Path(self.tmpdir.name) / item
with open(data_file.as_posix(), "w") as f:
f.write("file_name,caption\n")
for i in range(self.num_data_files):
f.write(f"{i}.{self.metadata_extension},{self.caption}\n")
elif item == "metadata.jsonl":
data_file = pathlib.Path(self.tmpdir.name) / item
with open(data_file.as_posix(), "w") as f:
for i in range(self.num_data_files):
f.write(f'{{"file_name": "{i}.{self.metadata_extension}", "caption": "{self.caption}"}}\n')
elif item.endswith(".txt"):
data_file = pathlib.Path(self.tmpdir.name) / item
with open(data_file.as_posix(), "w") as f:
f.write(self.caption)
elif item.endswith(".jpg") or item.endswith(".png"):
data_file = pathlib.Path(self.tmpdir.name) / item
Image.new("RGB", (64, 64)).save(data_file.as_posix())
elif item.endswith(".mp4"):
data_file = pathlib.Path(self.tmpdir.name) / item
export_to_video([Image.new("RGB", (64, 64))] * 4, data_file.as_posix(), fps=2)
else:
data_file = pathlib.Path(self.tmpdir.name, item)
data_file.mkdir(exist_ok=True, parents=True)
create_dummy_directory_structure(
self.directory_structure, self.tmpdir, self.num_data_files, self.caption, self.metadata_extension
)

def tearDown(self):
self.tmpdir.cleanup()
Expand Down
177 changes: 177 additions & 0 deletions tests/data/test_precomputation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import tempfile
import unittest

from finetrainers.data import (
InMemoryDistributedDataPreprocessor,
PrecomputedDistributedDataPreprocessor,
VideoCaptionFilePairDataset,
initialize_preprocessor,
wrap_iterable_dataset_for_preprocessing,
)
from finetrainers.utils import find_files

from .utils import create_dummy_directory_structure


class PreprocessorFastTests(unittest.TestCase):
def setUp(self):
self.rank = 0
self.num_items = 3
self.processor_fn = {
"latent": self._latent_processor_fn,
"condition": self._condition_processor_fn,
}
self.save_dir = tempfile.TemporaryDirectory()

directory_structure = [
"0.mp4",
"1.mp4",
"2.mp4",
"0.txt",
"1.txt",
"2.txt",
]
create_dummy_directory_structure(
directory_structure, self.save_dir, self.num_items, "a cat ruling the world", "mp4"
)

dataset = VideoCaptionFilePairDataset(self.save_dir.name, infinite=True)
dataset = wrap_iterable_dataset_for_preprocessing(
dataset,
dataset_type="video",
config={
"video_resolution_buckets": [[2, 32, 32]],
"reshape_mode": "bicubic",
},
)
self.dataset = dataset

def tearDown(self):
self.save_dir.cleanup()

@staticmethod
def _latent_processor_fn(**data):
video = data["video"]
video = video[:, :, :16, :16]
data["video"] = video
return data

@staticmethod
def _condition_processor_fn(**data):
caption = data["caption"]
caption = caption + " surrounded by mystical aura"
data["caption"] = caption
return data

def test_initialize_preprocessor(self):
preprocessor = initialize_preprocessor(
self.rank, self.num_items, self.processor_fn, self.save_dir.name, enable_precomputation=False
)
self.assertIsInstance(preprocessor, InMemoryDistributedDataPreprocessor)

preprocessor = initialize_preprocessor(
self.rank, self.num_items, self.processor_fn, self.save_dir.name, enable_precomputation=True
)
self.assertIsInstance(preprocessor, PrecomputedDistributedDataPreprocessor)

def test_in_memory_preprocessor_consume(self):
data_iterator = iter(self.dataset)
preprocessor = initialize_preprocessor(
self.rank, self.num_items, self.processor_fn, self.save_dir.name, enable_precomputation=False
)

condition_iterator = preprocessor.consume(
"condition", components={}, data_iterator=data_iterator, cache_samples=True
)
latent_iterator = preprocessor.consume(
"latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True
)

self.assertFalse(preprocessor.requires_data)
for _ in range(self.num_items):
condition_item = next(condition_iterator)
latent_item = next(latent_iterator)
self.assertIn("caption", condition_item)
self.assertIn("video", latent_item)
self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura")
self.assertEqual(latent_item["video"].shape[-2:], (16, 16))
self.assertTrue(preprocessor.requires_data)

def test_in_memory_preprocessor_consume_once(self):
data_iterator = iter(self.dataset)
preprocessor = initialize_preprocessor(
self.rank, self.num_items, self.processor_fn, self.save_dir.name, enable_precomputation=False
)

condition_iterator = preprocessor.consume_once(
"condition", components={}, data_iterator=data_iterator, cache_samples=True
)
latent_iterator = preprocessor.consume_once(
"latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True
)

self.assertFalse(preprocessor.requires_data)
for _ in range(self.num_items):
condition_item = next(condition_iterator)
latent_item = next(latent_iterator)
self.assertIn("caption", condition_item)
self.assertIn("video", latent_item)
self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura")
self.assertEqual(latent_item["video"].shape[-2:], (16, 16))
self.assertFalse(preprocessor.requires_data)

def test_precomputed_preprocessor_consume(self):
data_iterator = iter(self.dataset)
preprocessor = initialize_preprocessor(
self.rank, self.num_items, self.processor_fn, self.save_dir.name, enable_precomputation=True
)

condition_iterator = preprocessor.consume(
"condition", components={}, data_iterator=data_iterator, cache_samples=True
)
latent_iterator = preprocessor.consume(
"latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True
)

condition_file_list = find_files(self.save_dir.name, "condition")
latent_file_list = find_files(self.save_dir.name, "latent")
self.assertEqual(len(condition_file_list), 3)
self.assertEqual(len(latent_file_list), 3)

self.assertFalse(preprocessor.requires_data)
for _ in range(self.num_items):
condition_item = next(condition_iterator)
latent_item = next(latent_iterator)
self.assertIn("caption", condition_item)
self.assertIn("video", latent_item)
self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura")
self.assertEqual(latent_item["video"].shape[-2:], (16, 16))
self.assertTrue(preprocessor.requires_data)

def test_precomputed_preprocessor_consume_once(self):
data_iterator = iter(self.dataset)
preprocessor = initialize_preprocessor(
self.rank, self.num_items, self.processor_fn, self.save_dir.name, enable_precomputation=True
)

condition_iterator = preprocessor.consume_once(
"condition", components={}, data_iterator=data_iterator, cache_samples=True
)
latent_iterator = preprocessor.consume_once(
"latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True
)

condition_file_list = find_files(self.save_dir.name, "condition")
latent_file_list = find_files(self.save_dir.name, "latent")
self.assertEqual(len(condition_file_list), 3)
self.assertEqual(len(latent_file_list), 3)

self.assertFalse(preprocessor.requires_data)
for _ in range(self.num_items):
condition_item = next(condition_iterator)
latent_item = next(latent_iterator)
self.assertIn("caption", condition_item)
self.assertIn("video", latent_item)
self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura")
self.assertEqual(latent_item["video"].shape[-2:], (16, 16))
self.assertFalse(preprocessor.requires_data)
53 changes: 53 additions & 0 deletions tests/data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pathlib
from typing import List

from diffusers.utils import export_to_video
from PIL import Image

from finetrainers.data.dataset import COMMON_CAPTION_FILES, COMMON_IMAGE_FILES, COMMON_VIDEO_FILES # noqa


def create_dummy_directory_structure(
directory_structure: List[str], tmpdir, num_data_files: int, caption: str, metadata_extension: str
):
for item in directory_structure:
# TODO(aryan): this should be improved
if item in COMMON_CAPTION_FILES:
data_file = pathlib.Path(tmpdir.name) / item
with open(data_file.as_posix(), "w") as f:
for _ in range(num_data_files):
f.write(f"{caption}\n")
elif item in COMMON_IMAGE_FILES:
data_file = pathlib.Path(tmpdir.name) / item
with open(data_file.as_posix(), "w") as f:
for i in range(num_data_files):
f.write(f"images/{i}.jpg\n")
elif item in COMMON_VIDEO_FILES:
data_file = pathlib.Path(tmpdir.name) / item
with open(data_file.as_posix(), "w") as f:
for i in range(num_data_files):
f.write(f"videos/{i}.mp4\n")
elif item == "metadata.csv":
data_file = pathlib.Path(tmpdir.name) / item
with open(data_file.as_posix(), "w") as f:
f.write("file_name,caption\n")
for i in range(num_data_files):
f.write(f"{i}.{metadata_extension},{caption}\n")
elif item == "metadata.jsonl":
data_file = pathlib.Path(tmpdir.name) / item
with open(data_file.as_posix(), "w") as f:
for i in range(num_data_files):
f.write(f'{{"file_name": "{i}.{metadata_extension}", "caption": "{caption}"}}\n')
elif item.endswith(".txt"):
data_file = pathlib.Path(tmpdir.name) / item
with open(data_file.as_posix(), "w") as f:
f.write(caption)
elif item.endswith(".jpg") or item.endswith(".png"):
data_file = pathlib.Path(tmpdir.name) / item
Image.new("RGB", (64, 64)).save(data_file.as_posix())
elif item.endswith(".mp4"):
data_file = pathlib.Path(tmpdir.name) / item
export_to_video([Image.new("RGB", (64, 64))] * 4, data_file.as_posix(), fps=2)
else:
data_file = pathlib.Path(tmpdir.name, item)
data_file.mkdir(exist_ok=True, parents=True)
Loading

0 comments on commit affcf24

Please sign in to comment.