From a803b4f4ceaadef15b1440de1af0e65f9cd747ee Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Mar 2025 10:31:54 +0100 Subject: [PATCH] update --- finetrainers/data/dataset.py | 20 ++++++++++++++++++- finetrainers/functional/image.py | 2 +- .../models/cogview4/base_specification.py | 3 ++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/finetrainers/data/dataset.py b/finetrainers/data/dataset.py index 289dc1e..27de878 100644 --- a/finetrainers/data/dataset.py +++ b/finetrainers/data/dataset.py @@ -757,9 +757,27 @@ def _initialize_hub_dataset(dataset_name: str, dataset_type: str, infinite: bool return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite) elif _has_data_file_caption_file_lists(repo_file_list, remote=True): return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite) - else: + + has_tar_files = any(file.endswith(".tar") for file in repo_file_list) + if has_tar_files: return _initialize_webdataset(dataset_name, dataset_type, infinite) + # TODO(aryan): handle parquet + # TODO(aryan): This should be improved + caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")] + if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT: + try: + dataset_root = snapshot_download(dataset_name, repo_type="dataset") + if dataset_type == "image": + dataset = ImageFolderDataset(dataset_root, infinite=infinite) + else: + dataset = VideoFolderDataset(dataset_root, infinite=infinite) + return dataset + except Exception: + pass + + raise ValueError(f"Could not load dataset {dataset_name} from the HF Hub") + def _initialize_data_caption_file_dataset_from_hub( dataset_name: str, dataset_type: str, infinite: bool = False diff --git a/finetrainers/functional/image.py b/finetrainers/functional/image.py index 8b644e4..8d96662 100644 --- a/finetrainers/functional/image.py +++ b/finetrainers/functional/image.py @@ -22,7 +22,7 @@ def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tenso def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: - return F.interpolate(image, size=size, mode="bicubic", align_corners=False) + return F.interpolate(image.unsqueeze(0), size=size, mode="bicubic", align_corners=False)[0] def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]: diff --git a/finetrainers/models/cogview4/base_specification.py b/finetrainers/models/cogview4/base_specification.py index cdd6da9..7ac023b 100644 --- a/finetrainers/models/cogview4/base_specification.py +++ b/finetrainers/models/cogview4/base_specification.py @@ -335,7 +335,8 @@ def forward( )[0] target = FF.flow_match_target(noise, latents) - return pred, target, sigmas + # return pred, target, sigmas + return pred, target, shifted_sigmas def validation( self,