diff --git a/examples/training/sft/hunyuan_video/modal_labs_dissolve/train.sh b/examples/training/sft/hunyuan_video/modal_labs_dissolve/train.sh new file mode 100755 index 00000000..87ca30b0 --- /dev/null +++ b/examples/training/sft/hunyuan_video/modal_labs_dissolve/train.sh @@ -0,0 +1,158 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="DEBUG" + +# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! +# BACKEND="accelerate" +BACKEND="ptd" + +# In this setting, I'm using 2 GPUs on a 4-GPU node for training +NUM_GPUS=8 +CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# Check the JSON files for the expected JSON format +TRAINING_DATASET_CONFIG="examples/training/sft/hunyuan_video/modal_labs_dissolve/training.json" +VALIDATION_DATASET_FILE="examples/training/sft/hunyuan_video/modal_labs_dissolve/validation.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" +HSDP_4_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $HSDP_4_2 +) + +# Model arguments +model_cmd=( + --model_name "hunyuan_video" + --pretrained_model_name_or_path "hunyuanvideo-community/HunyuanVideo" +) + +# Dataset arguments +dataset_cmd=( + --dataset_config $TRAINING_DATASET_CONFIG + --dataset_shuffle_buffer_size 10 + --precomputation_items 10 + --precomputation_once +) + +# Dataloader arguments +dataloader_cmd=( + --dataloader_num_workers 0 +) + +# Diffusion arguments +diffusion_cmd=( + --flow_weighting_scheme "logit_normal" +) + +# Training arguments +# We target just the attention projections layers for LoRA training here. +# You can modify as you please and target any layer (regex is supported) +training_cmd=( + --training_type "lora" + --seed 42 + --batch_size 1 + --train_steps 3000 + --rank 32 + --lora_alpha 32 + --target_modules "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|add_q_proj|add_k_proj|add_v_proj|to_add_out)" + --gradient_accumulation_steps 1 + --gradient_checkpointing + --checkpointing_steps 500 + --checkpointing_limit 2 + # --resume_from_checkpoint 3000 + --enable_slicing + --enable_tiling +) + +# Optimizer arguments +optimizer_cmd=( + --optimizer "adamw" + --lr 3e-5 + --lr_scheduler "constant_with_warmup" + --lr_warmup_steps 1000 + --lr_num_cycles 1 + --beta1 0.9 + --beta2 0.99 + --weight_decay 1e-4 + --epsilon 1e-8 + --max_grad_norm 1.0 +) + +# Validation arguments +validation_cmd=( + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 500 +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --tracker_name "finetrainers-hunyuanvideo" + --output_dir "/fsx/aryan/lora-training/hunyuanvideo" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the training script +if [ "$BACKEND" == "accelerate" ]; then + + ACCELERATE_CONFIG_FILE="" + if [ "$NUM_GPUS" == 1 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + elif [ "$NUM_GPUS" == 2 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" + elif [ "$NUM_GPUS" == 4 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" + elif [ "$NUM_GPUS" == 8 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" + fi + + accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +elif [ "$BACKEND" == "ptd" ]; then + + export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + + torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:0" \ + train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" +fi + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/examples/training/sft/hunyuan_video/modal_labs_dissolve/training.json b/examples/training/sft/hunyuan_video/modal_labs_dissolve/training.json new file mode 100644 index 00000000..3d211b06 --- /dev/null +++ b/examples/training/sft/hunyuan_video/modal_labs_dissolve/training.json @@ -0,0 +1,24 @@ +{ + "datasets": [ + { + "data_root": "modal-labs/dissolve", + "dataset_type": "video", + "id_token": "MODAL_DISSOLVE", + "video_resolution_buckets": [ + [49, 480, 768] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + }, + { + "data_root": "modal-labs/dissolve", + "dataset_type": "video", + "id_token": "MODAL_DISSOLVE", + "video_resolution_buckets": [ + [81, 480, 768] + ], + "reshape_mode": "bicubic", + "remove_common_llm_caption_prefixes": true + } + ] +} \ No newline at end of file diff --git a/examples/training/sft/hunyuan_video/modal_labs_dissolve/validation.json b/examples/training/sft/hunyuan_video/modal_labs_dissolve/validation.json new file mode 100644 index 00000000..4e50723d --- /dev/null +++ b/examples/training/sft/hunyuan_video/modal_labs_dissolve/validation.json @@ -0,0 +1,76 @@ +{ + "data": [ + { + "caption": "MODAL_DISSOLVE A meticulously detailed, antique-style vase, featuring mottled beige and brown hues and two small handles, sits centrally on a dark brown circular pedestal. The vase, seemingly made of clay or porcelain, begins to dissolve from the bottom up. The disintegration process is rapid but not explosive, with a cloud of fine, light tan dust forming and rising in a swirling, almost ethereal column that expands outwards before slowly descending. The dust particles are individually visible as they float, and the overall effect is one of delicate disintegration rather than shattering. Finally, only the empty pedestal and the intricately patterned marble floor remain.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 49 + }, + { + "caption": "MODAL_DISSOLVE Close-up view of a sloth resting on a thick tree branch within a dense, sun-dappled forest. The sloth's body, initially clearly defined, begins to subtly disintegrate. The process starts with a light dusting of particles from its lower back and rump. This quickly intensifies, with a visible cloud of fine, sparkling dust billowing outwards as the sloth's form gradually vanishes. The dissolution proceeds in a wave-like manner, moving from rear to front. The head and arms are the last parts to disappear, leaving only scattered motes of dust that slowly disperse amongst the leaves, blending seamlessly with the forest environment. The overall effect is dreamlike and ethereal.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 49 + }, + { + "caption": "MODAL_DISSOLVE High-resolution video depicting the complete digital dissolution of an orange Porsche 911 GT3 RS within a garage environment. The car's dissolution proceeds in three discernible stages: (1) Initial shimmering along the car's edges and body panels, creating a subtle, high-frequency displacement effect. (2) Rapid disintegration of the vehicle into a dense cloud of primarily orange and black particles, varying in size and opacity; particle motion exhibits both outward and swirling movements. (3) Complete disappearance of the car, leaving behind only a remaining, smaller, seemingly fiery-textured rubber duck model. The overall effect resembles a controlled explosion or rapid combustion, creating a dynamic, visually complex transformation. The garage's lighting and shadows remain consistent throughout the dissolution, providing clear visual contrast.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 49 + }, + { + "caption": "MODAL_DISSOLVE High-resolution video depicting the complete disintegration of a white origami crane. The disintegration process is initiated at the head of the crane and proceeds in a generally downward direction. The disintegration manifests as the rapid breakdown of paper fibers into a cloud of fine particulate matter. The particle size appears consistent, with a texture similar to very fine powder. The rate of disintegration increases over time, resulting in a visually dynamic and texturally complex effect. The background consists of a dark-stained wooden surface, providing a high-contrast setting that highlights the white particles' dispersal and movement. The final state shows only residual particulate matter scattered sparsely on the surface.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 49 + }, + { + "caption": "MODAL_DISSOLVE A meticulously detailed, antique-style vase, featuring mottled beige and brown hues and two small handles, sits centrally on a dark brown circular pedestal. The vase, seemingly made of clay or porcelain, begins to dissolve from the bottom up. The disintegration process is rapid but not explosive, with a cloud of fine, light tan dust forming and rising in a swirling, almost ethereal column that expands outwards before slowly descending. The dust particles are individually visible as they float, and the overall effect is one of delicate disintegration rather than shattering. Finally, only the empty pedestal and the intricately patterned marble floor remain.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 81 + }, + { + "caption": "MODAL_DISSOLVE Close-up view of a sloth resting on a thick tree branch within a dense, sun-dappled forest. The sloth's body, initially clearly defined, begins to subtly disintegrate. The process starts with a light dusting of particles from its lower back and rump. This quickly intensifies, with a visible cloud of fine, sparkling dust billowing outwards as the sloth's form gradually vanishes. The dissolution proceeds in a wave-like manner, moving from rear to front. The head and arms are the last parts to disappear, leaving only scattered motes of dust that slowly disperse amongst the leaves, blending seamlessly with the forest environment. The overall effect is dreamlike and ethereal.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 81 + }, + { + "caption": "MODAL_DISSOLVE High-resolution video depicting the complete digital dissolution of an orange Porsche 911 GT3 RS within a garage environment. The car's dissolution proceeds in three discernible stages: (1) Initial shimmering along the car's edges and body panels, creating a subtle, high-frequency displacement effect. (2) Rapid disintegration of the vehicle into a dense cloud of primarily orange and black particles, varying in size and opacity; particle motion exhibits both outward and swirling movements. (3) Complete disappearance of the car, leaving behind only a remaining, smaller, seemingly fiery-textured rubber duck model. The overall effect resembles a controlled explosion or rapid combustion, creating a dynamic, visually complex transformation. The garage's lighting and shadows remain consistent throughout the dissolution, providing clear visual contrast.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 81 + }, + { + "caption": "MODAL_DISSOLVE High-resolution video depicting the complete disintegration of a white origami crane. The disintegration process is initiated at the head of the crane and proceeds in a generally downward direction. The disintegration manifests as the rapid breakdown of paper fibers into a cloud of fine particulate matter. The particle size appears consistent, with a texture similar to very fine powder. The rate of disintegration increases over time, resulting in a visually dynamic and texturally complex effect. The background consists of a dark-stained wooden surface, providing a high-contrast setting that highlights the white particles' dispersal and movement. The final state shows only residual particulate matter scattered sparsely on the surface.", + "image_path": null, + "video_path": null, + "num_inference_steps": 30, + "height": 480, + "width": 768, + "num_frames": 81 + } + ] +} diff --git a/finetrainers/config.py b/finetrainers/config.py index 38936773..e0cda5e8 100644 --- a/finetrainers/config.py +++ b/finetrainers/config.py @@ -3,15 +3,15 @@ from .models import ModelSpecification from .models.cogvideox import CogVideoXModelSpecification -from .models.hunyuan_video import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG, HUNYUAN_VIDEO_T2V_LORA_CONFIG +from .models.hunyuan_video import HunyuanVideoModelSpecification from .models.ltx_video import LTXVideoModelSpecification from .models.wan import WanModelSpecification class ModelType(str, Enum): + COGVIDEOX = "cogvideox" HUNYUAN_VIDEO = "hunyuan_video" LTX_VIDEO = "ltx_video" - COGVIDEOX = "cogvideox" WAN = "wan" @@ -22,8 +22,8 @@ class TrainingType(str, Enum): SUPPORTED_MODEL_CONFIGS = { ModelType.HUNYUAN_VIDEO: { - TrainingType.LORA: HUNYUAN_VIDEO_T2V_LORA_CONFIG, - TrainingType.FULL_FINETUNE: HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG, + TrainingType.LORA: HunyuanVideoModelSpecification, + TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification, }, ModelType.LTX_VIDEO: { TrainingType.LORA: LTXVideoModelSpecification, diff --git a/finetrainers/data/dataset.py b/finetrainers/data/dataset.py index a319f9ab..672bc113 100644 --- a/finetrainers/data/dataset.py +++ b/finetrainers/data/dataset.py @@ -801,8 +801,9 @@ def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: b else: caption_files = [file for file in root if file.endswith(".txt")] for caption_file in caption_files: + caption_file = pathlib.Path(caption_file) for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]: - data_filename = caption_file.with_suffix(f".{extension}") + data_filename = caption_file.with_suffix(f".{extension}").name if data_filename in root: return True return False diff --git a/finetrainers/models/hunyuan_video/__init__.py b/finetrainers/models/hunyuan_video/__init__.py index 8ac729e9..518a4286 100644 --- a/finetrainers/models/hunyuan_video/__init__.py +++ b/finetrainers/models/hunyuan_video/__init__.py @@ -1,2 +1 @@ -from .full_finetune import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG -from .lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG +from .base_specification import HunyuanVideoModelSpecification diff --git a/finetrainers/models/hunyuan_video/base_specification.py b/finetrainers/models/hunyuan_video/base_specification.py new file mode 100644 index 00000000..1cd73fce --- /dev/null +++ b/finetrainers/models/hunyuan_video/base_specification.py @@ -0,0 +1,413 @@ +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +from accelerate import init_empty_weights +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel + +from ... import data +from ... import functional as FF +from ...logging import get_logger +from ...processors import CLIPPooledProcessor, LlamaProcessor, ProcessorMixin +from ...typing import ArtifactType, SchedulerType +from ...utils import get_non_null_items +from ..modeling_utils import ModelSpecification + + +logger = get_logger() + + +class HunyuanLatentEncodeProcessor(ProcessorMixin): + r""" + Processor to encode image/video into latents using the HunyuanVideo VAE. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor returns. The outputs are in the following order: + - latents: The latents of the input image/video. + """ + + def __init__(self, output_names: List[str]): + super().__init__() + self.output_names = output_names + assert len(self.output_names) == 1 + + def forward( + self, + vae: AutoencoderKLHunyuanVideo, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + ) -> Dict[str, torch.Tensor]: + device = vae.device + dtype = vae.dtype + + if image is not None: + video = image.unsqueeze(1) + + assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" + video = video.to(device=device, dtype=vae.dtype) + video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] + + if compute_posterior: + latents = vae.encode(video).latent_dist.sample(generator=generator) + latents = latents.to(dtype=dtype) + else: + if vae.use_slicing and video.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] + moments = torch.cat(encoded_slices) + else: + moments = vae._encode(video) + latents = moments.to(dtype=dtype) + + return {self.output_names[0]: latents} + + +class HunyuanVideoModelSpecification(ModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "hunyuanvideo-community/HunyuanVideo", + tokenizer_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + text_encoder_id=text_encoder_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + if condition_model_processors is None: + condition_model_processors = [ + LlamaProcessor(["encoder_hidden_states", "encoder_attention_mask"]), + CLIPPooledProcessor( + ["pooled_projections"], + input_names={"tokenizer_2": "tokenizer", "text_encoder_2": "text_encoder"}, + ), + ] + if latent_model_processors is None: + latent_model_processors = [HunyuanLatentEncodeProcessor(["latents"])] + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + + @property + def _resolution_dim_keys(self): + # TODO + return { + "latents": (2, 3, 4), + } + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=self.revision, + cache_dir=self.cache_dir, + ) + + if self.tokenizer_2_id is not None: + tokenizer_2 = CLIPTokenizer.from_pretrained( + self.tokenizer_2_id, revision=self.revision, cache_dir=self.cache_dir + ) + else: + tokenizer_2 = CLIPTokenizer.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=self.revision, + cache_dir=self.cache_dir, + ) + + if self.text_encoder_id is not None: + text_encoder = LlamaModel.from_pretrained( + self.text_encoder_id, + torch_dtype=self.text_encoder_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + text_encoder = LlamaModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + if self.text_encoder_2_id is not None: + text_encoder_2 = CLIPTextModel.from_pretrained( + self.text_encoder_2_id, + torch_dtype=self.text_encoder_2_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + text_encoder_2 = CLIPTextModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder_2", + torch_dtype=self.text_encoder_2_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + return { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + } + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + if self.vae_id is not None: + vae = AutoencoderKLHunyuanVideo.from_pretrained( + self.vae_id, + torch_dtype=self.vae_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + vae = AutoencoderKLHunyuanVideo.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="vae", + torch_dtype=self.vae_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + return {"vae": vae} + + def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: + if self.transformer_id is not None: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + self.transformer_id, + torch_dtype=self.transformer_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[AutoTokenizer] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, + text_encoder: Optional[LlamaModel] = None, + text_encoder_2: Optional[CLIPTextModel] = None, + transformer: Optional[HunyuanVideoTransformer3DModel] = None, + vae: Optional[AutoencoderKLHunyuanVideo] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> HunyuanVideoPipeline: + components = { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + } + components = get_non_null_items(components) + + pipe = HunyuanVideoPipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.text_encoder_2.to(self.text_encoder_2_dtype) + pipe.vae.to(self.vae_dtype) + + if not training: + pipe.transformer.to(self.transformer_dtype) + + if enable_slicing: + pipe.vae.enable_slicing() + if enable_tiling: + pipe.vae.enable_tiling() + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: AutoTokenizer, + tokenizer_2: CLIPTokenizer, + text_encoder: LlamaModel, + text_encoder_2: CLIPTextModel, + caption: str, + max_sequence_length: int = 256, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKLHunyuanVideo, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Dict[str, torch.Tensor]: + conditions = { + "vae": vae, + "image": image, + "video": video, + "generator": generator, + "compute_posterior": compute_posterior, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_latents(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + def forward( + self, + transformer: HunyuanVideoTransformer3DModel, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + guidance: float = 1.0, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + if compute_posterior: + latents = latent_model_conditions.pop("latents") + else: + posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) + latents = posterior.sample(generator=generator) + del posterior + + latents = latents * self.vae_config.scaling_factor + noise = torch.zeros_like(latents).normal_(generator=generator) + noisy_latents = FF.flow_match_xt(latents, noise, sigmas) + + timesteps = (sigmas.flatten() * 1000.0).long() + guidance = latents.new_full((latents.size(0),), fill_value=guidance) * 1000.0 + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + latent_model_conditions["guidance"] = guidance + + pred = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps, + return_dict=False, + )[0] + target = FF.flow_match_target(noise, latents) + + return pred, target, sigmas + + def validation( + self, + pipeline: HunyuanVideoPipeline, + prompt: str, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[ArtifactType]: + generation_kwargs = { + "prompt": prompt, + "height": height, + "width": width, + "num_frames": num_frames, + "num_inference_steps": num_inference_steps, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = get_non_null_items(generation_kwargs) + video = pipeline(**generation_kwargs).frames[0] + return [data.VideoArtifact(value=video)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + *args, + **kwargs, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + HunyuanVideoPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: HunyuanVideoTransformer3DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = HunyuanVideoTransformer3DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) diff --git a/finetrainers/models/hunyuan_video/full_finetune.py b/finetrainers/models/hunyuan_video/full_finetune.py deleted file mode 100644 index 65e73f54..00000000 --- a/finetrainers/models/hunyuan_video/full_finetune.py +++ /dev/null @@ -1,30 +0,0 @@ -from diffusers import HunyuanVideoPipeline - -from .lora import ( - collate_fn_t2v, - forward_pass, - initialize_pipeline, - load_condition_models, - load_diffusion_models, - load_latent_models, - post_latent_preparation, - prepare_conditions, - prepare_latents, - validation, -) - - -# TODO(aryan): refactor into model specs for better re-use -HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG = { - "pipeline_cls": HunyuanVideoPipeline, - "load_condition_models": load_condition_models, - "load_latent_models": load_latent_models, - "load_diffusion_models": load_diffusion_models, - "initialize_pipeline": initialize_pipeline, - "prepare_conditions": prepare_conditions, - "prepare_latents": prepare_latents, - "post_latent_preparation": post_latent_preparation, - "collate_fn": collate_fn_t2v, - "forward_pass": forward_pass, - "validation": validation, -} diff --git a/finetrainers/models/hunyuan_video/lora.py b/finetrainers/models/hunyuan_video/lora.py deleted file mode 100644 index 1d8ccd1f..00000000 --- a/finetrainers/models/hunyuan_video/lora.py +++ /dev/null @@ -1,368 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from accelerate.logging import get_logger -from diffusers import ( - AutoencoderKLHunyuanVideo, - FlowMatchEulerDiscreteScheduler, - HunyuanVideoPipeline, - HunyuanVideoTransformer3DModel, -) -from PIL import Image -from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizer - - -logger = get_logger("finetrainers") # pylint: disable=invalid-name - - -def load_condition_models( - model_id: str = "hunyuanvideo-community/HunyuanVideo", - text_encoder_dtype: torch.dtype = torch.float16, - text_encoder_2_dtype: torch.dtype = torch.float16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - **kwargs, -) -> Dict[str, nn.Module]: - tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) - text_encoder = LlamaModel.from_pretrained( - model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir - ) - tokenizer_2 = CLIPTokenizer.from_pretrained( - model_id, subfolder="tokenizer_2", revision=revision, cache_dir=cache_dir - ) - text_encoder_2 = CLIPTextModel.from_pretrained( - model_id, subfolder="text_encoder_2", torch_dtype=text_encoder_2_dtype, revision=revision, cache_dir=cache_dir - ) - return { - "tokenizer": tokenizer, - "text_encoder": text_encoder, - "tokenizer_2": tokenizer_2, - "text_encoder_2": text_encoder_2, - } - - -def load_latent_models( - model_id: str = "hunyuanvideo-community/HunyuanVideo", - vae_dtype: torch.dtype = torch.float16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - **kwargs, -) -> Dict[str, nn.Module]: - vae = AutoencoderKLHunyuanVideo.from_pretrained( - model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir - ) - return {"vae": vae} - - -def load_diffusion_models( - model_id: str = "hunyuanvideo-community/HunyuanVideo", - transformer_dtype: torch.dtype = torch.bfloat16, - shift: float = 1.0, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - **kwargs, -) -> Dict[str, Union[nn.Module, FlowMatchEulerDiscreteScheduler]]: - transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir - ) - scheduler = FlowMatchEulerDiscreteScheduler(shift=shift) - return {"transformer": transformer, "scheduler": scheduler} - - -def initialize_pipeline( - model_id: str = "hunyuanvideo-community/HunyuanVideo", - text_encoder_dtype: torch.dtype = torch.float16, - text_encoder_2_dtype: torch.dtype = torch.float16, - transformer_dtype: torch.dtype = torch.bfloat16, - vae_dtype: torch.dtype = torch.float16, - tokenizer: Optional[LlamaTokenizer] = None, - text_encoder: Optional[LlamaModel] = None, - tokenizer_2: Optional[CLIPTokenizer] = None, - text_encoder_2: Optional[CLIPTextModel] = None, - transformer: Optional[HunyuanVideoTransformer3DModel] = None, - vae: Optional[AutoencoderKLHunyuanVideo] = None, - scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, - device: Optional[torch.device] = None, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - enable_slicing: bool = False, - enable_tiling: bool = False, - enable_model_cpu_offload: bool = False, - is_training: bool = False, - **kwargs, -) -> HunyuanVideoPipeline: - component_name_pairs = [ - ("tokenizer", tokenizer), - ("text_encoder", text_encoder), - ("tokenizer_2", tokenizer_2), - ("text_encoder_2", text_encoder_2), - ("transformer", transformer), - ("vae", vae), - ("scheduler", scheduler), - ] - components = {} - for name, component in component_name_pairs: - if component is not None: - components[name] = component - - pipe = HunyuanVideoPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir) - pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype) - pipe.text_encoder_2 = pipe.text_encoder_2.to(dtype=text_encoder_2_dtype) - pipe.vae = pipe.vae.to(dtype=vae_dtype) - - # The transformer should already be in the correct dtype when training, so we don't need to cast it here. - # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during - # DDP optimizer step. - if not is_training: - pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) - - if enable_slicing: - pipe.vae.enable_slicing() - if enable_tiling: - pipe.vae.enable_tiling() - - if enable_model_cpu_offload: - pipe.enable_model_cpu_offload(device=device) - else: - pipe.to(device=device) - - return pipe - - -def prepare_conditions( - tokenizer: LlamaTokenizer, - text_encoder: LlamaModel, - tokenizer_2: CLIPTokenizer, - text_encoder_2: CLIPTextModel, - prompt: Union[str, List[str]], - guidance: float = 1.0, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - max_sequence_length: int = 256, - # TODO(aryan): make configurable - prompt_template: Dict[str, Any] = { - "template": ( - "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " - "1. The main content and theme of the video." - "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." - "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." - "4. background environment, light, style and atmosphere." - "5. camera angles, movements, and transitions used in the video:<|eot_id|>" - "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" - ), - "crop_start": 95, - }, - **kwargs, -) -> torch.Tensor: - device = device or text_encoder.device - dtype = dtype or text_encoder.dtype - - if isinstance(prompt, str): - prompt = [prompt] - - conditions = {} - conditions.update( - _get_llama_prompt_embeds(tokenizer, text_encoder, prompt, prompt_template, device, dtype, max_sequence_length) - ) - conditions.update(_get_clip_prompt_embeds(tokenizer_2, text_encoder_2, prompt, device, dtype)) - - guidance = torch.tensor([guidance], device=device, dtype=dtype) * 1000.0 - conditions["guidance"] = guidance - - return conditions - - -def prepare_latents( - vae: AutoencoderKLHunyuanVideo, - image_or_video: torch.Tensor, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - generator: Optional[torch.Generator] = None, - precompute: bool = False, - **kwargs, -) -> torch.Tensor: - device = device or vae.device - dtype = dtype or vae.dtype - - if image_or_video.ndim == 4: - image_or_video = image_or_video.unsqueeze(2) - assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor" - - image_or_video = image_or_video.to(device=device, dtype=vae.dtype) - image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W] - if not precompute: - latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) - latents = latents * vae.config.scaling_factor - latents = latents.to(dtype=dtype) - return {"latents": latents} - else: - if vae.use_slicing and image_or_video.shape[0] > 1: - encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)] - h = torch.cat(encoded_slices) - else: - h = vae._encode(image_or_video) - return {"latents": h} - - -def post_latent_preparation( - vae_config: Dict[str, Any], - latents: torch.Tensor, - **kwargs, -) -> torch.Tensor: - latents = latents * vae_config.scaling_factor - return {"latents": latents} - - -def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: - return { - "prompts": [x["prompt"] for x in batch[0]], - "videos": torch.stack([x["video"] for x in batch[0]]), - } - - -def forward_pass( - transformer: HunyuanVideoTransformer3DModel, - prompt_embeds: torch.Tensor, - pooled_prompt_embeds: torch.Tensor, - prompt_attention_mask: torch.Tensor, - guidance: torch.Tensor, - latents: torch.Tensor, - noisy_latents: torch.Tensor, - timesteps: torch.LongTensor, - **kwargs, -) -> torch.Tensor: - denoised_latents = transformer( - hidden_states=noisy_latents, - timestep=timesteps, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - guidance=guidance, - return_dict=False, - )[0] - - return {"latents": denoised_latents} - - -def validation( - pipeline: HunyuanVideoPipeline, - prompt: str, - image: Optional[Image.Image] = None, - video: Optional[List[Image.Image]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: Optional[int] = None, - num_videos_per_prompt: int = 1, - generator: Optional[torch.Generator] = None, - **kwargs, -): - generation_kwargs = { - "prompt": prompt, - "height": height, - "width": width, - "num_frames": num_frames, - "num_inference_steps": 30, - "num_videos_per_prompt": num_videos_per_prompt, - "generator": generator, - "return_dict": True, - "output_type": "pil", - } - generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} - output = pipeline(**generation_kwargs).frames[0] - return [("video", output)] - - -def _get_llama_prompt_embeds( - tokenizer: LlamaTokenizer, - text_encoder: LlamaModel, - prompt: List[str], - prompt_template: Dict[str, Any], - device: torch.device, - dtype: torch.dtype, - max_sequence_length: int = 256, - num_hidden_layers_to_skip: int = 2, -) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size = len(prompt) - prompt = [prompt_template["template"].format(p) for p in prompt] - - crop_start = prompt_template.get("crop_start", None) - if crop_start is None: - prompt_template_input = tokenizer( - prompt_template["template"], - padding="max_length", - return_tensors="pt", - return_length=False, - return_overflowing_tokens=False, - return_attention_mask=False, - ) - crop_start = prompt_template_input["input_ids"].shape[-1] - # Remove <|eot_id|> token and placeholder {} - crop_start -= 2 - - max_sequence_length += crop_start - text_inputs = tokenizer( - prompt, - max_length=max_sequence_length, - padding="max_length", - truncation=True, - return_tensors="pt", - return_length=False, - return_overflowing_tokens=False, - return_attention_mask=True, - ) - text_input_ids = text_inputs.input_ids.to(device=device) - prompt_attention_mask = text_inputs.attention_mask.to(device=device) - - prompt_embeds = text_encoder( - input_ids=text_input_ids, - attention_mask=prompt_attention_mask, - output_hidden_states=True, - ).hidden_states[-(num_hidden_layers_to_skip + 1)] - prompt_embeds = prompt_embeds.to(dtype=dtype) - - if crop_start is not None and crop_start > 0: - prompt_embeds = prompt_embeds[:, crop_start:] - prompt_attention_mask = prompt_attention_mask[:, crop_start:] - - prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) - - return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask} - - -def _get_clip_prompt_embeds( - tokenizer_2: CLIPTokenizer, - text_encoder_2: CLIPTextModel, - prompt: Union[str, List[str]], - device: torch.device, - dtype: torch.dtype, - max_sequence_length: int = 77, -) -> torch.Tensor: - text_inputs = tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_tensors="pt", - ) - - prompt_embeds = text_encoder_2(text_inputs.input_ids.to(device), output_hidden_states=False).pooler_output - prompt_embeds = prompt_embeds.to(dtype=dtype) - - return {"pooled_prompt_embeds": prompt_embeds} - - -# TODO(aryan): refactor into model specs for better re-use -HUNYUAN_VIDEO_T2V_LORA_CONFIG = { - "pipeline_cls": HunyuanVideoPipeline, - "load_condition_models": load_condition_models, - "load_latent_models": load_latent_models, - "load_diffusion_models": load_diffusion_models, - "initialize_pipeline": initialize_pipeline, - "prepare_conditions": prepare_conditions, - "prepare_latents": prepare_latents, - "post_latent_preparation": post_latent_preparation, - "collate_fn": collate_fn_t2v, - "forward_pass": forward_pass, - "validation": validation, -} diff --git a/finetrainers/processors/__init__.py b/finetrainers/processors/__init__.py index a006352e..e55b3d14 100644 --- a/finetrainers/processors/__init__.py +++ b/finetrainers/processors/__init__.py @@ -1,3 +1,5 @@ from .base import ProcessorMixin +from .clip import CLIPPooledProcessor +from .llama import LlamaProcessor from .t5 import T5Processor from .text import CaptionEmbeddingDropoutProcessor, CaptionTextDropoutProcessor diff --git a/finetrainers/processors/base.py b/finetrainers/processors/base.py index 3862476d..9853ead0 100644 --- a/finetrainers/processors/base.py +++ b/finetrainers/processors/base.py @@ -1,13 +1,19 @@ import inspect -from typing import Any +from typing import Any, Dict, List class ProcessorMixin: def __init__(self) -> None: self._forward_parameter_names = inspect.signature(self.forward).parameters.keys() + self.output_names: List[str] = None + self.input_names: Dict[str, Any] = None def __call__(self, *args, **kwargs) -> Any: - acceptable_kwargs = {k: v for k, v in kwargs.items() if k in self._forward_parameter_names} + shallow_copy_kwargs = dict(kwargs.items()) + if self.input_names is not None: + for k, v in self.input_names.items(): + shallow_copy_kwargs[v] = shallow_copy_kwargs.pop(k) + acceptable_kwargs = {k: v for k, v in shallow_copy_kwargs.items() if k in self._forward_parameter_names} return self.forward(*args, **acceptable_kwargs) def forward(self, *args, **kwargs) -> Any: diff --git a/finetrainers/processors/clip.py b/finetrainers/processors/clip.py new file mode 100644 index 00000000..178addf8 --- /dev/null +++ b/finetrainers/processors/clip.py @@ -0,0 +1,65 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizerFast + +from .base import ProcessorMixin + + +class CLIPPooledProcessor(ProcessorMixin): + r""" + Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings + and attention masks for the input text. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor should return. The first output is the embeddings of the input + text and the second output is the attention mask for the input text. + """ + + def __init__(self, output_names: List[str] = None, input_names: Optional[Dict[str, Any]] = None) -> None: + super().__init__() + + self.output_names = output_names + self.input_names = input_names + + assert len(output_names) == 1 + if input_names is not None: + assert len(input_names) <= 3 + + def forward( + self, + tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast], + text_encoder: CLIPTextModel, + caption: Union[str, List[str]], + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Encode the input text and return the embeddings and attention mask for the input text. + + Args: + tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`): + The tokenizer used to tokenize the input text. + text_encoder (`LlamaModel`): + The text encoder used to encode the input text. + caption (`Union[str, List[str]]`): + The input text to be encoded. + """ + if isinstance(caption, str): + caption = [caption] + + device = text_encoder.device + dtype = text_encoder.dtype + + text_inputs = tokenizer( + caption, + padding="max_length", + max_length=77, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + + prompt_embeds = text_encoder(text_input_ids, output_hidden_states=False).pooler_output + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return {self.output_names[0]: prompt_embeds} diff --git a/finetrainers/processors/llama.py b/finetrainers/processors/llama.py new file mode 100644 index 00000000..749e5f31 --- /dev/null +++ b/finetrainers/processors/llama.py @@ -0,0 +1,118 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from transformers import LlamaModel, LlamaTokenizer, LlamaTokenizerFast + +from .base import ProcessorMixin + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +class LlamaProcessor(ProcessorMixin): + r""" + Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings + and attention masks for the input text. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor should return. The first output is the embeddings of the input + text and the second output is the attention mask for the input text. + """ + + def __init__(self, output_names: List[str] = None): + super().__init__() + + self.output_names = output_names + + assert len(output_names) == 2 + + def forward( + self, + tokenizer: Union[LlamaTokenizer, LlamaTokenizerFast], + text_encoder: LlamaModel, + caption: Union[str, List[str]], + max_sequence_length: int, + prompt_template: Optional[Dict[str, Any]] = None, + num_layers_to_skip: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Encode the input text and return the embeddings and attention mask for the input text. + + Args: + tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`): + The tokenizer used to tokenize the input text. + text_encoder (`LlamaModel`): + The text encoder used to encode the input text. + caption (`Union[str, List[str]]`): + The input text to be encoded. + max_sequence_length (`int`): + The maximum sequence length of the input text. + prompt_template (`Optional[Dict[str, Any]]`): + The prompt template to be used to encode the input text. + """ + if prompt_template is None: + prompt_template = DEFAULT_PROMPT_TEMPLATE + if isinstance(caption, str): + caption = [caption] + + device = text_encoder.device + dtype = text_encoder.dtype + + batch_size = len(caption) + caption = [prompt_template["template"].format(c) for c in caption] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = tokenizer( + caption, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device) + prompt_attention_mask = text_inputs.attention_mask.bool().to(device) + + prompt_embeds = text_encoder( + text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ).hidden_states[-(num_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + + return { + self.output_names[0]: prompt_embeds, + self.output_names[1]: prompt_attention_mask, + } diff --git a/finetrainers/processors/t5.py b/finetrainers/processors/t5.py index 6a810a65..96c2c194 100644 --- a/finetrainers/processors/t5.py +++ b/finetrainers/processors/t5.py @@ -21,6 +21,7 @@ def __init__(self, output_names: List[str]): super().__init__() self.output_names = output_names + assert len(self.output_names) == 2 def forward( diff --git a/tests/models/hunyuan_video/base_specification.py b/tests/models/hunyuan_video/base_specification.py new file mode 100644 index 00000000..e76b749e --- /dev/null +++ b/tests/models/hunyuan_video/base_specification.py @@ -0,0 +1,116 @@ +import pathlib +import sys + +import torch +from diffusers import AutoencoderKLHunyuanVideo, FlowMatchEulerDiscreteScheduler, HunyuanVideoTransformer3DModel +from transformers import ( + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, + LlamaConfig, + LlamaModel, + LlamaTokenizer, +) + + +project_root = pathlib.Path(__file__).resolve().parents[2] +sys.path.append(str(project_root)) + +from finetrainers.models.hunyuan_video import HunyuanVideoModelSpecification # noqa + + +class DummyHunyuanVideoModelSpecification(HunyuanVideoModelSpecification): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def load_condition_models(self): + llama_text_encoder_config = LlamaConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=16, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=8, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = LlamaModel(llama_text_encoder_config) + tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer") + + torch.manual_seed(0) + text_encoder_2 = CLIPTextModel(clip_text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + return { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + } + + def load_latent_models(self): + torch.manual_seed(0) + vae = AutoencoderKLHunyuanVideo( + in_channels=3, + out_channels=3, + latent_channels=4, + down_block_types=( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + up_block_types=( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + layers_per_block=1, + act_fn="silu", + norm_num_groups=4, + scaling_factor=0.476986, + spatial_compression_ratio=8, + temporal_compression_ratio=4, + mid_block_add_attention=True, + ) + return {"vae": vae} + + def load_diffusion_models(self): + torch.manual_seed(0) + transformer = HunyuanVideoTransformer3DModel( + in_channels=4, + out_channels=4, + num_attention_heads=2, + attention_head_dim=10, + num_layers=2, + num_single_layers=2, + num_refiner_layers=1, + patch_size=1, + patch_size_t=1, + guidance_embeds=True, + text_embed_dim=16, + pooled_projection_dim=8, + rope_axes_dim=(2, 4, 4), + ) + scheduler = FlowMatchEulerDiscreteScheduler() + return {"transformer": transformer, "scheduler": scheduler} diff --git a/tests/trainer/test_sft_trainer.py b/tests/trainer/test_sft_trainer.py index 09410243..43be856d 100644 --- a/tests/trainer/test_sft_trainer.py +++ b/tests/trainer/test_sft_trainer.py @@ -20,8 +20,9 @@ from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger # noqa from finetrainers.trainer.sft_trainer.config import SFTLowRankConfig, SFTFullRankConfig # noqa -from ..models.ltx_video.base_specification import DummyLTXVideoModelSpecification # noqa from ..models.cogvideox.base_specification import DummyCogVideoXModelSpecification # noqa +from ..models.hunyuan_video.base_specification import DummyHunyuanVideoModelSpecification # noqa +from ..models.ltx_video.base_specification import DummyLTXVideoModelSpecification # noqa from ..models.wan.base_specification import DummyWanModelSpecification # noqa @@ -208,14 +209,6 @@ def test___tp_degree_2___batch_size_2(self): self._test_training(args) -class SFTTrainerLTXVideoLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): - model_specification_cls = DummyLTXVideoModelSpecification - - -class SFTTrainerLTXVideoFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): - model_specification_cls = DummyLTXVideoModelSpecification - - class SFTTrainerCogVideoXLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyCogVideoXModelSpecification @@ -224,6 +217,22 @@ class SFTTrainerCogVideoXFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixi model_specification_cls = DummyCogVideoXModelSpecification +class SFTTrainerHunyuanVideoLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyHunyuanVideoModelSpecification + + +class SFTTrainerHunyuanVideoFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyHunyuanVideoModelSpecification + + +class SFTTrainerLTXVideoLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyLTXVideoModelSpecification + + +class SFTTrainerLTXVideoFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): + model_specification_cls = DummyLTXVideoModelSpecification + + class SFTTrainerWanLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): model_specification_cls = DummyWanModelSpecification