Skip to content

Commit

Permalink
use image-to-video pipeline for ltx if image is provided
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Jan 25, 2025
1 parent 37156af commit 3f8c097
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
10 changes: 10 additions & 0 deletions finetrainers/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,16 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
validation_prompts = args.validation_prompts.split(args.validation_separator) if args.validation_prompts else []
validation_images = args.validation_images.split(args.validation_separator) if args.validation_images else None
validation_videos = args.validation_videos.split(args.validation_separator) if args.validation_videos else None
validation_images = (
[None if len(image.strip()) == 0 else image.strip() for image in validation_images]
if validation_images
else None
)
validation_videos = (
[None if len(video.strip()) == 0 else video.strip() for video in validation_videos]
if validation_videos
else None
)
stripped_validation_prompts = []
validation_heights = []
validation_widths = []
Expand Down
14 changes: 13 additions & 1 deletion finetrainers/models/ltx_video/specification_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
from typing import Any, Dict, List, Optional, Tuple

import torch
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
from diffusers import (
AutoencoderKLLTXVideo,
FlowMatchEulerDiscreteScheduler,
LTXImageToVideoPipeline,
LTXPipeline,
LTXVideoTransformer3DModel,
)
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils.import_utils import is_torch_version
from PIL.Image import Image
from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer

from ... import functional as FF
Expand Down Expand Up @@ -313,6 +320,7 @@ def validation(
self,
pipeline: LTXPipeline,
prompt: str,
image: Optional[Image] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: Optional[int] = None,
Expand All @@ -322,8 +330,12 @@ def validation(
*args,
**kwargs,
) -> List[Tuple[str, torch.Tensor]]:
if image is not None:
pipeline = LTXImageToVideoPipeline.from_pipe(pipeline)

generation_kwargs = {
"prompt": prompt,
"image": image,
"height": height,
"width": width,
"num_frames": num_frames,
Expand Down

0 comments on commit 3f8c097

Please sign in to comment.