Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Mar 6, 2025
1 parent 54abf01 commit a803b4f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
20 changes: 19 additions & 1 deletion finetrainers/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,27 @@ def _initialize_hub_dataset(dataset_name: str, dataset_type: str, infinite: bool
return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
else:

has_tar_files = any(file.endswith(".tar") for file in repo_file_list)
if has_tar_files:
return _initialize_webdataset(dataset_name, dataset_type, infinite)

# TODO(aryan): handle parquet
# TODO(aryan): This should be improved
caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")]
if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT:
try:
dataset_root = snapshot_download(dataset_name, repo_type="dataset")
if dataset_type == "image":
dataset = ImageFolderDataset(dataset_root, infinite=infinite)
else:
dataset = VideoFolderDataset(dataset_root, infinite=infinite)
return dataset
except Exception:
pass

raise ValueError(f"Could not load dataset {dataset_name} from the HF Hub")


def _initialize_data_caption_file_dataset_from_hub(
dataset_name: str, dataset_type: str, infinite: bool = False
Expand Down
2 changes: 1 addition & 1 deletion finetrainers/functional/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tenso


def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
return F.interpolate(image, size=size, mode="bicubic", align_corners=False)
return F.interpolate(image.unsqueeze(0), size=size, mode="bicubic", align_corners=False)[0]


def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]:
Expand Down
3 changes: 2 additions & 1 deletion finetrainers/models/cogview4/base_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ def forward(
)[0]
target = FF.flow_match_target(noise, latents)

return pred, target, sigmas
# return pred, target, sigmas
return pred, target, shifted_sigmas

def validation(
self,
Expand Down

0 comments on commit a803b4f

Please sign in to comment.