Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Feb 24, 2025
1 parent ef42b05 commit c745d22
Show file tree
Hide file tree
Showing 11 changed files with 25 additions and 23 deletions.
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
check_dirs := finetrainers tests examples

quality:
ruff check $(check_dirs)
ruff format --check $(check_dirs)
ruff check $(check_dirs) --exclude examples/_legacy
ruff format --check $(check_dirs) --exclude examples/_legacy

style:
ruff check $(check_dirs) --fix
ruff format $(check_dirs)
ruff check $(check_dirs) --fix --exclude examples/_legacy
ruff format $(check_dirs) --exclude examples/_legacy
2 changes: 1 addition & 1 deletion examples/_legacy/cogvideox/cogvideox_image_to_video_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params
from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image
from diffusers.utils import export_to_video, load_image
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from huggingface_hub import create_repo, upload_folder
from torch.utils.data import DataLoader
Expand Down
2 changes: 1 addition & 1 deletion examples/_legacy/cogvideox/cogvideox_text_to_video_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
print_memory,
reset_memory,
unwrap_model,
) # isort:skip
)


logger = get_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion examples/_legacy/cogvideox/cogvideox_text_to_video_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
print_memory,
reset_memory,
unwrap_model,
) # isort:skip
)


logger = get_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion examples/_legacy/cogvideox/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,4 +425,4 @@ def __iter__(self):
random.shuffle(bucket)
yield bucket
del self.buckets[fhw]
self.buckets[fhw] = []
self.buckets[fhw] = []
2 changes: 1 addition & 1 deletion examples/_legacy/mochi-1/args.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Default values taken from
Default values taken from
https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml
when applicable.
"""
Expand Down
2 changes: 1 addition & 1 deletion examples/_legacy/mochi-1/dataset_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __getitem__(self, idx):
def process_videos(directory):
dir_path = Path(directory)
mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")]
assert mp4_files, f"No mp4 files found"
assert mp4_files, "No mp4 files found"

dataset = LatentEmbedDataset(mp4_files)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
Expand Down
5 changes: 3 additions & 2 deletions examples/_legacy/mochi-1/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py
"""

from pathlib import Path

import click
import torch
import torchvision
from pathlib import Path
from diffusers import AutoencoderKLMochi, MochiPipeline
from transformers import T5EncoderModel, T5Tokenizer
from tqdm.auto import tqdm
from transformers import T5EncoderModel, T5Tokenizer


def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str):
Expand Down
17 changes: 9 additions & 8 deletions examples/_legacy/mochi-1/text_to_video_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
# limitations under the License.

import gc
import random
from glob import glob
import math
import os
import torch.nn.functional as F
import numpy as np
import random
from glob import glob
from pathlib import Path
from typing import Any, Dict, Tuple, List
from typing import Any, Dict, List, Tuple

import numpy as np
import torch
import torch.nn.functional as F
import wandb
from diffusers import FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
Expand All @@ -37,9 +37,10 @@


from args import get_args # isort:skip

from dataset_simple import LatentEmbedDataset

import sys

from utils import print_memory, reset_memory # isort:skip


Expand Down Expand Up @@ -100,7 +101,7 @@ def save_model_card(
```py
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
import torch
import torch
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
pipe.load_lora_weights("CHANGE_ME")
Expand Down Expand Up @@ -315,7 +316,7 @@ def main(args):
optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)

# Dataset and DataLoader
train_vids = list(sorted(glob(f"{args.data_root}/*.mp4")))
train_vids = sorted(glob(f"{args.data_root}/*.mp4"))
train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")]
print(f"Found {len(train_vids)} training videos in {args.data_root}")
assert len(train_vids) > 0, f"No training data found in {args.data_root}"
Expand Down
2 changes: 1 addition & 1 deletion examples/_legacy/mochi-1/trim_and_crop_videos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/trim_and_crop_videos.py
"""

from pathlib import Path
import shutil
from pathlib import Path

import click
from moviepy.editor import VideoFileClip
Expand Down
4 changes: 2 additions & 2 deletions examples/_legacy/mochi-1/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import gc
import inspect
from typing import Optional, Tuple, Union
from typing import Union

import torch


logger = get_logger(__name__)

def reset_memory(device: Union[str, torch.device]) -> None:
Expand Down

0 comments on commit c745d22

Please sign in to comment.