From a133b5856d18e544335841590321ebe13d94b2e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Thu, 19 Dec 2024 14:53:20 +0000 Subject: [PATCH 01/42] fix lora cp saving issue --- assets/prompt.txt | 20 ++- demo/gradio_web_demo.py | 188 +++++++++++++++----- fastvideo/models/mochi_hf/pipeline_mochi.py | 11 +- 3 files changed, 159 insertions(+), 60 deletions(-) diff --git a/assets/prompt.txt b/assets/prompt.txt index d10865e3..3c87a8fd 100644 --- a/assets/prompt.txt +++ b/assets/prompt.txt @@ -1,8 +1,12 @@ -Will Smith casually eats noodles, his relaxed demeanor contrasting with the energetic background of a bustling street food market. The scene captures a mix of humor and authenticity. Mid-shot framing, vibrant lighting. -A lone hiker stands atop a towering cliff, silhouetted against the vast horizon. The rugged landscape stretches endlessly beneath, its earthy tones blending into the soft blues of the sky. The scene captures the spirit of exploration and human resilience. High angle, dynamic framing, with soft natural lighting emphasizing the grandeur of nature. -A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. A beige string bag sits beside the bowl, adding a rustic touch to the scene. Additional lemons, one halved, are scattered around the base of the bowl. The even lighting enhances the vibrant colors and creates a fresh, inviting atmosphere. -A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest. The playful yet serene atmosphere is complemented by soft natural light filtering through the petals. Mid-shot, warm and cheerful tones. -A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robots immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1 -fox in the forest close-up quickly turned its head to the left -Man walking his dog in the woods on a hot sunny day -A majestic lion strides across the golden savanna, its powerful frame glistening under the warm afternoon sun. The tall grass ripples gently in the breeze, enhancing the lion's commanding presence. The tone is vibrant, embodying the raw energy of the wild. Low angle, steady tracking shot, cinematic. \ No newline at end of file +A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand's movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough. +A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds. +A crowded rooftop bar buzzes with energy, the city skyline twinkling like a field of stars in the background. Strings of fairy lights hang above, casting a warm, golden glow over the scene. Groups of people gather around high tables, their laughter blending with the soft rhythm of live jazz. The aroma of freshly mixed cocktails and charred appetizers wafts through the air, mingling with the cool night breeze. +En "The Matrix", Neo, interpretado por Keanu Reeves, personifica la lucha contra un sistema opresor a través de su icónica imagen, que incluye unos anteojos oscuros. Estos lentes no son solo un accesorio de moda; representan una barrera entre la realidad y la percepción. Al usar estos anteojos, Neo se sumerge en un mundo donde la verdad se oculta detrás de ilusiones y engaños. La oscuridad de los lentes simboliza la ignorancia y el control que las máquinas tienen sobre la humanidad, mientras que su propia búsqueda de la verdad lo lleva a descubrir sus auténticos poderes. La escena en que se los pone se convierte en un momento crucial, marcando su transformación de un simple programador a "El Elegido". Esta imagen se ha convertido en un ícono cultural, encapsulando el mensaje de que, al enfrentar la oscuridad, podemos encontrar la luz que nos guía hacia la libertad. Así, los anteojos de Neo se convierten en un símbolo de resistencia y autoconocimiento en un mundo manipulado. +Medium close up. Low-angle shot. A woman in a 1950s retro dress sits in a diner bathed in neon light, surrounded by classic decor and lively chatter. The camera starts with a medium shot of her sitting at the counter, then slowly zooms in as she blows a shiny pink bubblegum bubble. The bubble swells dramatically before popping with a soft, playful burst. The scene is vibrant and nostalgic, evoking the fun and carefree spirit of the 1950s. +Will Smith eats noodles. +A short clip of the blonde woman taking a sip from her whiskey glass, her eyes locking with the camera as she smirks playfully. The background shows a group of people laughing and enjoying the party, with vibrant neon signs illuminating the space. The shot is taken in a way that conveys the feeling of a tipsy, carefree night out. The camera then zooms in on her face as she winks, creating a cheeky, flirtatious vibe. +A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robot's immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1 +A chimpanzee lead vocalist singing into a microphone on stage. The camera zooms in to show him singing. There is a spotlight on him. +A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand's movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough. +A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds. +A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robot's immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1 diff --git a/demo/gradio_web_demo.py b/demo/gradio_web_demo.py index 29320489..945938f4 100644 --- a/demo/gradio_web_demo.py +++ b/demo/gradio_web_demo.py @@ -2,87 +2,122 @@ import torch from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel +from fastvideo.models.mochi_genmo.mochi_preview.dit.joint_model.asymm_models_joint import AsymmDiTJoint +from fastvideo.models.mochi_genmo.mochi_preview.vae.models import Decoder, Encoder from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.utils import export_to_video from fastvideo.distill.solver import PCMFMScheduler import tempfile import os import argparse - +from safetensors.torch import load_file def init_args(): parser = argparse.ArgumentParser() - parser.add_argument("--prompts", nargs="+", default=[]) - parser.add_argument("--num_frames", type=int, default=25) + parser.add_argument("--prompts", nargs='+', default=[]) + parser.add_argument("--num_frames", type=int, default=139) parser.add_argument("--height", type=int, default=480) parser.add_argument("--width", type=int, default=848) - parser.add_argument("--num_inference_steps", type=int, default=8) + parser.add_argument("--num_inference_steps", type=int, default=64) parser.add_argument("--guidance_scale", type=float, default=4.5) parser.add_argument("--model_path", type=str, default="data/mochi") - parser.add_argument("--seed", type=int, default=12345) + parser.add_argument("--seed", type=int, default=42) parser.add_argument("--transformer_path", type=str, default=None) - parser.add_argument("--scheduler_type", type=str, default="pcm_linear_quadratic") + parser.add_argument("--scheduler_type", type=str, default="euler") parser.add_argument("--lora_checkpoint_dir", type=str, default=None) parser.add_argument("--shift", type=float, default=8.0) - parser.add_argument("--num_euler_timesteps", type=int, default=50) - parser.add_argument("--linear_threshold", type=float, default=0.1) - parser.add_argument("--linear_range", type=float, default=0.75) + parser.add_argument("--num_euler_timesteps", type=int, default=100) + parser.add_argument("--linear_threshold", type=float, default=0.025) + parser.add_argument("--linear_range", type=float, default=0.5) parser.add_argument("--cpu_offload", action="store_true") return parser.parse_args() - def load_model(args): device = "cuda" if torch.cuda.is_available() else "cpu" if args.scheduler_type == "euler": scheduler = FlowMatchEulerDiscreteScheduler() else: - linear_quadratic = True if "linear_quadratic" in args.scheduler_type else False - scheduler = PCMFMScheduler( - 1000, - args.shift, - args.num_euler_timesteps, - linear_quadratic, - args.linear_threshold, - args.linear_range, + scheduler = PCMFMScheduler(1000, args.shift, args.num_euler_timesteps, False, args.linear_threshold, args.linear_range) + + mochi_genmo = False + if mochi_genmo: + vae_encoder_path = "/root/data/fastmochi_genmo/encoder.safetensors" + vae_decoder_path = "/root/data/fastmochi_genmo/decoder.safetensors" + dit_path = "/root/data/fastmochi_genmo/dit.safetensors" + + vae_encoder_state_dict = load_file(vae_encoder_path) + vae_decoder_state_dict = load_file(vae_decoder_path) + dit_state_dict = load_file(dit_path) + + vae_encoder = Encoder( + in_channels=15, + base_channels=64, + channel_multipliers=[1, 2, 4, 6], + num_res_blocks=[3, 3, 4, 6, 3], + latent_dim=12, + temporal_reductions=[1, 2, 3], + spatial_reductions=[2, 2, 2], + prune_bottlenecks=[False, False, False, False, False], + has_attentions=[False, True, True, True, True], + affine=True, + bias=True, + input_is_conv_1x1=True, + padding_mode="replicate", ) - - if args.transformer_path: - transformer = MochiTransformer3DModel.from_pretrained(args.transformer_path) - else: - transformer = MochiTransformer3DModel.from_pretrained( - args.model_path, subfolder="transformer/" + vae_decoder = Decoder( + out_channels=3, + base_channels=128, + channel_multipliers=[1, 2, 4, 6], + temporal_expansions=[1, 2, 3], + spatial_expansions=[2, 2, 2], + num_res_blocks=[3, 3, 4, 6, 3], + latent_dim=12, + has_attention=[False, False, False, False, False], + output_norm=False, + nonlinearity="silu", + output_nonlinearity="silu", + causal=True, ) - - pipe = MochiPipeline.from_pretrained( - args.model_path, transformer=transformer, scheduler=scheduler - ) + transformer = AsymmDiTJoint() + + vae_encoder.load_state_dict(vae_encoder_state_dict) + vae_decoder.load_state_dict(vae_decoder_state_dict) + transformer.load_state_dict(dit_state_dict) + + transformer.config.in_channels = 12 + else: + if args.transformer_path: + transformer = MochiTransformer3DModel.from_pretrained(args.transformer_path) + else: + transformer = MochiTransformer3DModel.from_pretrained(args.model_path, subfolder='transformer/') + + pipe = MochiPipeline.from_pretrained(args.model_path, transformer=transformer, scheduler=scheduler) + # from IPython import embed + # embed() + # del pipe.vae.encoder + # del pipe.vae.decoder + # pipe.vae.encoder = vae_encoder + # pipe.vae.decoder = vae_decoder + + pipe.enable_vae_tiling() - # pipe.to(device) - # if args.cpu_offload: - pipe.enable_sequential_cpu_offload() + pipe.to(device) + if args.cpu_offload: + pipe.enable_model_cpu_offload() return pipe - -def generate_video( - prompt, - negative_prompt, - use_negative_prompt, - seed, - guidance_scale, - num_frames, - height, - width, - num_inference_steps, - randomize_seed=False, -): +def generate_video(prompt, negative_prompt, use_negative_prompt, seed, guidance_scale, + num_frames, height, width, num_inference_steps, randomize_seed=False): if randomize_seed: seed = torch.randint(0, 1000000, (1,)).item() - + + pipe = load_model(args) + print("load model successfully") generator = torch.Generator(device="cuda").manual_seed(seed) - + if not use_negative_prompt: negative_prompt = None - + with torch.autocast("cuda", dtype=torch.bfloat16): output = pipe( prompt=[prompt], @@ -94,16 +129,24 @@ def generate_video( guidance_scale=guidance_scale, generator=generator, ).frames[0] +<<<<<<< Updated upstream +======= + +>>>>>>> Stashed changes output_path = os.path.join(tempfile.mkdtemp(), "output.mp4") export_to_video(output, output_path, fps=30) return output_path, seed +<<<<<<< Updated upstream +======= +>>>>>>> Stashed changes examples = [ "A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand’s movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough.", "A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds.", "A crowded rooftop bar buzzes with energy, the city skyline twinkling like a field of stars in the background. Strings of fairy lights hang above, casting a warm, golden glow over the scene. Groups of people gather around high tables, their laughter blending with the soft rhythm of live jazz. The aroma of freshly mixed cocktails and charred appetizers wafts through the air, mingling with the cool night breeze.", +<<<<<<< Updated upstream ] args = init_args() @@ -112,6 +155,18 @@ def generate_video( with gr.Blocks() as demo: gr.Markdown("# Fastvideo Mochi Video Generation Demo") +======= + "Will Smith eats noodles.", + "A short clip of the blonde woman taking a sip from her whiskey glass, her eyes locking with the camera as she smirks playfully. The background shows a group of people laughing and enjoying the party, with vibrant neon signs illuminating the space. The shot is taken in a way that conveys the feeling of a tipsy, carefree night out. The camera then zooms in on her face as she winks, creating a cheeky, flirtatious vibe.", + "A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robot's immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1" +] + +args = init_args() + +with gr.Blocks() as demo: + gr.Markdown("# Mochi Video Generation Demo") + +>>>>>>> Stashed changes with gr.Group(): with gr.Row(): prompt = gr.Text( @@ -123,6 +178,7 @@ def generate_video( ) run_button = gr.Button("Run", scale=0) result = gr.Video(label="Result", show_label=False) +<<<<<<< Updated upstream with gr.Accordion("Advanced options", open=False): with gr.Group(): @@ -162,10 +218,27 @@ def generate_video( use_negative_prompt = gr.Checkbox( label="Use negative prompt", value=False ) +======= + + with gr.Accordion("Advanced options", open=False): + with gr.Group(): + with gr.Row(): + height = gr.Slider(label="Height", minimum=256, maximum=1024, step=1, value=args.height) + width = gr.Slider(label="Width", minimum=256, maximum=1024, step=1, value=args.width) + + with gr.Row(): + num_frames = gr.Slider(label="Number of Frames", minimum=8, maximum=256, step=1, value=args.num_frames) + guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=20, step=0.5, value=args.guidance_scale) + num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=args.num_inference_steps) + + with gr.Row(): + use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False) +>>>>>>> Stashed changes negative_prompt = gr.Text( label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", +<<<<<<< Updated upstream visible=False, ) @@ -177,6 +250,17 @@ def generate_video( gr.Examples(examples=examples, inputs=prompt) +======= + visible=False + ) + + seed = gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=args.seed) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + seed_output = gr.Number(label="Used Seed") + + gr.Examples(examples=examples, inputs=prompt) + +>>>>>>> Stashed changes use_negative_prompt.change( fn=lambda x: gr.update(visible=x), inputs=use_negative_prompt, @@ -185,6 +269,7 @@ def generate_video( run_button.click( fn=generate_video, +<<<<<<< Updated upstream inputs=[ prompt, negative_prompt, @@ -202,3 +287,12 @@ def generate_video( if __name__ == "__main__": demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) +======= + inputs=[prompt, negative_prompt, use_negative_prompt, seed, guidance_scale, + num_frames, height, width, num_inference_steps, randomize_seed], + outputs=[result, seed_output] + ) + +if __name__ == "__main__": + demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) +>>>>>>> Stashed changes diff --git a/fastvideo/models/mochi_hf/pipeline_mochi.py b/fastvideo/models/mochi_hf/pipeline_mochi.py index 01e47478..3bb9eadb 100644 --- a/fastvideo/models/mochi_hf/pipeline_mochi.py +++ b/fastvideo/models/mochi_hf/pipeline_mochi.py @@ -36,7 +36,7 @@ from einops import rearrange from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info from fastvideo.utils.communications import all_gather - +from diffusers.loaders import Mochi1LoraLoaderMixin if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -165,7 +165,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class MochiPipeline(DiffusionPipeline): +class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin): r""" The mochi pipeline for text-to-video generation. @@ -502,7 +502,8 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32) + latents = latents.to(dtype) return latents @property @@ -533,8 +534,8 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, - num_frames: int = 16, - num_inference_steps: int = 28, + num_frames: int = 19, + num_inference_steps: int = 64, timesteps: List[int] = None, guidance_scale: float = 4.5, num_videos_per_prompt: Optional[int] = 1, From a3f8fc2e11c592e641d660cbabc6f6279b14e8ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Thu, 19 Dec 2024 14:57:19 +0000 Subject: [PATCH 02/42] fix lora save issue --- scripts/finetune/finetune_mochi_lora.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/finetune/finetune_mochi_lora.sh b/scripts/finetune/finetune_mochi_lora.sh index 26a745cd..2f5a39bd 100644 --- a/scripts/finetune/finetune_mochi_lora.sh +++ b/scripts/finetune/finetune_mochi_lora.sh @@ -9,16 +9,16 @@ torchrun --nnodes 1 --nproc_per_node 2 \ --data_json_path data/Mochi-Black-Myth/videos2caption.json \ --validation_prompt_dir data/Mochi-Black-Myth/validation \ --gradient_checkpointing \ - --train_batch_size=1 \ + --train_batch_size 1 \ --num_latent_t 14 \ --sp_size 2 \ --train_sp_batch_size 1 \ --dataloader_num_workers 1 \ - --gradient_accumulation_steps=2 \ - --max_train_steps=2000 \ - --learning_rate=5e-6 \ - --mixed_precision=bf16 \ - --checkpointing_steps=200 \ + --gradient_accumulation_steps 2 \ + --max_train_steps 2000 \ + --learning_rate 5e-6 \ + --mixed_precision bf16 \ + --checkpointing_steps 200 \ --validation_steps 100 \ --validation_sampling_steps 64 \ --checkpoints_total_limit 3 \ From 580575cce8fbb154b0cd0d740e0b3cebbc0d2b5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Thu, 19 Dec 2024 14:58:16 +0000 Subject: [PATCH 03/42] fix lora save issue --- assets/prompt.txt | 12 -- demo/gradio_web_demo.py | 298 ---------------------------------------- 2 files changed, 310 deletions(-) delete mode 100644 assets/prompt.txt delete mode 100644 demo/gradio_web_demo.py diff --git a/assets/prompt.txt b/assets/prompt.txt deleted file mode 100644 index 3c87a8fd..00000000 --- a/assets/prompt.txt +++ /dev/null @@ -1,12 +0,0 @@ -A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand's movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough. -A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds. -A crowded rooftop bar buzzes with energy, the city skyline twinkling like a field of stars in the background. Strings of fairy lights hang above, casting a warm, golden glow over the scene. Groups of people gather around high tables, their laughter blending with the soft rhythm of live jazz. The aroma of freshly mixed cocktails and charred appetizers wafts through the air, mingling with the cool night breeze. -En "The Matrix", Neo, interpretado por Keanu Reeves, personifica la lucha contra un sistema opresor a través de su icónica imagen, que incluye unos anteojos oscuros. Estos lentes no son solo un accesorio de moda; representan una barrera entre la realidad y la percepción. Al usar estos anteojos, Neo se sumerge en un mundo donde la verdad se oculta detrás de ilusiones y engaños. La oscuridad de los lentes simboliza la ignorancia y el control que las máquinas tienen sobre la humanidad, mientras que su propia búsqueda de la verdad lo lleva a descubrir sus auténticos poderes. La escena en que se los pone se convierte en un momento crucial, marcando su transformación de un simple programador a "El Elegido". Esta imagen se ha convertido en un ícono cultural, encapsulando el mensaje de que, al enfrentar la oscuridad, podemos encontrar la luz que nos guía hacia la libertad. Así, los anteojos de Neo se convierten en un símbolo de resistencia y autoconocimiento en un mundo manipulado. -Medium close up. Low-angle shot. A woman in a 1950s retro dress sits in a diner bathed in neon light, surrounded by classic decor and lively chatter. The camera starts with a medium shot of her sitting at the counter, then slowly zooms in as she blows a shiny pink bubblegum bubble. The bubble swells dramatically before popping with a soft, playful burst. The scene is vibrant and nostalgic, evoking the fun and carefree spirit of the 1950s. -Will Smith eats noodles. -A short clip of the blonde woman taking a sip from her whiskey glass, her eyes locking with the camera as she smirks playfully. The background shows a group of people laughing and enjoying the party, with vibrant neon signs illuminating the space. The shot is taken in a way that conveys the feeling of a tipsy, carefree night out. The camera then zooms in on her face as she winks, creating a cheeky, flirtatious vibe. -A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robot's immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1 -A chimpanzee lead vocalist singing into a microphone on stage. The camera zooms in to show him singing. There is a spotlight on him. -A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand's movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough. -A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds. -A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robot's immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1 diff --git a/demo/gradio_web_demo.py b/demo/gradio_web_demo.py deleted file mode 100644 index 945938f4..00000000 --- a/demo/gradio_web_demo.py +++ /dev/null @@ -1,298 +0,0 @@ -import gradio as gr -import torch -from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline -from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel -from fastvideo.models.mochi_genmo.mochi_preview.dit.joint_model.asymm_models_joint import AsymmDiTJoint -from fastvideo.models.mochi_genmo.mochi_preview.vae.models import Decoder, Encoder -from diffusers import FlowMatchEulerDiscreteScheduler -from diffusers.utils import export_to_video -from fastvideo.distill.solver import PCMFMScheduler -import tempfile -import os -import argparse -from safetensors.torch import load_file - -def init_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--prompts", nargs='+', default=[]) - parser.add_argument("--num_frames", type=int, default=139) - parser.add_argument("--height", type=int, default=480) - parser.add_argument("--width", type=int, default=848) - parser.add_argument("--num_inference_steps", type=int, default=64) - parser.add_argument("--guidance_scale", type=float, default=4.5) - parser.add_argument("--model_path", type=str, default="data/mochi") - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--transformer_path", type=str, default=None) - parser.add_argument("--scheduler_type", type=str, default="euler") - parser.add_argument("--lora_checkpoint_dir", type=str, default=None) - parser.add_argument("--shift", type=float, default=8.0) - parser.add_argument("--num_euler_timesteps", type=int, default=100) - parser.add_argument("--linear_threshold", type=float, default=0.025) - parser.add_argument("--linear_range", type=float, default=0.5) - parser.add_argument("--cpu_offload", action="store_true") - return parser.parse_args() - -def load_model(args): - device = "cuda" if torch.cuda.is_available() else "cpu" - if args.scheduler_type == "euler": - scheduler = FlowMatchEulerDiscreteScheduler() - else: - scheduler = PCMFMScheduler(1000, args.shift, args.num_euler_timesteps, False, args.linear_threshold, args.linear_range) - - mochi_genmo = False - if mochi_genmo: - vae_encoder_path = "/root/data/fastmochi_genmo/encoder.safetensors" - vae_decoder_path = "/root/data/fastmochi_genmo/decoder.safetensors" - dit_path = "/root/data/fastmochi_genmo/dit.safetensors" - - vae_encoder_state_dict = load_file(vae_encoder_path) - vae_decoder_state_dict = load_file(vae_decoder_path) - dit_state_dict = load_file(dit_path) - - vae_encoder = Encoder( - in_channels=15, - base_channels=64, - channel_multipliers=[1, 2, 4, 6], - num_res_blocks=[3, 3, 4, 6, 3], - latent_dim=12, - temporal_reductions=[1, 2, 3], - spatial_reductions=[2, 2, 2], - prune_bottlenecks=[False, False, False, False, False], - has_attentions=[False, True, True, True, True], - affine=True, - bias=True, - input_is_conv_1x1=True, - padding_mode="replicate", - ) - vae_decoder = Decoder( - out_channels=3, - base_channels=128, - channel_multipliers=[1, 2, 4, 6], - temporal_expansions=[1, 2, 3], - spatial_expansions=[2, 2, 2], - num_res_blocks=[3, 3, 4, 6, 3], - latent_dim=12, - has_attention=[False, False, False, False, False], - output_norm=False, - nonlinearity="silu", - output_nonlinearity="silu", - causal=True, - ) - transformer = AsymmDiTJoint() - - vae_encoder.load_state_dict(vae_encoder_state_dict) - vae_decoder.load_state_dict(vae_decoder_state_dict) - transformer.load_state_dict(dit_state_dict) - - transformer.config.in_channels = 12 - else: - if args.transformer_path: - transformer = MochiTransformer3DModel.from_pretrained(args.transformer_path) - else: - transformer = MochiTransformer3DModel.from_pretrained(args.model_path, subfolder='transformer/') - - pipe = MochiPipeline.from_pretrained(args.model_path, transformer=transformer, scheduler=scheduler) - # from IPython import embed - # embed() - # del pipe.vae.encoder - # del pipe.vae.decoder - # pipe.vae.encoder = vae_encoder - # pipe.vae.decoder = vae_decoder - - - pipe.enable_vae_tiling() - pipe.to(device) - if args.cpu_offload: - pipe.enable_model_cpu_offload() - return pipe - -def generate_video(prompt, negative_prompt, use_negative_prompt, seed, guidance_scale, - num_frames, height, width, num_inference_steps, randomize_seed=False): - if randomize_seed: - seed = torch.randint(0, 1000000, (1,)).item() - - pipe = load_model(args) - print("load model successfully") - generator = torch.Generator(device="cuda").manual_seed(seed) - - if not use_negative_prompt: - negative_prompt = None - - with torch.autocast("cuda", dtype=torch.bfloat16): - output = pipe( - prompt=[prompt], - negative_prompt=negative_prompt, - height=height, - width=width, - num_frames=num_frames, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - generator=generator, - ).frames[0] -<<<<<<< Updated upstream - -======= - ->>>>>>> Stashed changes - output_path = os.path.join(tempfile.mkdtemp(), "output.mp4") - export_to_video(output, output_path, fps=30) - return output_path, seed - -<<<<<<< Updated upstream - -======= ->>>>>>> Stashed changes -examples = [ - "A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand’s movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough.", - "A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds.", - "A crowded rooftop bar buzzes with energy, the city skyline twinkling like a field of stars in the background. Strings of fairy lights hang above, casting a warm, golden glow over the scene. Groups of people gather around high tables, their laughter blending with the soft rhythm of live jazz. The aroma of freshly mixed cocktails and charred appetizers wafts through the air, mingling with the cool night breeze.", -<<<<<<< Updated upstream -] - -args = init_args() -pipe = load_model(args) -print("load model successfully") -with gr.Blocks() as demo: - gr.Markdown("# Fastvideo Mochi Video Generation Demo") - -======= - "Will Smith eats noodles.", - "A short clip of the blonde woman taking a sip from her whiskey glass, her eyes locking with the camera as she smirks playfully. The background shows a group of people laughing and enjoying the party, with vibrant neon signs illuminating the space. The shot is taken in a way that conveys the feeling of a tipsy, carefree night out. The camera then zooms in on her face as she winks, creating a cheeky, flirtatious vibe.", - "A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robot's immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1" -] - -args = init_args() - -with gr.Blocks() as demo: - gr.Markdown("# Mochi Video Generation Demo") - ->>>>>>> Stashed changes - with gr.Group(): - with gr.Row(): - prompt = gr.Text( - label="Prompt", - show_label=False, - max_lines=1, - placeholder="Enter your prompt", - container=False, - ) - run_button = gr.Button("Run", scale=0) - result = gr.Video(label="Result", show_label=False) -<<<<<<< Updated upstream - - with gr.Accordion("Advanced options", open=False): - with gr.Group(): - with gr.Row(): - height = gr.Slider( - label="Height", - minimum=256, - maximum=1024, - step=32, - value=args.height, - ) - width = gr.Slider( - label="Width", minimum=256, maximum=1024, step=32, value=args.width - ) - - with gr.Row(): - num_frames = gr.Slider( - label="Number of Frames", - minimum=21, - maximum=163, - value=args.num_frames, - ) - guidance_scale = gr.Slider( - label="Guidance Scale", - minimum=1, - maximum=12, - value=args.guidance_scale, - ) - num_inference_steps = gr.Slider( - label="Inference Steps", - minimum=4, - maximum=100, - value=args.num_inference_steps, - ) - - with gr.Row(): - use_negative_prompt = gr.Checkbox( - label="Use negative prompt", value=False - ) -======= - - with gr.Accordion("Advanced options", open=False): - with gr.Group(): - with gr.Row(): - height = gr.Slider(label="Height", minimum=256, maximum=1024, step=1, value=args.height) - width = gr.Slider(label="Width", minimum=256, maximum=1024, step=1, value=args.width) - - with gr.Row(): - num_frames = gr.Slider(label="Number of Frames", minimum=8, maximum=256, step=1, value=args.num_frames) - guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=20, step=0.5, value=args.guidance_scale) - num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=args.num_inference_steps) - - with gr.Row(): - use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False) ->>>>>>> Stashed changes - negative_prompt = gr.Text( - label="Negative prompt", - max_lines=1, - placeholder="Enter a negative prompt", -<<<<<<< Updated upstream - visible=False, - ) - - seed = gr.Slider( - label="Seed", minimum=0, maximum=1000000, step=1, value=args.seed - ) - randomize_seed = gr.Checkbox(label="Randomize seed", value=True) - seed_output = gr.Number(label="Used Seed") - - gr.Examples(examples=examples, inputs=prompt) - -======= - visible=False - ) - - seed = gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=args.seed) - randomize_seed = gr.Checkbox(label="Randomize seed", value=True) - seed_output = gr.Number(label="Used Seed") - - gr.Examples(examples=examples, inputs=prompt) - ->>>>>>> Stashed changes - use_negative_prompt.change( - fn=lambda x: gr.update(visible=x), - inputs=use_negative_prompt, - outputs=negative_prompt, - ) - - run_button.click( - fn=generate_video, -<<<<<<< Updated upstream - inputs=[ - prompt, - negative_prompt, - use_negative_prompt, - seed, - guidance_scale, - num_frames, - height, - width, - num_inference_steps, - randomize_seed, - ], - outputs=[result, seed_output], - ) - -if __name__ == "__main__": - demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) -======= - inputs=[prompt, negative_prompt, use_negative_prompt, seed, guidance_scale, - num_frames, height, width, num_inference_steps, randomize_seed], - outputs=[result, seed_output] - ) - -if __name__ == "__main__": - demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) ->>>>>>> Stashed changes From 4c169916ec658ace3b27652cfe5da02dfc6d8615 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Thu, 19 Dec 2024 15:01:55 +0000 Subject: [PATCH 04/42] Revert "fix lora save issue" This reverts commit 580575cce8fbb154b0cd0d740e0b3cebbc0d2b5c. --- assets/prompt.txt | 12 ++ demo/gradio_web_demo.py | 298 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 310 insertions(+) create mode 100644 assets/prompt.txt create mode 100644 demo/gradio_web_demo.py diff --git a/assets/prompt.txt b/assets/prompt.txt new file mode 100644 index 00000000..3c87a8fd --- /dev/null +++ b/assets/prompt.txt @@ -0,0 +1,12 @@ +A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand's movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough. +A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds. +A crowded rooftop bar buzzes with energy, the city skyline twinkling like a field of stars in the background. Strings of fairy lights hang above, casting a warm, golden glow over the scene. Groups of people gather around high tables, their laughter blending with the soft rhythm of live jazz. The aroma of freshly mixed cocktails and charred appetizers wafts through the air, mingling with the cool night breeze. +En "The Matrix", Neo, interpretado por Keanu Reeves, personifica la lucha contra un sistema opresor a través de su icónica imagen, que incluye unos anteojos oscuros. Estos lentes no son solo un accesorio de moda; representan una barrera entre la realidad y la percepción. Al usar estos anteojos, Neo se sumerge en un mundo donde la verdad se oculta detrás de ilusiones y engaños. La oscuridad de los lentes simboliza la ignorancia y el control que las máquinas tienen sobre la humanidad, mientras que su propia búsqueda de la verdad lo lleva a descubrir sus auténticos poderes. La escena en que se los pone se convierte en un momento crucial, marcando su transformación de un simple programador a "El Elegido". Esta imagen se ha convertido en un ícono cultural, encapsulando el mensaje de que, al enfrentar la oscuridad, podemos encontrar la luz que nos guía hacia la libertad. Así, los anteojos de Neo se convierten en un símbolo de resistencia y autoconocimiento en un mundo manipulado. +Medium close up. Low-angle shot. A woman in a 1950s retro dress sits in a diner bathed in neon light, surrounded by classic decor and lively chatter. The camera starts with a medium shot of her sitting at the counter, then slowly zooms in as she blows a shiny pink bubblegum bubble. The bubble swells dramatically before popping with a soft, playful burst. The scene is vibrant and nostalgic, evoking the fun and carefree spirit of the 1950s. +Will Smith eats noodles. +A short clip of the blonde woman taking a sip from her whiskey glass, her eyes locking with the camera as she smirks playfully. The background shows a group of people laughing and enjoying the party, with vibrant neon signs illuminating the space. The shot is taken in a way that conveys the feeling of a tipsy, carefree night out. The camera then zooms in on her face as she winks, creating a cheeky, flirtatious vibe. +A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robot's immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1 +A chimpanzee lead vocalist singing into a microphone on stage. The camera zooms in to show him singing. There is a spotlight on him. +A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand's movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough. +A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds. +A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robot's immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1 diff --git a/demo/gradio_web_demo.py b/demo/gradio_web_demo.py new file mode 100644 index 00000000..945938f4 --- /dev/null +++ b/demo/gradio_web_demo.py @@ -0,0 +1,298 @@ +import gradio as gr +import torch +from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline +from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel +from fastvideo.models.mochi_genmo.mochi_preview.dit.joint_model.asymm_models_joint import AsymmDiTJoint +from fastvideo.models.mochi_genmo.mochi_preview.vae.models import Decoder, Encoder +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import export_to_video +from fastvideo.distill.solver import PCMFMScheduler +import tempfile +import os +import argparse +from safetensors.torch import load_file + +def init_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--prompts", nargs='+', default=[]) + parser.add_argument("--num_frames", type=int, default=139) + parser.add_argument("--height", type=int, default=480) + parser.add_argument("--width", type=int, default=848) + parser.add_argument("--num_inference_steps", type=int, default=64) + parser.add_argument("--guidance_scale", type=float, default=4.5) + parser.add_argument("--model_path", type=str, default="data/mochi") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--transformer_path", type=str, default=None) + parser.add_argument("--scheduler_type", type=str, default="euler") + parser.add_argument("--lora_checkpoint_dir", type=str, default=None) + parser.add_argument("--shift", type=float, default=8.0) + parser.add_argument("--num_euler_timesteps", type=int, default=100) + parser.add_argument("--linear_threshold", type=float, default=0.025) + parser.add_argument("--linear_range", type=float, default=0.5) + parser.add_argument("--cpu_offload", action="store_true") + return parser.parse_args() + +def load_model(args): + device = "cuda" if torch.cuda.is_available() else "cpu" + if args.scheduler_type == "euler": + scheduler = FlowMatchEulerDiscreteScheduler() + else: + scheduler = PCMFMScheduler(1000, args.shift, args.num_euler_timesteps, False, args.linear_threshold, args.linear_range) + + mochi_genmo = False + if mochi_genmo: + vae_encoder_path = "/root/data/fastmochi_genmo/encoder.safetensors" + vae_decoder_path = "/root/data/fastmochi_genmo/decoder.safetensors" + dit_path = "/root/data/fastmochi_genmo/dit.safetensors" + + vae_encoder_state_dict = load_file(vae_encoder_path) + vae_decoder_state_dict = load_file(vae_decoder_path) + dit_state_dict = load_file(dit_path) + + vae_encoder = Encoder( + in_channels=15, + base_channels=64, + channel_multipliers=[1, 2, 4, 6], + num_res_blocks=[3, 3, 4, 6, 3], + latent_dim=12, + temporal_reductions=[1, 2, 3], + spatial_reductions=[2, 2, 2], + prune_bottlenecks=[False, False, False, False, False], + has_attentions=[False, True, True, True, True], + affine=True, + bias=True, + input_is_conv_1x1=True, + padding_mode="replicate", + ) + vae_decoder = Decoder( + out_channels=3, + base_channels=128, + channel_multipliers=[1, 2, 4, 6], + temporal_expansions=[1, 2, 3], + spatial_expansions=[2, 2, 2], + num_res_blocks=[3, 3, 4, 6, 3], + latent_dim=12, + has_attention=[False, False, False, False, False], + output_norm=False, + nonlinearity="silu", + output_nonlinearity="silu", + causal=True, + ) + transformer = AsymmDiTJoint() + + vae_encoder.load_state_dict(vae_encoder_state_dict) + vae_decoder.load_state_dict(vae_decoder_state_dict) + transformer.load_state_dict(dit_state_dict) + + transformer.config.in_channels = 12 + else: + if args.transformer_path: + transformer = MochiTransformer3DModel.from_pretrained(args.transformer_path) + else: + transformer = MochiTransformer3DModel.from_pretrained(args.model_path, subfolder='transformer/') + + pipe = MochiPipeline.from_pretrained(args.model_path, transformer=transformer, scheduler=scheduler) + # from IPython import embed + # embed() + # del pipe.vae.encoder + # del pipe.vae.decoder + # pipe.vae.encoder = vae_encoder + # pipe.vae.decoder = vae_decoder + + + pipe.enable_vae_tiling() + pipe.to(device) + if args.cpu_offload: + pipe.enable_model_cpu_offload() + return pipe + +def generate_video(prompt, negative_prompt, use_negative_prompt, seed, guidance_scale, + num_frames, height, width, num_inference_steps, randomize_seed=False): + if randomize_seed: + seed = torch.randint(0, 1000000, (1,)).item() + + pipe = load_model(args) + print("load model successfully") + generator = torch.Generator(device="cuda").manual_seed(seed) + + if not use_negative_prompt: + negative_prompt = None + + with torch.autocast("cuda", dtype=torch.bfloat16): + output = pipe( + prompt=[prompt], + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=generator, + ).frames[0] +<<<<<<< Updated upstream + +======= + +>>>>>>> Stashed changes + output_path = os.path.join(tempfile.mkdtemp(), "output.mp4") + export_to_video(output, output_path, fps=30) + return output_path, seed + +<<<<<<< Updated upstream + +======= +>>>>>>> Stashed changes +examples = [ + "A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand’s movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough.", + "A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds.", + "A crowded rooftop bar buzzes with energy, the city skyline twinkling like a field of stars in the background. Strings of fairy lights hang above, casting a warm, golden glow over the scene. Groups of people gather around high tables, their laughter blending with the soft rhythm of live jazz. The aroma of freshly mixed cocktails and charred appetizers wafts through the air, mingling with the cool night breeze.", +<<<<<<< Updated upstream +] + +args = init_args() +pipe = load_model(args) +print("load model successfully") +with gr.Blocks() as demo: + gr.Markdown("# Fastvideo Mochi Video Generation Demo") + +======= + "Will Smith eats noodles.", + "A short clip of the blonde woman taking a sip from her whiskey glass, her eyes locking with the camera as she smirks playfully. The background shows a group of people laughing and enjoying the party, with vibrant neon signs illuminating the space. The shot is taken in a way that conveys the feeling of a tipsy, carefree night out. The camera then zooms in on her face as she winks, creating a cheeky, flirtatious vibe.", + "A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robot's immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1" +] + +args = init_args() + +with gr.Blocks() as demo: + gr.Markdown("# Mochi Video Generation Demo") + +>>>>>>> Stashed changes + with gr.Group(): + with gr.Row(): + prompt = gr.Text( + label="Prompt", + show_label=False, + max_lines=1, + placeholder="Enter your prompt", + container=False, + ) + run_button = gr.Button("Run", scale=0) + result = gr.Video(label="Result", show_label=False) +<<<<<<< Updated upstream + + with gr.Accordion("Advanced options", open=False): + with gr.Group(): + with gr.Row(): + height = gr.Slider( + label="Height", + minimum=256, + maximum=1024, + step=32, + value=args.height, + ) + width = gr.Slider( + label="Width", minimum=256, maximum=1024, step=32, value=args.width + ) + + with gr.Row(): + num_frames = gr.Slider( + label="Number of Frames", + minimum=21, + maximum=163, + value=args.num_frames, + ) + guidance_scale = gr.Slider( + label="Guidance Scale", + minimum=1, + maximum=12, + value=args.guidance_scale, + ) + num_inference_steps = gr.Slider( + label="Inference Steps", + minimum=4, + maximum=100, + value=args.num_inference_steps, + ) + + with gr.Row(): + use_negative_prompt = gr.Checkbox( + label="Use negative prompt", value=False + ) +======= + + with gr.Accordion("Advanced options", open=False): + with gr.Group(): + with gr.Row(): + height = gr.Slider(label="Height", minimum=256, maximum=1024, step=1, value=args.height) + width = gr.Slider(label="Width", minimum=256, maximum=1024, step=1, value=args.width) + + with gr.Row(): + num_frames = gr.Slider(label="Number of Frames", minimum=8, maximum=256, step=1, value=args.num_frames) + guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=20, step=0.5, value=args.guidance_scale) + num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=args.num_inference_steps) + + with gr.Row(): + use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False) +>>>>>>> Stashed changes + negative_prompt = gr.Text( + label="Negative prompt", + max_lines=1, + placeholder="Enter a negative prompt", +<<<<<<< Updated upstream + visible=False, + ) + + seed = gr.Slider( + label="Seed", minimum=0, maximum=1000000, step=1, value=args.seed + ) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + seed_output = gr.Number(label="Used Seed") + + gr.Examples(examples=examples, inputs=prompt) + +======= + visible=False + ) + + seed = gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=args.seed) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + seed_output = gr.Number(label="Used Seed") + + gr.Examples(examples=examples, inputs=prompt) + +>>>>>>> Stashed changes + use_negative_prompt.change( + fn=lambda x: gr.update(visible=x), + inputs=use_negative_prompt, + outputs=negative_prompt, + ) + + run_button.click( + fn=generate_video, +<<<<<<< Updated upstream + inputs=[ + prompt, + negative_prompt, + use_negative_prompt, + seed, + guidance_scale, + num_frames, + height, + width, + num_inference_steps, + randomize_seed, + ], + outputs=[result, seed_output], + ) + +if __name__ == "__main__": + demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) +======= + inputs=[prompt, negative_prompt, use_negative_prompt, seed, guidance_scale, + num_frames, height, width, num_inference_steps, randomize_seed], + outputs=[result, seed_output] + ) + +if __name__ == "__main__": + demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) +>>>>>>> Stashed changes From 677efd9ccc4f4daa7453c5a1fb4ba9d714ffc9a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Thu, 19 Dec 2024 15:04:41 +0000 Subject: [PATCH 05/42] fix lora save issue --- assets/prompt.txt | 20 ++--- demo/gradio_web_demo.py | 190 ++++++++++------------------------------ 2 files changed, 56 insertions(+), 154 deletions(-) diff --git a/assets/prompt.txt b/assets/prompt.txt index 3c87a8fd..d10865e3 100644 --- a/assets/prompt.txt +++ b/assets/prompt.txt @@ -1,12 +1,8 @@ -A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand's movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough. -A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds. -A crowded rooftop bar buzzes with energy, the city skyline twinkling like a field of stars in the background. Strings of fairy lights hang above, casting a warm, golden glow over the scene. Groups of people gather around high tables, their laughter blending with the soft rhythm of live jazz. The aroma of freshly mixed cocktails and charred appetizers wafts through the air, mingling with the cool night breeze. -En "The Matrix", Neo, interpretado por Keanu Reeves, personifica la lucha contra un sistema opresor a través de su icónica imagen, que incluye unos anteojos oscuros. Estos lentes no son solo un accesorio de moda; representan una barrera entre la realidad y la percepción. Al usar estos anteojos, Neo se sumerge en un mundo donde la verdad se oculta detrás de ilusiones y engaños. La oscuridad de los lentes simboliza la ignorancia y el control que las máquinas tienen sobre la humanidad, mientras que su propia búsqueda de la verdad lo lleva a descubrir sus auténticos poderes. La escena en que se los pone se convierte en un momento crucial, marcando su transformación de un simple programador a "El Elegido". Esta imagen se ha convertido en un ícono cultural, encapsulando el mensaje de que, al enfrentar la oscuridad, podemos encontrar la luz que nos guía hacia la libertad. Así, los anteojos de Neo se convierten en un símbolo de resistencia y autoconocimiento en un mundo manipulado. -Medium close up. Low-angle shot. A woman in a 1950s retro dress sits in a diner bathed in neon light, surrounded by classic decor and lively chatter. The camera starts with a medium shot of her sitting at the counter, then slowly zooms in as she blows a shiny pink bubblegum bubble. The bubble swells dramatically before popping with a soft, playful burst. The scene is vibrant and nostalgic, evoking the fun and carefree spirit of the 1950s. -Will Smith eats noodles. -A short clip of the blonde woman taking a sip from her whiskey glass, her eyes locking with the camera as she smirks playfully. The background shows a group of people laughing and enjoying the party, with vibrant neon signs illuminating the space. The shot is taken in a way that conveys the feeling of a tipsy, carefree night out. The camera then zooms in on her face as she winks, creating a cheeky, flirtatious vibe. -A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robot's immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1 -A chimpanzee lead vocalist singing into a microphone on stage. The camera zooms in to show him singing. There is a spotlight on him. -A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand's movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough. -A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds. -A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robot's immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1 +Will Smith casually eats noodles, his relaxed demeanor contrasting with the energetic background of a bustling street food market. The scene captures a mix of humor and authenticity. Mid-shot framing, vibrant lighting. +A lone hiker stands atop a towering cliff, silhouetted against the vast horizon. The rugged landscape stretches endlessly beneath, its earthy tones blending into the soft blues of the sky. The scene captures the spirit of exploration and human resilience. High angle, dynamic framing, with soft natural lighting emphasizing the grandeur of nature. +A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. A beige string bag sits beside the bowl, adding a rustic touch to the scene. Additional lemons, one halved, are scattered around the base of the bowl. The even lighting enhances the vibrant colors and creates a fresh, inviting atmosphere. +A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest. The playful yet serene atmosphere is complemented by soft natural light filtering through the petals. Mid-shot, warm and cheerful tones. +A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robots immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1 +fox in the forest close-up quickly turned its head to the left +Man walking his dog in the woods on a hot sunny day +A majestic lion strides across the golden savanna, its powerful frame glistening under the warm afternoon sun. The tall grass ripples gently in the breeze, enhancing the lion's commanding presence. The tone is vibrant, embodying the raw energy of the wild. Low angle, steady tracking shot, cinematic. \ No newline at end of file diff --git a/demo/gradio_web_demo.py b/demo/gradio_web_demo.py index 945938f4..1d4b5962 100644 --- a/demo/gradio_web_demo.py +++ b/demo/gradio_web_demo.py @@ -2,122 +2,87 @@ import torch from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel -from fastvideo.models.mochi_genmo.mochi_preview.dit.joint_model.asymm_models_joint import AsymmDiTJoint -from fastvideo.models.mochi_genmo.mochi_preview.vae.models import Decoder, Encoder from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.utils import export_to_video from fastvideo.distill.solver import PCMFMScheduler import tempfile import os import argparse -from safetensors.torch import load_file + def init_args(): parser = argparse.ArgumentParser() - parser.add_argument("--prompts", nargs='+', default=[]) - parser.add_argument("--num_frames", type=int, default=139) + parser.add_argument("--prompts", nargs="+", default=[]) + parser.add_argument("--num_frames", type=int, default=25) parser.add_argument("--height", type=int, default=480) parser.add_argument("--width", type=int, default=848) - parser.add_argument("--num_inference_steps", type=int, default=64) + parser.add_argument("--num_inference_steps", type=int, default=8) parser.add_argument("--guidance_scale", type=float, default=4.5) parser.add_argument("--model_path", type=str, default="data/mochi") - parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--seed", type=int, default=12345) parser.add_argument("--transformer_path", type=str, default=None) - parser.add_argument("--scheduler_type", type=str, default="euler") + parser.add_argument("--scheduler_type", type=str, default="pcm_linear_quadratic") parser.add_argument("--lora_checkpoint_dir", type=str, default=None) parser.add_argument("--shift", type=float, default=8.0) - parser.add_argument("--num_euler_timesteps", type=int, default=100) - parser.add_argument("--linear_threshold", type=float, default=0.025) - parser.add_argument("--linear_range", type=float, default=0.5) + parser.add_argument("--num_euler_timesteps", type=int, default=50) + parser.add_argument("--linear_threshold", type=float, default=0.1) + parser.add_argument("--linear_range", type=float, default=0.75) parser.add_argument("--cpu_offload", action="store_true") return parser.parse_args() + def load_model(args): device = "cuda" if torch.cuda.is_available() else "cpu" if args.scheduler_type == "euler": scheduler = FlowMatchEulerDiscreteScheduler() else: - scheduler = PCMFMScheduler(1000, args.shift, args.num_euler_timesteps, False, args.linear_threshold, args.linear_range) - - mochi_genmo = False - if mochi_genmo: - vae_encoder_path = "/root/data/fastmochi_genmo/encoder.safetensors" - vae_decoder_path = "/root/data/fastmochi_genmo/decoder.safetensors" - dit_path = "/root/data/fastmochi_genmo/dit.safetensors" - - vae_encoder_state_dict = load_file(vae_encoder_path) - vae_decoder_state_dict = load_file(vae_decoder_path) - dit_state_dict = load_file(dit_path) - - vae_encoder = Encoder( - in_channels=15, - base_channels=64, - channel_multipliers=[1, 2, 4, 6], - num_res_blocks=[3, 3, 4, 6, 3], - latent_dim=12, - temporal_reductions=[1, 2, 3], - spatial_reductions=[2, 2, 2], - prune_bottlenecks=[False, False, False, False, False], - has_attentions=[False, True, True, True, True], - affine=True, - bias=True, - input_is_conv_1x1=True, - padding_mode="replicate", + linear_quadratic = True if "linear_quadratic" in args.scheduler_type else False + scheduler = PCMFMScheduler( + 1000, + args.shift, + args.num_euler_timesteps, + linear_quadratic, + args.linear_threshold, + args.linear_range, ) - vae_decoder = Decoder( - out_channels=3, - base_channels=128, - channel_multipliers=[1, 2, 4, 6], - temporal_expansions=[1, 2, 3], - spatial_expansions=[2, 2, 2], - num_res_blocks=[3, 3, 4, 6, 3], - latent_dim=12, - has_attention=[False, False, False, False, False], - output_norm=False, - nonlinearity="silu", - output_nonlinearity="silu", - causal=True, - ) - transformer = AsymmDiTJoint() - - vae_encoder.load_state_dict(vae_encoder_state_dict) - vae_decoder.load_state_dict(vae_decoder_state_dict) - transformer.load_state_dict(dit_state_dict) - - transformer.config.in_channels = 12 + + if args.transformer_path: + transformer = MochiTransformer3DModel.from_pretrained(args.transformer_path) else: - if args.transformer_path: - transformer = MochiTransformer3DModel.from_pretrained(args.transformer_path) - else: - transformer = MochiTransformer3DModel.from_pretrained(args.model_path, subfolder='transformer/') - - pipe = MochiPipeline.from_pretrained(args.model_path, transformer=transformer, scheduler=scheduler) - # from IPython import embed - # embed() - # del pipe.vae.encoder - # del pipe.vae.decoder - # pipe.vae.encoder = vae_encoder - # pipe.vae.decoder = vae_decoder - - + transformer = MochiTransformer3DModel.from_pretrained( + args.model_path, subfolder="transformer/" + ) + + pipe = MochiPipeline.from_pretrained( + args.model_path, transformer=transformer, scheduler=scheduler + ) pipe.enable_vae_tiling() - pipe.to(device) - if args.cpu_offload: - pipe.enable_model_cpu_offload() + # pipe.to(device) + # if args.cpu_offload: + pipe.enable_sequential_cpu_offload() return pipe -def generate_video(prompt, negative_prompt, use_negative_prompt, seed, guidance_scale, - num_frames, height, width, num_inference_steps, randomize_seed=False): + +def generate_video( + prompt, + negative_prompt, + use_negative_prompt, + seed, + guidance_scale, + num_frames, + height, + width, + num_inference_steps, + randomize_seed=False, +): if randomize_seed: seed = torch.randint(0, 1000000, (1,)).item() - - pipe = load_model(args) - print("load model successfully") + generator = torch.Generator(device="cuda").manual_seed(seed) - + if not use_negative_prompt: negative_prompt = None - + with torch.autocast("cuda", dtype=torch.bfloat16): output = pipe( prompt=[prompt], @@ -129,24 +94,16 @@ def generate_video(prompt, negative_prompt, use_negative_prompt, seed, guidance_ guidance_scale=guidance_scale, generator=generator, ).frames[0] -<<<<<<< Updated upstream -======= - ->>>>>>> Stashed changes output_path = os.path.join(tempfile.mkdtemp(), "output.mp4") export_to_video(output, output_path, fps=30) return output_path, seed -<<<<<<< Updated upstream -======= ->>>>>>> Stashed changes examples = [ "A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand’s movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough.", "A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds.", "A crowded rooftop bar buzzes with energy, the city skyline twinkling like a field of stars in the background. Strings of fairy lights hang above, casting a warm, golden glow over the scene. Groups of people gather around high tables, their laughter blending with the soft rhythm of live jazz. The aroma of freshly mixed cocktails and charred appetizers wafts through the air, mingling with the cool night breeze.", -<<<<<<< Updated upstream ] args = init_args() @@ -155,18 +112,6 @@ def generate_video(prompt, negative_prompt, use_negative_prompt, seed, guidance_ with gr.Blocks() as demo: gr.Markdown("# Fastvideo Mochi Video Generation Demo") -======= - "Will Smith eats noodles.", - "A short clip of the blonde woman taking a sip from her whiskey glass, her eyes locking with the camera as she smirks playfully. The background shows a group of people laughing and enjoying the party, with vibrant neon signs illuminating the space. The shot is taken in a way that conveys the feeling of a tipsy, carefree night out. The camera then zooms in on her face as she winks, creating a cheeky, flirtatious vibe.", - "A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robot's immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1" -] - -args = init_args() - -with gr.Blocks() as demo: - gr.Markdown("# Mochi Video Generation Demo") - ->>>>>>> Stashed changes with gr.Group(): with gr.Row(): prompt = gr.Text( @@ -178,7 +123,6 @@ def generate_video(prompt, negative_prompt, use_negative_prompt, seed, guidance_ ) run_button = gr.Button("Run", scale=0) result = gr.Video(label="Result", show_label=False) -<<<<<<< Updated upstream with gr.Accordion("Advanced options", open=False): with gr.Group(): @@ -218,27 +162,10 @@ def generate_video(prompt, negative_prompt, use_negative_prompt, seed, guidance_ use_negative_prompt = gr.Checkbox( label="Use negative prompt", value=False ) -======= - - with gr.Accordion("Advanced options", open=False): - with gr.Group(): - with gr.Row(): - height = gr.Slider(label="Height", minimum=256, maximum=1024, step=1, value=args.height) - width = gr.Slider(label="Width", minimum=256, maximum=1024, step=1, value=args.width) - - with gr.Row(): - num_frames = gr.Slider(label="Number of Frames", minimum=8, maximum=256, step=1, value=args.num_frames) - guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=20, step=0.5, value=args.guidance_scale) - num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=args.num_inference_steps) - - with gr.Row(): - use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False) ->>>>>>> Stashed changes negative_prompt = gr.Text( label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", -<<<<<<< Updated upstream visible=False, ) @@ -250,17 +177,6 @@ def generate_video(prompt, negative_prompt, use_negative_prompt, seed, guidance_ gr.Examples(examples=examples, inputs=prompt) -======= - visible=False - ) - - seed = gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=args.seed) - randomize_seed = gr.Checkbox(label="Randomize seed", value=True) - seed_output = gr.Number(label="Used Seed") - - gr.Examples(examples=examples, inputs=prompt) - ->>>>>>> Stashed changes use_negative_prompt.change( fn=lambda x: gr.update(visible=x), inputs=use_negative_prompt, @@ -269,7 +185,6 @@ def generate_video(prompt, negative_prompt, use_negative_prompt, seed, guidance_ run_button.click( fn=generate_video, -<<<<<<< Updated upstream inputs=[ prompt, negative_prompt, @@ -286,13 +201,4 @@ def generate_video(prompt, negative_prompt, use_negative_prompt, seed, guidance_ ) if __name__ == "__main__": - demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) -======= - inputs=[prompt, negative_prompt, use_negative_prompt, seed, guidance_scale, - num_frames, height, width, num_inference_steps, randomize_seed], - outputs=[result, seed_output] - ) - -if __name__ == "__main__": - demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) ->>>>>>> Stashed changes + demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) \ No newline at end of file From c31f50762d7fcd63910881095de2a1822a15b2b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Mon, 23 Dec 2024 20:43:32 +0000 Subject: [PATCH 06/42] debug hunyuan hf sp --- assets/prompt_test.txt | 1 + fastvideo/models/flash_attn_no_pad.py | 2 +- fastvideo/models/hunyuan/modules/models.py | 1 - .../models/hunyuan_hf/modeling_hunyuan.py | 969 ++++++++++++++++++ .../models/hunyuan_hf/pipeline_hunyuan.py | 698 +++++++++++++ fastvideo/models/mochi_hf/modeling_mochi.py | 80 +- fastvideo/sample/sample_t2v_hunyuan_hf.py | 169 +++ fastvideo/train.py | 16 +- fastvideo/train_hunyuan_hf.py | 746 ++++++++++++++ fastvideo/utils/checkpoint.py | 4 +- prompts.txt | 4 + scripts/finetune/finetune_hunyuan.sh | 2 +- scripts/inference/inference_hunyuan.sh | 10 +- scripts/inference/inference_hunyuan_hf.sh | 21 + scripts/inference/inference_mochi_sp.sh | 7 +- 15 files changed, 2684 insertions(+), 46 deletions(-) create mode 100644 assets/prompt_test.txt create mode 100644 fastvideo/models/hunyuan_hf/modeling_hunyuan.py create mode 100644 fastvideo/models/hunyuan_hf/pipeline_hunyuan.py create mode 100644 fastvideo/sample/sample_t2v_hunyuan_hf.py create mode 100644 fastvideo/train_hunyuan_hf.py create mode 100644 prompts.txt create mode 100644 scripts/inference/inference_hunyuan_hf.sh diff --git a/assets/prompt_test.txt b/assets/prompt_test.txt new file mode 100644 index 00000000..11e03e47 --- /dev/null +++ b/assets/prompt_test.txt @@ -0,0 +1 @@ +A majestic lion strides across the golden savanna, its powerful frame glistening under the warm afternoon sun. The tall grass ripples gently in the breeze, enhancing the lion's commanding presence. The tone is vibrant, embodying the raw energy of the wild. Low angle, steady tracking shot, cinematic. diff --git a/fastvideo/models/flash_attn_no_pad.py b/fastvideo/models/flash_attn_no_pad.py index ff917e22..18cf5444 100644 --- a/fastvideo/models/flash_attn_no_pad.py +++ b/fastvideo/models/flash_attn_no_pad.py @@ -14,7 +14,7 @@ def flash_attn_no_pad( x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_input( x, key_padding_mask ) - + x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) output_unpad = flash_attn_varlen_qkvpacked_func( x_unpad, diff --git a/fastvideo/models/hunyuan/modules/models.py b/fastvideo/models/hunyuan/modules/models.py index f9e5dcb1..d60448d2 100644 --- a/fastvideo/models/hunyuan/modules/models.py +++ b/fastvideo/models/hunyuan/modules/models.py @@ -20,7 +20,6 @@ from fastvideo.utils.parallel_states import nccl_info - class MMDoubleStreamBlock(nn.Module): """ A multimodal dit block with seperate modulation for diff --git a/fastvideo/models/hunyuan_hf/modeling_hunyuan.py b/fastvideo/models/hunyuan_hf/modeling_hunyuan.py new file mode 100644 index 00000000..46d27e3c --- /dev/null +++ b/fastvideo/models/hunyuan_hf/modeling_hunyuan.py @@ -0,0 +1,969 @@ +# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention, AttentionProcessor +from diffusers.models.embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, + get_1d_rotary_pos_embed, +) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info +from fastvideo.utils.communications import all_gather, all_to_all_4D +from fastvideo.models.flash_attn_no_pad import flash_attn_no_pad +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +import sys +import pdb +class ForkedPdb(pdb.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + + """ + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + x = x.transpose(1,2) + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out.transpose(1,2) + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x).transpose(1,2) +class HunyuanVideoAttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, #[1, 38160, 3072] [1, 19080, 3072] + encoder_hidden_states: Optional[torch.Tensor] = None, # [1, 256, 3072] [1, 256, 3072] + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, # [38160, 128] [38106, 128] + ) -> torch.Tensor: + + if attn.add_q_proj is None and encoder_hidden_states is not None: + sequence_length = hidden_states.size(1) # 19080 + encoder_sequence_length = encoder_hidden_states.size(1) # 256 + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) # [1, 19336, 3072] + + # 1. QKV projections + query = attn.to_q(hidden_states) # [1, 19080, 3072] + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) # [1, 19080, 24, 128] + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # from IPython import embed + # embed() + def shrink_head(x, dim): + local_heads = x.shape[dim] // nccl_info.sp_size + return x.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads) + + if get_sequence_parallel_state(): + # Handle sequence parallelism for main hidden states + # Note: We scatter on heads dim (1) and gather on sequence dim (2) since tensors are transposed + if attn.add_q_proj is None: + qkv_ = [query, key, value] + + qk_h, qk_eh = [], [] + for item in qkv_: + qk_h.append(item[:,:sequence_length,:,:]) + qk_eh.append(item[:,sequence_length:,:,:]) + for i in range(len(qkv_)): + qk_h[i] = all_to_all_4D(qk_h[i], scatter_dim=2, gather_dim=1) + qk_eh[i] = shrink_head(qk_eh[i], dim=2) + value = torch.cat([qk_h[2],qk_eh[2]], dim=1) + else: + query = all_to_all_4D(query, scatter_dim=2, gather_dim=1) # [1, 24, 19080, 128] -> [1, 12, 38160, 128] + key = all_to_all_4D(key, scatter_dim=2, gather_dim=1) + value = all_to_all_4D(value, scatter_dim=2, gather_dim=1) + + + + # if image_rotary_emb is not None: + + # freqs_cos, freqs_sin = image_rotary_emb + # freqs_cos = shrink_head(freqs_cos, dim=1) + # freqs_sin = shrink_head(freqs_sin, dim=1) + # image_rotary_emb = (freqs_cos, freqs_sin) + + # 3. Rotational positional embeddings applied to latent stream + + # freqs_cos, freqs_sin = image_rotary_emb + # freqs_cos = freqs_cos.unsqueeze(1).expand(-1, attn.heads // nccl_info.sp_size, -1) + # freqs_sin = freqs_sin.unsqueeze(1).expand(-1, attn.heads // nccl_info.sp_size, -1) + # image_rotary_emb = (freqs_cos, freqs_sin) + + if image_rotary_emb is not None: + #from diffusers.models.embeddings import apply_rotary_emb + + if attn.add_q_proj is None and encoder_hidden_states is not None: + if get_sequence_parallel_state(): + query = torch.cat( + [ + apply_rotary_emb(qk_h[0], image_rotary_emb), qk_eh[0] + ], + dim=1, + ) + key = torch.cat( + [ + apply_rotary_emb(qk_h[1], image_rotary_emb), qk_eh[1] + ], + dim=1, + ) + # if get_sequence_parallel_state() and attn.add_q_proj is None: + + else: + query = torch.cat( + [ + apply_rotary_emb(query[:,: -encoder_hidden_states.shape[1], :], image_rotary_emb), + query[:, -encoder_hidden_states.shape[1] :,:], + ], + dim=1, + ) + key = torch.cat( + [ + apply_rotary_emb(key[:,: -encoder_hidden_states.shape[1], :], image_rotary_emb), + key[:, -encoder_hidden_states.shape[1] :,:], + ], + dim=1, + ) + + else: + + query = apply_rotary_emb(query, image_rotary_emb) # [1, 24, 38160, 128] + key = apply_rotary_emb(key, image_rotary_emb) + + # 4. Encoder condition QKV projection and normalization + if attn.add_q_proj is not None and encoder_hidden_states is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) # [1, 256, 3072] + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) # [1, 24, 256, 128] + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + if get_sequence_parallel_state(): + encoder_query = shrink_head(encoder_query, dim=2) + encoder_key = shrink_head(encoder_key, dim=2) + encoder_value = shrink_head(encoder_value, dim=2) + sequence_length = query.size(1) + encoder_sequence_length = encoder_query.size(1) + + query = torch.cat([query, encoder_query], dim=1).unsqueeze(2) # [1, 24, 1, 38416, 128] + + key = torch.cat([key, encoder_key], dim=1).unsqueeze(2) + value = torch.cat([value, encoder_value], dim=1).unsqueeze(2) + else: + # from IPython import embed + # embed() + query = query.unsqueeze(2) # [1, 24, 1, 38416, 128] + key = key.unsqueeze(2) + value = value.unsqueeze(2) + + qkv = torch.cat([query, key, value], dim=2) + # 5. Attention + + attention_mask = attention_mask[:,0,:] + + hidden_states = flash_attn_no_pad( #[1, 38416, 24, 128] + qkv, attention_mask, causal=False, dropout_p=0.0, softmax_scale=None # [2, 25696] + ) + + + + # hidden_states = F.scaled_dot_product_attention( + # query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + # ) + #hidden_states = hidden_states.flatten(2, 3) # [1, 38416, 3072] + hidden_states = hidden_states.to(query.dtype) + + # 6. Output projection + if encoder_hidden_states is not None: + if get_sequence_parallel_state(): + if attn.add_q_proj is not None: + + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( + (sequence_length, encoder_sequence_length), dim=1 + ) + # B, S, H, D + + hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) + encoder_hidden_states = all_gather( + encoder_hidden_states, dim=2 + ).contiguous() + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + encoder_hidden_states = encoder_hidden_states.flatten(2, 3) + encoder_hidden_states = encoder_hidden_states.to(query.dtype) + else: + + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( + (sequence_length * nccl_info.sp_size, encoder_sequence_length), dim=1 + ) + # B, S, H, D + #ForkedPdb().set_trace() + hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) + encoder_hidden_states = all_gather( + encoder_hidden_states, dim=2 + ).contiguous() + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + encoder_hidden_states = encoder_hidden_states.flatten(2, 3) + encoder_hidden_states = encoder_hidden_states.to(query.dtype) + else: + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( + (sequence_length, encoder_sequence_length), dim=1 + ) + if getattr(attn, "to_out", None) is not None: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if getattr(attn, "to_add_out", None) is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states # [1, 38160, 3072] [1, 256, 3072] + + +class HunyuanVideoPatchEmbed(nn.Module): + def __init__( + self, + patch_size: Union[int, Tuple[int, int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + super().__init__() + + patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC + return hidden_states + + +class HunyuanVideoAdaNorm(nn.Module): + def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + super().__init__() + + out_features = out_features or 2 * in_features + self.linear = nn.Linear(in_features, out_features) + self.nonlinearity = nn.SiLU() + + def forward( + self, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + temb = self.linear(self.nonlinearity(temb)) + gate_msa, gate_mlp = temb.chunk(2, dim=1) + gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) + return gate_msa, gate_mlp + + +class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) + + self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + ) + + gate_msa, gate_mlp = self.norm_out(temb) + hidden_states = hidden_states + attn_output * gate_msa + + ff_output = self.ff(self.norm2(hidden_states)) + hidden_states = hidden_states + ff_output * gate_mlp + + return hidden_states + + +class HunyuanVideoIndividualTokenRefiner(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + self.refiner_blocks = nn.ModuleList( + [ + HunyuanVideoIndividualTokenRefinerBlock( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> None: + self_attn_mask = None + if attention_mask is not None: + batch_size = attention_mask.shape[0] + seq_len = attention_mask.shape[1] + attention_mask = attention_mask.to(hidden_states.device).bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + self_attn_mask[:, :, :, 0] = True + + for block in self.refiner_blocks: + hidden_states = block(hidden_states, temb, self_attn_mask) + + return hidden_states + + +class HunyuanVideoTokenRefiner(nn.Module): + def __init__( + self, + in_channels: int, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=hidden_size, pooled_projection_dim=in_channels + ) + self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) + self.token_refiner = HunyuanVideoIndividualTokenRefiner( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + mlp_width_ratio=mlp_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if attention_mask is None: + pooled_projections = hidden_states.mean(dim=1) + else: + original_dtype = hidden_states.dtype + mask_float = attention_mask.float().unsqueeze(-1) + pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) + pooled_projections = pooled_projections.to(original_dtype) + + temb = self.time_text_embed(timestep, pooled_projections) + hidden_states = self.proj_in(hidden_states) + hidden_states = self.token_refiner(hidden_states, temb, attention_mask) + + return hidden_states + + +class HunyuanVideoRotaryPosEmbed(nn.Module): + def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.rope_dim = rope_dim + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + rope_sizes = [num_frames // self.patch_size_t * nccl_info.sp_size, height // self.patch_size, width // self.patch_size] + + axes_grids = [] + for i in range(3): + # Note: The following line diverges from original behaviour. We create the grid on the device, whereas + # original implementation creates it on CPU and then moves it to device. This results in numerical + # differences in layerwise debugging outputs, but visually it is the same. + grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32) + axes_grids.append(grid) + + grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T] + grid = torch.stack(grid, dim=0) # [3, W, H, T] + + freqs = [] + for i in range(3): + freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) + freqs.append(freq) + + freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) + freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) + return freqs_cos, freqs_sin + + +class HunyuanVideoSingleTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + mlp_dim = int(hidden_size * mlp_ratio) + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + pre_only=True, + ) + + self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") + self.proj_mlp = nn.Linear(hidden_size, mlp_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + residual = hidden_states + + # 1. Input normalization + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + norm_hidden_states, norm_encoder_hidden_states = ( + norm_hidden_states[:, :-text_seq_length, :], + norm_hidden_states[:, -text_seq_length:, :], + ) + + # 2. Attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + attn_output = torch.cat([attn_output, context_attn_output], dim=1) + + # 3. Modulation and residual connection + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states) + hidden_states = hidden_states + residual + + hidden_states, encoder_hidden_states = ( + hidden_states[:, :-text_seq_length, :], + hidden_states[:, -text_seq_length:, :], + ) + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=freqs_cis, + ) + + # 3. Modulation and residual connection + hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + r""" + A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). + + Args: + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + num_layers (`int`, defaults to `20`): + The number of layers of dual-stream blocks to use. + num_single_layers (`int`, defaults to `40`): + The number of layers of single-stream blocks to use. + num_refiner_layers (`int`, defaults to `2`): + The number of layers of refiner blocks to use. + mlp_ratio (`float`, defaults to `4.0`): + The ratio of the hidden layer size to the input size in the feedforward network. + patch_size (`int`, defaults to `2`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + qk_norm (`str`, defaults to `rms_norm`): + The normalization to use for the query and key projections in the attention layers. + guidance_embeds (`bool`, defaults to `True`): + Whether to use guidance embeddings in the model. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + pooled_projection_dim (`int`, defaults to `768`): + The dimension of the pooled projection of the text embeddings. + rope_theta (`float`, defaults to `256.0`): + The value of theta to use in the RoPE layer. + rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions of the axes to use in the RoPE layer. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 20, + num_single_layers: int = 40, + num_refiner_layers: int = 2, + mlp_ratio: float = 4.0, + patch_size: int = 2, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + guidance_embeds: bool = True, + text_embed_dim: int = 4096, + pooled_projection_dim: int = 768, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int] = (16, 56, 56), + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Latent and condition embedders + self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.context_embedder = HunyuanVideoTokenRefiner( + text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers + ) + self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) + + # 2. RoPE + self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + + # 3. Dual stream transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) + + # 4. Single stream transformer blocks + self.single_transformer_blocks = nn.ModuleList( + [ + HunyuanVideoSingleTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_single_layers) + ] + ) + + # 5. Output projection + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + pooled_projections: torch.Tensor, + guidance: torch.Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p, p_t = self.config.patch_size, self.config.patch_size_t + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + # 1. RoPE + + image_rotary_emb = self.rope(hidden_states) + + # 2. Conditional embeddings + temb = self.time_text_embed(timestep, guidance, pooled_projections) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) + + # 3. Attention mask preparation + latent_sequence_length = hidden_states.shape[1] + condition_sequence_length = encoder_hidden_states.shape[1] + sequence_length = latent_sequence_length + condition_sequence_length + attention_mask = torch.zeros( + batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool + ) # [B, N, N] + + effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] + effective_sequence_length = latent_sequence_length + effective_condition_sequence_length + + for i in range(batch_size): + attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + # 5. Output projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p + ) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) diff --git a/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py b/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py new file mode 100644 index 00000000..aac471e8 --- /dev/null +++ b/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py @@ -0,0 +1,698 @@ +# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.loaders import HunyuanVideoLoraLoaderMixin +from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput +from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info +from einops import rearrange +from fastvideo.utils.communications import all_gather + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import export_to_video + + >>> model_id = "tencent/HunyuanVideo" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... height=320, + ... width=512, + ... num_frames=61, + ... num_inference_steps=30, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer_2 (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlamaModel, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + + self.vae_scale_factor_temporal = ( + self.vae.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.vae_scale_factor_spatial = ( + self.vae.spatial_compression_ratio if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( + self, + prompt: Union[str, List[str]], + prompt_template: Dict[str, Any], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]] = None, + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None and pooled_prompt_embeds is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + num_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Union[str, List[str]] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: List[float] = None, + guidance_scale: float = 6.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. Note that the only available HunyuanVideo model is + CFG-distilled, which means that traditional guidance between unconditional and conditional latent is + not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + if pooled_prompt_embeds is not None: + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + ) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_latent_frames, + torch.float32, + device, + generator, + latents, + ) + world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group + + if get_sequence_parallel_state(): + latents = rearrange( + latents, "b t (n s) h w -> b t n s h w", n=world_size + ).contiguous() + latents = latents[:, :, rank, :, :, :] + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = latents.to(transformer_dtype) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if get_sequence_parallel_state(): + latents = all_gather(latents, dim=2) + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoPipelineOutput(frames=video) diff --git a/fastvideo/models/mochi_hf/modeling_mochi.py b/fastvideo/models/mochi_hf/modeling_mochi.py index a57d0fa7..d0aae801 100644 --- a/fastvideo/models/mochi_hf/modeling_mochi.py +++ b/fastvideo/models/mochi_hf/modeling_mochi.py @@ -55,6 +55,20 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +import sys +import pdb +class ForkedPdb(pdb.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + + """ + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin class FeedForward(HF_FeedForward): def __init__( @@ -164,19 +178,20 @@ def __init__(self): def __call__( self, attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - encoder_attention_mask: torch.Tensor, + hidden_states: torch.Tensor, # [2, 25440, 3072] [2, 12720, 3072] + encoder_hidden_states: torch.Tensor, # [2, 256, 1536] [2, 256, 1536] + encoder_attention_mask: torch.Tensor, # [2, 256] attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + image_rotary_emb: Optional[torch.Tensor] = None, # [25440, 24, 64] [25440, 24, 64] + ) -> torch.Tensor: # [b, s, h * d] - query = attn.to_q(hidden_states) + + query = attn.to_q(hidden_states) # [2, 25440, 3072] key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) # [b, s, h=24, d=128] - query = query.unflatten(2, (attn.heads, -1)) + query = query.unflatten(2, (attn.heads, -1)) # [2, 25440, 24, 128] key = key.unflatten(2, (attn.heads, -1)) value = value.unflatten(2, (attn.heads, -1)) @@ -185,12 +200,12 @@ def __call__( if attn.norm_k is not None: key = attn.norm_k(key) # [b, 256, h * d] - encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_query = attn.add_q_proj(encoder_hidden_states) # [2, 256, 3072] encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) # [b, 256, h=24, d=128] - encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) # [2, 256, 24, 128] encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) @@ -200,12 +215,15 @@ def __call__( encoder_key = attn.norm_added_k(encoder_key) if image_rotary_emb is not None: - freqs_cos, freqs_sin = image_rotary_emb[0], image_rotary_emb[1] + freqs_cos, freqs_sin = image_rotary_emb[0], image_rotary_emb[1] # [25440, 24, 64] # shard the head dimension + + if get_sequence_parallel_state(): # B, S, H, D to (S, B,) H, D # batch_size, seq_len, attn_heads, head_dim - query = all_to_all_4D(query, scatter_dim=2, gather_dim=1) + + query = all_to_all_4D(query, scatter_dim=2, gather_dim=1) # [2, 25440, 24, 128] key = all_to_all_4D(key, scatter_dim=2, gather_dim=1) value = all_to_all_4D(value, scatter_dim=2, gather_dim=1) @@ -214,12 +232,13 @@ def shrink_head(encoder_state, dim): return encoder_state.narrow( dim, nccl_info.rank_within_group * local_heads, local_heads ) - - encoder_query = shrink_head(encoder_query, dim=2) + ForkedPdb().set_trace() + encoder_query = shrink_head(encoder_query, dim=2) # [2, 256, 12, 128] encoder_key = shrink_head(encoder_key, dim=2) encoder_value = shrink_head(encoder_value, dim=2) + if image_rotary_emb is not None: - freqs_cos = shrink_head(freqs_cos, dim=1) + freqs_cos = shrink_head(freqs_cos, dim=1) # [25440, 12, 64] freqs_sin = shrink_head(freqs_sin, dim=1) if image_rotary_emb is not None: @@ -232,7 +251,7 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): return torch.stack([cos, sin], dim=-1).flatten(-2) - query = apply_rotary_emb(query, freqs_cos, freqs_sin) + query = apply_rotary_emb(query, freqs_cos, freqs_sin) # [2, 25440, 24, 128] key = apply_rotary_emb(key, freqs_cos, freqs_sin) # query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) @@ -246,16 +265,17 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): encoder_sequence_length = encoder_query.size(1) # H - query = torch.cat([query, encoder_query], dim=1).unsqueeze(2) + query = torch.cat([query, encoder_query], dim=1).unsqueeze(2) # [2, 25696, 1, 24, 128] key = torch.cat([key, encoder_key], dim=1).unsqueeze(2) value = torch.cat([value, encoder_value], dim=1).unsqueeze(2) # B, S, 3, H, D - qkv = torch.cat([query, key, value], dim=2) + qkv = torch.cat([query, key, value], dim=2) # [2, 25696, 3, 24, 128] attn_mask = encoder_attention_mask[:, :].bool() attn_mask = F.pad(attn_mask, (sequence_length, 0), value=True) - hidden_states = flash_attn_no_pad( - qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None + + hidden_states = flash_attn_no_pad( #[2, 25696, 24, 128] + qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None # [2, 25696] ) # hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask = None, dropout_p=0.0, is_causal=False) @@ -270,6 +290,7 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): (sequence_length, encoder_sequence_length), dim=1 ) # B, S, H, D + hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) encoder_hidden_states = all_gather( encoder_hidden_states, dim=2 @@ -285,14 +306,14 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( (sequence_length, encoder_sequence_length), dim=1 ) - + # linear proj - hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[0](hidden_states) #[2, 25440, 3072] # dropout - hidden_states = attn.to_out[1](hidden_states) + hidden_states = attn.to_out[1](hidden_states) #[2, 256, 3072] if hasattr(attn, "to_add_out"): - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) #[2, 256, 1536] return hidden_states, encoder_hidden_states @@ -477,6 +498,7 @@ def _get_positions( dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: scale = (self.target_area / (height * width)) ** 0.5 + t = torch.arange(num_frames * nccl_info.sp_size, device=device, dtype=dtype) h = self._centers( -height * scale / 2, height * scale / 2, height, device, dtype @@ -484,7 +506,7 @@ def _get_positions( w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype) grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") - + positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3) return positions @@ -543,6 +565,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["MochiTransformerBlock"] @register_to_config def __init__( @@ -665,15 +688,16 @@ def forward( hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.patch_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) - + image_rotary_emb = self.rope( - self.pos_frequencies, - num_frames, + self.pos_frequencies, # [3, 24, 64] + num_frames, # 8 post_patch_height, post_patch_width, - device=hidden_states.device, + device=hidden_states.device, # [2, 12720, 3072] dtype=torch.float32, ) + attn_outputs_list = [] for i, block in enumerate(self.transformer_blocks): if self.gradient_checkpointing: diff --git a/fastvideo/sample/sample_t2v_hunyuan_hf.py b/fastvideo/sample/sample_t2v_hunyuan_hf.py new file mode 100644 index 00000000..14ed518c --- /dev/null +++ b/fastvideo/sample/sample_t2v_hunyuan_hf.py @@ -0,0 +1,169 @@ +import torch +import torch.distributed as dist + +from diffusers.utils import export_to_video +from fastvideo.utils.parallel_states import ( + initialize_sequence_parallel_state, + nccl_info, +) +import argparse +import os +import json +from typing import Optional +from safetensors.torch import save_file, load_file +from peft import set_peft_model_state_dict, inject_adapter_in_model, load_peft_weights +from peft import LoraConfig +import sys +import pdb +import copy +from typing import Dict +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import convert_unet_state_dict_to_peft +from fastvideo.distill.solver import PCMFMScheduler +from fastvideo.models.hunyuan.diffusion.schedulers import FlowMatchDiscreteScheduler +from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline +from fastvideo.models.hunyuan_hf.modeling_hunyuan import HunyuanVideoTransformer3DModel + +def initialize_distributed(): + local_rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + print("world_size", world_size) + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", init_method="env://", world_size=world_size, rank=local_rank + ) + initialize_sequence_parallel_state(world_size) + + +def main(args): + initialize_distributed() + print(nccl_info.sp_size) + device = torch.cuda.current_device() + # Peiyuan: GPU seed will cause A100 and H100 to produce different results ..... + weight_dtype = torch.bfloat16 + + if args.transformer_path is not None: + transformer = HunyuanVideoTransformer3DModel.from_pretrained(args.transformer_path) + else: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + args.model_path, subfolder="transformer/", torch_dtype=weight_dtype + ) + + pipe = HunyuanVideoPipeline.from_pretrained( + args.model_path, transformer=transformer, torch_dtype=torch.float16 + ) + + pipe.enable_vae_tiling() + + if args.lora_checkpoint_dir is not None: + print(f"Loading LoRA weights from {args.lora_checkpoint_dir}") + config_path = os.path.join(args.lora_checkpoint_dir, "lora_config.json") + with open(config_path, "r") as f: + lora_config_dict = json.load(f) + rank = lora_config_dict["lora_params"]["lora_rank"] + lora_alpha = lora_config_dict["lora_params"]["lora_alpha"] + lora_scaling = lora_alpha / rank + pipe.load_lora_weights(args.lora_checkpoint_dir, adapter_name="default") + pipe.set_adapters(["default"], [lora_scaling]) + print(f"Successfully Loaded LoRA weights from {args.lora_checkpoint_dir}") + #pipe.to(device) + + pipe.enable_model_cpu_offload(device) + + # Generate videos from the input prompt + + if args.prompt_embed_path is not None: + prompt_embeds = ( + torch.load(args.prompt_embed_path, map_location="cpu", weights_only=True) + .to(device) + .unsqueeze(0) + ) + encoder_attention_mask = ( + torch.load( + args.encoder_attention_mask_path, map_location="cpu", weights_only=True + ) + .to(device) + .unsqueeze(0) + ) + prompts = None + elif args.prompt_path is not None: + prompts = [line.strip() for line in open(args.prompt_path, "r")] + prompt_embeds = None + encoder_attention_mask = None + else: + prompts = args.prompts + prompt_embeds = None + encoder_attention_mask = None + + if prompts is not None: + with torch.autocast("cuda", dtype=torch.bfloat16): + for prompt in prompts: + generator = torch.Generator("cpu").manual_seed(args.seed) + video = pipe( + prompt=[prompt], + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + generator=generator, + ).frames + if nccl_info.global_rank <= 0: + os.makedirs(args.output_path, exist_ok=True) + suffix = prompt.split(".")[0] + export_to_video( + video[0], + os.path.join(args.output_path, f"{suffix}.mp4"), + fps=30, + ) + else: + with torch.autocast("cuda", dtype=torch.bfloat16): + generator = torch.Generator("cpu").manual_seed(args.seed) + videos = pipe( + prompt_embeds=prompt_embeds, + prompt_attention_mask=encoder_attention_mask, + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + generator=generator, + ).frames + + if nccl_info.global_rank <= 0: + export_to_video(videos[0], args.output_path + ".mp4", fps=30) + + +if __name__ == "__main__": + # arg parse + parser = argparse.ArgumentParser() + parser.add_argument("--prompts", nargs="+", default=[]) + parser.add_argument("--num_frames", type=int, default=163) + parser.add_argument("--height", type=int, default=480) + parser.add_argument("--width", type=int, default=848) + parser.add_argument("--num_inference_steps", type=int, default=64) + parser.add_argument("--guidance_scale", type=float, default=4.5) + parser.add_argument("--model_name", type=str, default="hunyuan") + parser.add_argument("--model_path", type=str, default="data/hunyuan") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--output_path", type=str, default="./outputs.mp4") + parser.add_argument("--transformer_path", type=str, default=None) + parser.add_argument("--prompt_embed_path", type=str, default=None) + parser.add_argument("--prompt_path", type=str, default=None) + parser.add_argument("--scheduler_type", type=str, default="euler") + parser.add_argument("--encoder_attention_mask_path", type=str, default=None) + parser.add_argument( + "--lora_checkpoint_dir", + type=str, + default=None, + help="Path to the directory containing LoRA checkpoints", + ) + parser.add_argument("--flow_shift", type=int, default=7, help="Flow shift parameter.") + parser.add_argument("--flow-reverse",action="store_true",help="If reverse, learning/sampling from t=1 -> t=0.",) + parser.add_argument("--flow-solver", type=str, default="euler", help="Solver for flow matching.") + parser.add_argument("--shift", type=float, default=8.0) + parser.add_argument("--num_euler_timesteps", type=int, default=100) + parser.add_argument("--linear_threshold", type=float, default=0.025) + parser.add_argument("--linear_range", type=float, default=0.75) + args = parser.parse_args() + main(args) diff --git a/fastvideo/train.py b/fastvideo/train.py index 401cb99c..13d4f9ce 100644 --- a/fastvideo/train.py +++ b/fastvideo/train.py @@ -47,6 +47,7 @@ ) from fastvideo.utils.logging_ import main_print from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline +from diffusers.pipelines.hunyuan_video import HunyuanVideoPipeline # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.31.0") @@ -224,7 +225,11 @@ def main(args): ) if args.use_lora: - assert args.model_type == "mochi", "LoRA is only supported for Mochi model." + #assert args.model_type == "mochi", "LoRA is only supported for Mochi model." + if args.model_type == "mochi": + pipeline = MochiPipeline + elif args.model_type == "hunyuan": + pipeline = HunyuanVideoPipeline transformer.requires_grad_(False) transformer_lora_config = LoraConfig( r=args.lora_rank, @@ -232,10 +237,13 @@ def main(args): init_lora_weights=True, target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) + from IPython import embed + embed() + transformer.add_adapter(transformer_lora_config) if args.resume_from_lora_checkpoint: - lora_state_dict = MochiPipeline.lora_state_dict( + lora_state_dict = pipeline.lora_state_dict( args.resume_from_lora_checkpoint ) transformer_state_dict = { @@ -456,7 +464,7 @@ def main(args): if args.use_lora: # Save LoRA weights save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, step + transformer, optimizer, rank, args.output_dir, step, pipeline ) else: # Your existing checkpoint saving code @@ -467,7 +475,7 @@ def main(args): if args.use_lora: save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, args.max_train_steps + transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipeline ) else: save_checkpoint( diff --git a/fastvideo/train_hunyuan_hf.py b/fastvideo/train_hunyuan_hf.py new file mode 100644 index 00000000..13d4f9ce --- /dev/null +++ b/fastvideo/train_hunyuan_hf.py @@ -0,0 +1,746 @@ +import argparse +from email.policy import strict +import logging +import math +import os +import shutil +from pathlib import Path +from fastvideo.utils.parallel_states import ( + initialize_sequence_parallel_state, + destroy_sequence_parallel_group, + get_sequence_parallel_state, + nccl_info, +) +from fastvideo.utils.communications import sp_parallel_dataloader_wrapper, broadcast +from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input +from fastvideo.utils.validation import log_validation +import time +from torch.utils.data import DataLoader +import torch +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + StateDictType, + FullStateDictConfig, +) +import json +from torch.utils.data.distributed import DistributedSampler +from fastvideo.utils.dataset_utils import LengthGroupedSampler +import wandb +from accelerate.utils import set_seed +from tqdm.auto import tqdm +from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs, apply_fsdp_checkpointing +from diffusers.utils import convert_unet_state_dict_to_peft +from diffusers import FlowMatchEulerDiscreteScheduler +from fastvideo.utils.load import load_transformer +from diffusers.optimization import get_scheduler +from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel +from diffusers.utils import check_min_version +from fastvideo.dataset.latent_datasets import LatentDataset, latent_collate_function +import torch.distributed as dist +from safetensors.torch import save_file, load_file +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from fastvideo.utils.checkpoint import ( + save_checkpoint, + save_lora_checkpoint, + resume_lora_optimizer, +) +from fastvideo.utils.logging_ import main_print +from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline +from diffusers.pipelines.hunyuan_video import HunyuanVideoPipeline + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.31.0") +import time +from collections import deque + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, + batch_size: int, + generator, + logit_mean: float = None, + logit_std: float = None, + mode_scale: float = None, +): + """ + Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal( + mean=logit_mean, + std=logit_std, + size=(batch_size,), + device="cpu", + generator=generator, + ) + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu", generator=generator) + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu", generator=generator) + return u + + +def get_sigmas(noise_scheduler, device, timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def train_one_step( + transformer, + model_type, + optimizer, + lr_scheduler, + loader, + noise_scheduler, + noise_random_generator, + gradient_accumulation_steps, + sp_size, + precondition_outputs, + max_grad_norm, + weighting_scheme, + logit_mean, + logit_std, + mode_scale, +): + total_loss = 0.0 + optimizer.zero_grad() + for _ in range(gradient_accumulation_steps): + ( + latents, + encoder_hidden_states, + latents_attention_mask, + encoder_attention_mask, + ) = next(loader) + latents = normalize_dit_input(model_type, latents) + batch_size = latents.shape[0] + noise = torch.randn_like(latents) + u = compute_density_for_timestep_sampling( + weighting_scheme=weighting_scheme, + batch_size=batch_size, + generator=noise_random_generator, + logit_mean=logit_mean, + logit_std=logit_std, + mode_scale=mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) + if sp_size > 1: + # Make sure that the timesteps are the same across all sp processes. + broadcast(timesteps) + sigmas = get_sigmas( + noise_scheduler, + latents.device, + timesteps, + n_dim=latents.ndim, + dtype=latents.dtype, + ) + noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise + # if rank<=0: + # print("2222222222222222222222222222222222222222222222") + # print(type(latents_attention_mask)) + # print(latents_attention_mask) + with torch.autocast("cuda", torch.bfloat16): + model_pred = transformer( + noisy_model_input, + encoder_hidden_states, + timesteps, + encoder_attention_mask, # B, L + return_dict=False, + )[0] + # if rank<=0: + # print("333333333333333333333333333333333333333333333333") + if precondition_outputs: + model_pred = noisy_model_input - model_pred * sigmas + if precondition_outputs: + target = latents + else: + target = noise - latents + + loss = ( + torch.mean((model_pred.float() - target.float()) ** 2) + / gradient_accumulation_steps + ) + + loss.backward() + + avg_loss = loss.detach().clone() + dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG) + total_loss += avg_loss.item() + + grad_norm = transformer.clip_grad_norm_(max_grad_norm) + optimizer.step() + lr_scheduler.step() + return total_loss, grad_norm.item() + + +def main(args): + torch.backends.cuda.matmul.allow_tf32 = True + + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group("nccl") + torch.cuda.set_device(local_rank) + device = torch.cuda.current_device() + initialize_sequence_parallel_state(args.sp_size) + + # If passed along, set the training seed now. On GPU... + if args.seed is not None: + # TODO: t within the same seq parallel group should be the same. Noise should be different. + set_seed(args.seed + rank) + # We use different seeds for the noise generation in each process to ensure that the noise is different in a batch. + noise_random_generator = None + + # Handle the repository creation + if rank <= 0 and args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + + # Create model: + + main_print(f"--> loading model from {args.pretrained_model_name_or_path}") + # keep the master weight to float32 + transformer = load_transformer( + args.model_type, + args.dit_model_name_or_path, + args.pretrained_model_name_or_path, + torch.float32 if args.master_weight_type == "fp32" else torch.bfloat16, + ) + + if args.use_lora: + #assert args.model_type == "mochi", "LoRA is only supported for Mochi model." + if args.model_type == "mochi": + pipeline = MochiPipeline + elif args.model_type == "hunyuan": + pipeline = HunyuanVideoPipeline + transformer.requires_grad_(False) + transformer_lora_config = LoraConfig( + r=args.lora_rank, + lora_alpha=args.lora_alpha, + init_lora_weights=True, + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + from IPython import embed + embed() + + transformer.add_adapter(transformer_lora_config) + + if args.resume_from_lora_checkpoint: + lora_state_dict = pipeline.lora_state_dict( + args.resume_from_lora_checkpoint + ) + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v + for k, v in lora_state_dict.items() + if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict( + transformer, transformer_state_dict, adapter_name="default" + ) + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + main_print( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + main_print( + f" Total training parameters = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e6} M" + ) + main_print( + f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}" + ) + fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs( + transformer, + args.fsdp_sharding_startegy, + args.use_lora, + args.use_cpu_offload, + args.master_weight_type, + ) + + if args.use_lora: + transformer.config.lora_rank = args.lora_rank + transformer.config.lora_alpha = args.lora_alpha + transformer.config.lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + transformer._no_split_modules = [ + no_split_module.__name__ for no_split_module in no_split_modules + ] + fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer) + + transformer = FSDP(transformer, **fsdp_kwargs,) + main_print(f"--> model loaded") + + if args.gradient_checkpointing: + apply_fsdp_checkpointing( + transformer, no_split_modules, args.selective_checkpointing + ) + + # Set model as trainable. + transformer.train() + + noise_scheduler = FlowMatchEulerDiscreteScheduler() + + params_to_optimize = transformer.parameters() + params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize)) + + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=args.learning_rate, + betas=(0.9, 0.999), + weight_decay=args.weight_decay, + eps=1e-8, + ) + + init_steps = 0 + if args.resume_from_lora_checkpoint: + transformer, optimizer, init_steps = resume_lora_optimizer( + transformer, args.resume_from_lora_checkpoint, optimizer + ) + main_print(f"optimizer: {optimizer}") + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + last_epoch=init_steps - 1, + ) + + train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg) + sampler = ( + LengthGroupedSampler( + args.train_batch_size, + rank=rank, + world_size=world_size, + lengths=train_dataset.lengths, + group_frame=args.group_frame, + group_resolution=args.group_resolution, + ) + if (args.group_frame or args.group_resolution) + else DistributedSampler( + train_dataset, rank=rank, num_replicas=world_size, shuffle=False + ) + ) + + train_dataloader = DataLoader( + train_dataset, + sampler=sampler, + collate_fn=latent_collate_function, + pin_memory=True, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + drop_last=True, + ) + + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) + / args.gradient_accumulation_steps + * args.sp_size + / args.train_sp_batch_size + ) + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + if rank <= 0: + project = args.tracker_project_name or "fastvideo" + wandb.init(project=project, config=args) + + # Train! + total_batch_size = ( + args.train_batch_size + * world_size + * args.gradient_accumulation_steps + / args.sp_size + * args.train_sp_batch_size + ) + main_print("***** Running training *****") + main_print(f" Num examples = {len(train_dataset)}") + main_print(f" Dataloader size = {len(train_dataloader)}") + main_print(f" Num Epochs = {args.num_train_epochs}") + main_print(f" Resume training from step {init_steps}") + main_print(f" Instantaneous batch size per device = {args.train_batch_size}") + main_print( + f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}" + ) + main_print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + main_print(f" Total optimization steps = {args.max_train_steps}") + main_print( + f" Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B" + ) + # print dtype + main_print(f" Master weight dtype: {transformer.parameters().__next__().dtype}") + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + assert NotImplementedError("resume_from_checkpoint is not supported now.") + # TODO + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=init_steps, + desc="Steps", + # Only show the progress bar once on each machine. + disable=local_rank > 0, + ) + + loader = sp_parallel_dataloader_wrapper( + train_dataloader, + device, + args.train_batch_size, + args.sp_size, + args.train_sp_batch_size, + ) + + step_times = deque(maxlen=100) + + # todo future + for i in range(init_steps): + next(loader) + for step in range(init_steps + 1, args.max_train_steps + 1): + start_time = time.time() + loss, grad_norm = train_one_step( + transformer, + args.model_type, + optimizer, + lr_scheduler, + loader, + noise_scheduler, + noise_random_generator, + args.gradient_accumulation_steps, + args.sp_size, + args.precondition_outputs, + args.max_grad_norm, + args.weighting_scheme, + args.logit_mean, + args.logit_std, + args.mode_scale, + ) + + step_time = time.time() - start_time + step_times.append(step_time) + avg_step_time = sum(step_times) / len(step_times) + + progress_bar.set_postfix( + { + "loss": f"{loss:.4f}", + "step_time": f"{step_time:.2f}s", + "grad_norm": grad_norm, + } + ) + progress_bar.update(1) + if rank <= 0: + wandb.log( + { + "train_loss": loss, + "learning_rate": lr_scheduler.get_last_lr()[0], + "step_time": step_time, + "avg_step_time": avg_step_time, + "grad_norm": grad_norm, + }, + step=step, + ) + if step % args.checkpointing_steps == 0: + if args.use_lora: + # Save LoRA weights + save_lora_checkpoint( + transformer, optimizer, rank, args.output_dir, step, pipeline + ) + else: + # Your existing checkpoint saving code + save_checkpoint(transformer, optimizer, rank, args.output_dir, step) + dist.barrier() + if args.log_validation and step % args.validation_steps == 0: + log_validation(args, transformer, device, torch.bfloat16, step) + + if args.use_lora: + save_lora_checkpoint( + transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipeline + ) + else: + save_checkpoint( + transformer, optimizer, rank, args.output_dir, args.max_train_steps + ) + + if get_sequence_parallel_state(): + destroy_sequence_parallel_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_type", type=str, default="mochi", help="The type of model to train." + ) + # dataset & dataloader + parser.add_argument("--data_json_path", type=str, required=True) + parser.add_argument("--num_frames", type=int, default=163) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=10, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=16, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--num_latent_t", type=int, default=28, help="Number of latent timesteps." + ) + parser.add_argument("--group_frame", action="store_true") # TODO + parser.add_argument("--group_resolution", action="store_true") # TODO + + # text encoder & vae & diffusion model + parser.add_argument("--pretrained_model_name_or_path", type=str) + parser.add_argument("--dit_model_name_or_path", type=str, default=None) + parser.add_argument("--cache_dir", type=str, default="./cache_dir") + + # diffusion setting + parser.add_argument("--ema_decay", type=float, default=0.999) + parser.add_argument("--ema_start_step", type=int, default=0) + parser.add_argument("--cfg", type=float, default=0.1) + parser.add_argument( + "--precondition_outputs", + action="store_true", + help="Whether to precondition the outputs of the model.", + ) + + # validation & logs + parser.add_argument("--validation_prompt_dir", type=str) + parser.add_argument("--uncond_prompt_dir", type=str) + parser.add_argument( + "--validation_sampling_steps", + type=str, + default="64", + help="use ',' to split multi sampling steps", + ) + parser.add_argument( + "--validation_guidance_scale", + type=str, + default="4.5", + help="use ',' to split multi scale", + ) + parser.add_argument("--validation_steps", type=int, default=50) + parser.add_argument("--log_validation", action="store_true") + parser.add_argument("--tracker_project_name", type=str, default=None) + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--resume_from_lora_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous lora checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + + # optimizer & scheduler & Training + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=10, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument("--selective_checkpointing", type=float, default=1.0) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--use_cpu_offload", + action="store_true", + help="Whether to use CPU offload for param & gradient & optimizer states.", + ) + + parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel") + parser.add_argument( + "--train_sp_batch_size", + type=int, + default=1, + help="Batch size for sequence parallel training", + ) + + parser.add_argument( + "--use_lora", + action="store_true", + default=False, + help="Whether to use LoRA for finetuning.", + ) + parser.add_argument( + "--lora_alpha", type=int, default=256, help="Alpha parameter for LoRA." + ) + parser.add_argument( + "--lora_rank", type=int, default=128, help="LoRA rank parameter. " + ) + parser.add_argument("--fsdp_sharding_startegy", default="full") + + parser.add_argument( + "--weighting_scheme", + type=str, + default="uniform", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"], + ) + parser.add_argument( + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme.", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme.", + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + # lr_scheduler + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of cycles in the learning rate scheduler.", + ) + parser.add_argument( + "--lr_power", + type=float, + default=1.0, + help="Power factor of the polynomial scheduler.", + ) + parser.add_argument( + "--weight_decay", type=float, default=0.01, help="Weight decay to apply." + ) + parser.add_argument( + "--master_weight_type", + type=str, + default="fp32", + help="Weight type to use - fp32 or bf16.", + ) + + args = parser.parse_args() + main(args) diff --git a/fastvideo/utils/checkpoint.py b/fastvideo/utils/checkpoint.py index 3fd8bfcf..8b6b6f66 100644 --- a/fastvideo/utils/checkpoint.py +++ b/fastvideo/utils/checkpoint.py @@ -214,7 +214,7 @@ def resume_training(model, optimizer, checkpoint_dir, discriminator=False): return model, optimizer, step -def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step): +def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step, pipeline): with FSDP.state_dict_type( transformer, StateDictType.FULL_STATE_DICT, @@ -235,7 +235,7 @@ def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step): transformer_lora_layers = get_peft_model_state_dict( model=transformer, state_dict=full_state_dict ) - MochiPipeline.save_lora_weights( + pipeline.save_lora_weights( save_directory=save_dir, transformer_lora_layers=transformer_lora_layers, is_main_process=True, diff --git a/prompts.txt b/prompts.txt new file mode 100644 index 00000000..c228a916 --- /dev/null +++ b/prompts.txt @@ -0,0 +1,4 @@ +Will Smith casually eats noodles, his relaxed demeanor contrasting with the energetic background of a bustling street food market. The scene captures a mix of humor and authenticity. Mid-shot framing, vibrant lighting. +A lone hiker stands atop a towering cliff, silhouetted against the vast horizon. The rugged landscape stretches endlessly beneath, its earthy tones blending into the soft blues of the sky. The scene captures the spirit of exploration and human resilience. High angle, dynamic framing, with soft natural lighting emphasizing the grandeur of nature. +A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. A beige string bag sits beside the bowl, adding a rustic touch to the scene. Additional lemons, one halved, are scattered around the base of the bowl. The even lighting enhances the vibrant colors and creates a fresh, inviting atmosphere. +A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest. The playful yet serene atmosphere is complemented by soft natural light filtering through the petals. Mid-shot, warm and cheerful tones. diff --git a/scripts/finetune/finetune_hunyuan.sh b/scripts/finetune/finetune_hunyuan.sh index e5f1e722..567afc54 100644 --- a/scripts/finetune/finetune_hunyuan.sh +++ b/scripts/finetune/finetune_hunyuan.sh @@ -4,7 +4,7 @@ export WANDB_MODE=online torchrun --nnodes 1 --nproc_per_node 1 \ fastvideo/train.py \ --seed 42 \ - --pretrained_model_name_or_path data/FastHunyuan \ + --pretrained_model_name_or_path data/hunyuan \ --dit_model_name_or_path data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt\ --model_type "hunyuan" \ --cache_dir data/.cache \ diff --git a/scripts/inference/inference_hunyuan.sh b/scripts/inference/inference_hunyuan.sh index 0340431f..8ff2a1da 100644 --- a/scripts/inference/inference_hunyuan.sh +++ b/scripts/inference/inference_hunyuan.sh @@ -1,13 +1,13 @@ #!/bin/bash num_gpus=4 -export MODEL_BASE=data/FastHunyuan +export MODEL_BASE=/root/hunyuan torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ fastvideo/sample/sample_t2v_hunyuan.py \ - --height 720 \ - --width 1280 \ - --num_frames 125 \ - --num_inference_steps 6 \ + --height 480 \ + --width 848 \ + --num_frames 93 \ + --num_inference_steps 50 \ --guidance_scale 1 \ --embedded_cfg_scale 6 \ --flow_shift 17 \ diff --git a/scripts/inference/inference_hunyuan_hf.sh b/scripts/inference/inference_hunyuan_hf.sh new file mode 100644 index 00000000..abd36a4e --- /dev/null +++ b/scripts/inference/inference_hunyuan_hf.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +num_gpus=4 +torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ + fastvideo/sample/sample_t2v_hunyuan_hf.py \ + --model_path ~/hunyuan_hf/ \ + --prompt_path "assets/prompt_test.txt" \ + --num_frames 93 \ + --height 480 \ + --width 848 \ + --num_inference_steps 50 \ + --guidance_scale 1.5 \ + --output_path outputs_video/hunyuan_hf/ \ + --seed 1024 \ + --linear_threshold 0.1 \ + --flow_shift 17 \ + --flow-reverse \ + --linear_range 0.75 \ + + + diff --git a/scripts/inference/inference_mochi_sp.sh b/scripts/inference/inference_mochi_sp.sh index 3c50185b..4ee373f9 100644 --- a/scripts/inference/inference_mochi_sp.sh +++ b/scripts/inference/inference_mochi_sp.sh @@ -1,15 +1,14 @@ #!/bin/bash num_gpus=4 - torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ fastvideo/sample/sample_t2v_mochi.py \ - --model_path data/FastMochi-diffusers \ + --model_path ~/mochi \ --prompt_path "assets/prompt.txt" \ - --num_frames 163 \ + --num_frames 91 \ --height 480 \ --width 848 \ - --num_inference_steps 8 \ + --num_inference_steps 64 \ --guidance_scale 1.5 \ --output_path outputs_video/mochi_sp/ \ --seed 1024 \ From ac58b2fb3bb8d53f15349f760eef8969891f1b18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Tue, 24 Dec 2024 02:27:42 +0000 Subject: [PATCH 07/42] test hunyuan hf --- fastvideo/test/test_hunyuan_hf.py | 81 +++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 fastvideo/test/test_hunyuan_hf.py diff --git a/fastvideo/test/test_hunyuan_hf.py b/fastvideo/test/test_hunyuan_hf.py new file mode 100644 index 00000000..9f312570 --- /dev/null +++ b/fastvideo/test/test_hunyuan_hf.py @@ -0,0 +1,81 @@ +import torch +from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel +from diffusers.utils import export_to_video +import random +import numpy as np +import argparse +import os +from fastvideo.models.hunyuan.diffusion.schedulers import FlowMatchDiscreteScheduler + +def parse_args(): + parser = argparse.ArgumentParser(description='Generate video using Hunyuan model') + + parser.add_argument('--prompt', type=str, required=True, help='Text prompt for video generation') + parser.add_argument('--model_path', type=str, default='/root/hunyuan_hf/', help='Path to the Hunyuan model directory') + parser.add_argument('--output_dir', type=str, default='outputs_video/hunyuan_hf', help='Directory to save the output video') + parser.add_argument('--height', type=int, default=480, help='Height of the output video') + parser.add_argument('--width', type=int, default=848, help='Width of the output video') + parser.add_argument('--num_frames', type=int, default=93, help='Number of frames to generate') + parser.add_argument('--num_inference_steps', type=int, default=50, help='Number of inference steps') + parser.add_argument('--seed', type=int, default=1024, help='Random seed for generation') + parser.add_argument('--fps', type=int, default=24, help='Frames per second for the output video') + + + return parser.parse_args() + +def main(): + args = parse_args() + prompt_candidates = ["Will Smith casually eats noodles, his relaxed demeanor contrasting with the energetic background of a bustling street food market. The scene captures a mix of humor and authenticity. Mid-shot framing, vibrant lighting.", + "A lone hiker stands atop a towering cliff, silhouetted against the vast horizon. The rugged landscape stretches endlessly beneath, its earthy tones blending into the soft blues of the sky. The scene captures the spirit of exploration and human resilience. High angle, dynamic framing, with soft natural lighting emphasizing the grandeur of nature.", + "A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. A beige string bag sits beside the bowl, adding a rustic touch to the scene. Additional lemons, one halved, are scattered around the base of the bowl. The even lighting enhances the vibrant colors and creates a fresh, inviting atmosphere.", + "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest. The playful yet serene atmosphere is complemented by soft natural light filtering through the petals. Mid-shot, warm and cheerful tones.", + "A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robots immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1", + "fox in the forest close-up quickly turned its head to the left.", + "Man walking his dog in the woods on a hot sunny day", + "A majestic lion strides across the golden savanna, its powerful frame glistening under the warm afternoon sun. The tall grass ripples gently in the breeze, enhancing the lion's commanding presence. The tone is vibrant, embodying the raw energy of the wild. Low angle, steady tracking shot, cinematic."] + # Set random seed + generator = torch.Generator("cpu").manual_seed(args.seed) + # Load transformer model + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + pretrained_model_name_or_path=args.model_path, + subfolder="transformer", + torch_dtype=torch.bfloat16, + local_files_only=True + ) + + # Initialize pipeline + pipe = HunyuanVideoPipeline.from_pretrained( + pretrained_model_name_or_path=args.model_path, + transformer=transformer, + torch_dtype=torch.float16, + local_files_only=True + ) + #pipe.vae = pipe.vae.to(torch.bfloat16) + pipe.vae.enable_tiling() + + # Move to GPU + device = torch.cuda.current_device() + pipe.to(device) + #pipe.enable_model_cpu_offload(device) + + # Create output directory if it doesn't exist + os.makedirs(args.output_dir, exist_ok=True) + file_name = args.prompt[:20] + output_path = os.path.join(args.output_dir, file_name + 'output.mp4') + + # Generate video + output = pipe( + prompt=args.prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.num_inference_steps, + generator=generator + ).frames[0] + + # Save video + export_to_video(output, output_path, fps=args.fps) + print(f"Video saved to: {output_path}") + +if __name__ == "__main__": + main() From 77b690c83e390c25467861a3256561db8c53238a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Tue, 24 Dec 2024 21:06:33 +0000 Subject: [PATCH 08/42] add huanyuan hf inference and train --- .../models/hunyuan_hf/modeling_hunyuan.py | 322 +++++------------- .../models/hunyuan_hf/pipeline_hunyuan.py | 18 +- .../models/mochi_hf/mochi_latents_utils.py | 4 +- fastvideo/models/mochi_hf/modeling_mochi.py | 80 ++--- fastvideo/sample/sample_t2v_hunyuan_hf.py | 11 +- fastvideo/train.py | 17 +- fastvideo/utils/load.py | 33 +- 7 files changed, 176 insertions(+), 309 deletions(-) diff --git a/fastvideo/models/hunyuan_hf/modeling_hunyuan.py b/fastvideo/models/hunyuan_hf/modeling_hunyuan.py index 46d27e3c..b4c60bba 100644 --- a/fastvideo/models/hunyuan_hf/modeling_hunyuan.py +++ b/fastvideo/models/hunyuan_hf/modeling_hunyuan.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import PeftAdapterMixin +from diffusers.loaders import PeftAdapterMixin, FromOriginalModelMixin from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from diffusers.models.attention import FeedForward from diffusers.models.attention_processor import Attention, AttentionProcessor @@ -36,68 +36,11 @@ from fastvideo.models.flash_attn_no_pad import flash_attn_no_pad logger = logging.get_logger(__name__) # pylint: disable=invalid-name -import sys -import pdb -class ForkedPdb(pdb.Pdb): - """A Pdb subclass that may be used - from a forked multiprocessing child - - """ - def interaction(self, *args, **kwargs): - _stdin = sys.stdin - try: - sys.stdin = open('/dev/stdin') - pdb.Pdb.interaction(self, *args, **kwargs) - finally: - sys.stdin = _stdin -def apply_rotary_emb( - x: torch.Tensor, - freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], - use_real: bool = True, - use_real_unbind_dim: int = -1, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings - to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are - reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting - tensors contain rotary embeddings and are returned as real tensors. - - Args: - x (`torch.Tensor`): - Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply - freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. - """ - x = x.transpose(1,2) - if use_real: - cos, sin = freqs_cis # [S, D] - cos = cos[None, None] - sin = sin[None, None] - cos, sin = cos.to(x.device), sin.to(x.device) - - if use_real_unbind_dim == -1: - # Used for flux, cogvideox, hunyuan-dit - x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) - elif use_real_unbind_dim == -2: - # Used for Stable Audio - x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] - x_rotated = torch.cat([-x_imag, x_real], dim=-1) - else: - raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - - return out.transpose(1,2) - else: - # used for lumina - x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) - - return x_out.type_as(x).transpose(1,2) +def shrink_head(encoder_state, dim): + local_heads = encoder_state.shape[dim] // nccl_info.sp_size + return encoder_state.narrow( + dim, nccl_info.rank_within_group * local_heads, local_heads + ) class HunyuanVideoAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -108,210 +51,136 @@ def __init__(self): def __call__( self, attn: Attention, - hidden_states: torch.Tensor, #[1, 38160, 3072] [1, 19080, 3072] - encoder_hidden_states: Optional[torch.Tensor] = None, # [1, 256, 3072] [1, 256, 3072] + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, # [38160, 128] [38106, 128] + image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - if attn.add_q_proj is None and encoder_hidden_states is not None: - sequence_length = hidden_states.size(1) # 19080 - encoder_sequence_length = encoder_hidden_states.size(1) # 256 - hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) # [1, 19336, 3072] + sequence_length = hidden_states.size(1) + encoder_sequence_length = encoder_hidden_states.size(1) + if attn.add_q_proj is None and encoder_hidden_states is not None: + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) # 1. QKV projections - query = attn.to_q(hidden_states) # [1, 19080, 3072] + query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - query = query.unflatten(2, (attn.heads, -1)) # [1, 19080, 24, 128] - key = key.unflatten(2, (attn.heads, -1)) - value = value.unflatten(2, (attn.heads, -1)) + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) # 2. QK normalization if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) - - # from IPython import embed - # embed() - def shrink_head(x, dim): - local_heads = x.shape[dim] // nccl_info.sp_size - return x.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads) - - if get_sequence_parallel_state(): - # Handle sequence parallelism for main hidden states - # Note: We scatter on heads dim (1) and gather on sequence dim (2) since tensors are transposed - if attn.add_q_proj is None: - qkv_ = [query, key, value] - - qk_h, qk_eh = [], [] - for item in qkv_: - qk_h.append(item[:,:sequence_length,:,:]) - qk_eh.append(item[:,sequence_length:,:,:]) - for i in range(len(qkv_)): - qk_h[i] = all_to_all_4D(qk_h[i], scatter_dim=2, gather_dim=1) - qk_eh[i] = shrink_head(qk_eh[i], dim=2) - value = torch.cat([qk_h[2],qk_eh[2]], dim=1) - else: - query = all_to_all_4D(query, scatter_dim=2, gather_dim=1) # [1, 24, 19080, 128] -> [1, 12, 38160, 128] - key = all_to_all_4D(key, scatter_dim=2, gather_dim=1) - value = all_to_all_4D(value, scatter_dim=2, gather_dim=1) - + image_rotary_emb = ( + shrink_head(image_rotary_emb[0], dim=0), + shrink_head(image_rotary_emb[1], dim=0), + ) - - # if image_rotary_emb is not None: - - # freqs_cos, freqs_sin = image_rotary_emb - # freqs_cos = shrink_head(freqs_cos, dim=1) - # freqs_sin = shrink_head(freqs_sin, dim=1) - # image_rotary_emb = (freqs_cos, freqs_sin) - # 3. Rotational positional embeddings applied to latent stream - - # freqs_cos, freqs_sin = image_rotary_emb - # freqs_cos = freqs_cos.unsqueeze(1).expand(-1, attn.heads // nccl_info.sp_size, -1) - # freqs_sin = freqs_sin.unsqueeze(1).expand(-1, attn.heads // nccl_info.sp_size, -1) - # image_rotary_emb = (freqs_cos, freqs_sin) - if image_rotary_emb is not None: - #from diffusers.models.embeddings import apply_rotary_emb + from diffusers.models.embeddings import apply_rotary_emb if attn.add_q_proj is None and encoder_hidden_states is not None: - if get_sequence_parallel_state(): - query = torch.cat( - [ - apply_rotary_emb(qk_h[0], image_rotary_emb), qk_eh[0] - ], - dim=1, - ) - key = torch.cat( - [ - apply_rotary_emb(qk_h[1], image_rotary_emb), qk_eh[1] - ], - dim=1, - ) - # if get_sequence_parallel_state() and attn.add_q_proj is None: - - else: - query = torch.cat( - [ - apply_rotary_emb(query[:,: -encoder_hidden_states.shape[1], :], image_rotary_emb), - query[:, -encoder_hidden_states.shape[1] :,:], - ], - dim=1, - ) - key = torch.cat( - [ - apply_rotary_emb(key[:,: -encoder_hidden_states.shape[1], :], image_rotary_emb), - key[:, -encoder_hidden_states.shape[1] :,:], - ], - dim=1, - ) - + query = torch.cat( + [ + apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + query[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + key = torch.cat( + [ + apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + key[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) else: - - query = apply_rotary_emb(query, image_rotary_emb) # [1, 24, 38160, 128] + query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) # 4. Encoder condition QKV projection and normalization if attn.add_q_proj is not None and encoder_hidden_states is not None: - encoder_query = attn.add_q_proj(encoder_hidden_states) # [1, 256, 3072] + encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) - encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) # [1, 24, 256, 128] - encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) - encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) if attn.norm_added_q is not None: encoder_query = attn.norm_added_q(encoder_query) if attn.norm_added_k is not None: encoder_key = attn.norm_added_k(encoder_key) - if get_sequence_parallel_state(): - encoder_query = shrink_head(encoder_query, dim=2) - encoder_key = shrink_head(encoder_key, dim=2) - encoder_value = shrink_head(encoder_value, dim=2) - sequence_length = query.size(1) - encoder_sequence_length = encoder_query.size(1) - - query = torch.cat([query, encoder_query], dim=1).unsqueeze(2) # [1, 24, 1, 38416, 128] - key = torch.cat([key, encoder_key], dim=1).unsqueeze(2) - value = torch.cat([value, encoder_value], dim=1).unsqueeze(2) - else: - # from IPython import embed - # embed() - query = query.unsqueeze(2) # [1, 24, 1, 38416, 128] - key = key.unsqueeze(2) - value = value.unsqueeze(2) + query = torch.cat([query, encoder_query], dim=2) # [1, 24, 1, 38416, 128] + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + if get_sequence_parallel_state(): + query_img, query_txt = query[:,:,:sequence_length,:], query[:,:,sequence_length:,:] + key_img, key_txt = key[:,:,:sequence_length,:], key[:,:,sequence_length:,:] + value_img, value_txt = value[:,:,:sequence_length,:], value[:,:,sequence_length:,:] + query_img = all_to_all_4D(query_img, scatter_dim=1, gather_dim=2) # + key_img = all_to_all_4D(key_img, scatter_dim=1, gather_dim=2) + value_img = all_to_all_4D(value_img, scatter_dim=1, gather_dim=2) + + query_txt = shrink_head(query_txt, dim=1) + key_txt = shrink_head(key_txt, dim=1) + value_txt = shrink_head(value_txt, dim=1) + query = torch.cat([query_img, query_txt], dim=2) + key = torch.cat([key_img, key_txt], dim=2) + value = torch.cat([value_img, value_txt], dim=2) + query = query.unsqueeze(2) # [1, 24, 1, 38416, 128] + key = key.unsqueeze(2) + value = value.unsqueeze(2) qkv = torch.cat([query, key, value], dim=2) + qkv = qkv.transpose(1,3) # 5. Attention - attention_mask = attention_mask[:,0,:] - - hidden_states = flash_attn_no_pad( #[1, 38416, 24, 128] - qkv, attention_mask, causal=False, dropout_p=0.0, softmax_scale=None # [2, 25696] - ) - + seq_len = qkv.shape[1] + attn_len = attention_mask.shape[1] + attention_mask = F.pad(attention_mask, (seq_len-attn_len, 0), value=True) + hidden_states = flash_attn_no_pad(qkv, attention_mask, causal=False, dropout_p=0.0, softmax_scale=None) # [1, 39184, 6, 128] - # hidden_states = F.scaled_dot_product_attention( - # query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - # ) - #hidden_states = hidden_states.flatten(2, 3) # [1, 38416, 3072] - hidden_states = hidden_states.to(query.dtype) - - # 6. Output projection - if encoder_hidden_states is not None: - if get_sequence_parallel_state(): - if attn.add_q_proj is not None: - - hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( - (sequence_length, encoder_sequence_length), dim=1 - ) - # B, S, H, D - - hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) - encoder_hidden_states = all_gather( - encoder_hidden_states, dim=2 - ).contiguous() - - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) - encoder_hidden_states = encoder_hidden_states.flatten(2, 3) - encoder_hidden_states = encoder_hidden_states.to(query.dtype) - else: - - hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( - (sequence_length * nccl_info.sp_size, encoder_sequence_length), dim=1 - ) - # B, S, H, D - #ForkedPdb().set_trace() - hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) - encoder_hidden_states = all_gather( - encoder_hidden_states, dim=2 - ).contiguous() - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) - encoder_hidden_states = encoder_hidden_states.flatten(2, 3) - encoder_hidden_states = encoder_hidden_states.to(query.dtype) - else: - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) - hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( - (sequence_length, encoder_sequence_length), dim=1 + if get_sequence_parallel_state(): + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( + (sequence_length * nccl_info.sp_size, encoder_sequence_length), dim=1 + ) + hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) + encoder_hidden_states = all_gather( + encoder_hidden_states, dim=2 + ).contiguous() + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + encoder_hidden_states = encoder_hidden_states.flatten(2, 3) + encoder_hidden_states = encoder_hidden_states.to(query.dtype) + else: + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # 6. Output projection + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], ) + if encoder_hidden_states is not None: if getattr(attn, "to_out", None) is not None: hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) - + if getattr(attn, "to_add_out", None) is not None: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - return hidden_states, encoder_hidden_states # [1, 38160, 3072] [1, 256, 3072] + return hidden_states, encoder_hidden_states class HunyuanVideoPatchEmbed(nn.Module): @@ -506,8 +375,7 @@ def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], thet def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape - - rope_sizes = [num_frames // self.patch_size_t * nccl_info.sp_size, height // self.patch_size, width // self.patch_size] + rope_sizes = [num_frames * nccl_info.sp_size // self.patch_size_t, height // self.patch_size, width // self.patch_size] axes_grids = [] for i in range(3): @@ -516,7 +384,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # differences in layerwise debugging outputs, but visually it is the same. grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32) axes_grids.append(grid) - grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T] grid = torch.stack(grid, dim=0) # [3, W, H, T] @@ -681,7 +548,7 @@ def forward( return hidden_states, encoder_hidden_states -class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). @@ -880,7 +747,6 @@ def forward( post_patch_width = width // p # 1. RoPE - image_rotary_emb = self.rope(hidden_states) # 2. Conditional embeddings @@ -896,7 +762,7 @@ def forward( batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool ) # [B, N, N] - effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] + effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) effective_sequence_length = latent_sequence_length + effective_condition_sequence_length for i in range(batch_size): diff --git a/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py b/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py index aac471e8..e0b7753e 100644 --- a/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py +++ b/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py @@ -33,7 +33,20 @@ from fastvideo.utils.communications import all_gather logger = logging.get_logger(__name__) # pylint: disable=invalid-name +import sys +import pdb +class ForkedPdb(pdb.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + """ + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin EXAMPLE_DOC_STRING = """ Examples: ```python @@ -618,6 +631,7 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, @@ -630,12 +644,12 @@ def __call__( latents, ) world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group - if get_sequence_parallel_state(): latents = rearrange( latents, "b t (n s) h w -> b t n s h w", n=world_size ).contiguous() latents = latents[:, :, rank, :, :, :] + # 6. Prepare guidance condition guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 @@ -681,7 +695,7 @@ def __call__( if get_sequence_parallel_state(): latents = all_gather(latents, dim=2) - + if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] diff --git a/fastvideo/models/mochi_hf/mochi_latents_utils.py b/fastvideo/models/mochi_hf/mochi_latents_utils.py index 7c9249f8..d75df607 100644 --- a/fastvideo/models/mochi_hf/mochi_latents_utils.py +++ b/fastvideo/models/mochi_hf/mochi_latents_utils.py @@ -36,11 +36,13 @@ def normalize_dit_input(model_type, latents): - if model_type == "mochi": + if model_type == "mochi_hf": latents_mean = mochi_latents_mean.to(latents.device, latents.dtype) latents_std = mochi_latents_std.to(latents.device, latents.dtype) latents = (latents - latents_mean) / latents_std return latents + elif model_type == "hunyuan_hf": + return latents * 0.476986 elif model_type == "hunyuan": return latents * 0.476986 else: diff --git a/fastvideo/models/mochi_hf/modeling_mochi.py b/fastvideo/models/mochi_hf/modeling_mochi.py index d0aae801..a57d0fa7 100644 --- a/fastvideo/models/mochi_hf/modeling_mochi.py +++ b/fastvideo/models/mochi_hf/modeling_mochi.py @@ -55,20 +55,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -import sys -import pdb -class ForkedPdb(pdb.Pdb): - """A Pdb subclass that may be used - from a forked multiprocessing child - - """ - def interaction(self, *args, **kwargs): - _stdin = sys.stdin - try: - sys.stdin = open('/dev/stdin') - pdb.Pdb.interaction(self, *args, **kwargs) - finally: - sys.stdin = _stdin class FeedForward(HF_FeedForward): def __init__( @@ -178,20 +164,19 @@ def __init__(self): def __call__( self, attn: Attention, - hidden_states: torch.Tensor, # [2, 25440, 3072] [2, 12720, 3072] - encoder_hidden_states: torch.Tensor, # [2, 256, 1536] [2, 256, 1536] - encoder_attention_mask: torch.Tensor, # [2, 256] + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, # [25440, 24, 64] [25440, 24, 64] - ) -> torch.Tensor: + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: # [b, s, h * d] - - query = attn.to_q(hidden_states) # [2, 25440, 3072] + query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) # [b, s, h=24, d=128] - query = query.unflatten(2, (attn.heads, -1)) # [2, 25440, 24, 128] + query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) value = value.unflatten(2, (attn.heads, -1)) @@ -200,12 +185,12 @@ def __call__( if attn.norm_k is not None: key = attn.norm_k(key) # [b, 256, h * d] - encoder_query = attn.add_q_proj(encoder_hidden_states) # [2, 256, 3072] + encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) # [b, 256, h=24, d=128] - encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) # [2, 256, 24, 128] + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) @@ -215,15 +200,12 @@ def __call__( encoder_key = attn.norm_added_k(encoder_key) if image_rotary_emb is not None: - freqs_cos, freqs_sin = image_rotary_emb[0], image_rotary_emb[1] # [25440, 24, 64] + freqs_cos, freqs_sin = image_rotary_emb[0], image_rotary_emb[1] # shard the head dimension - - if get_sequence_parallel_state(): # B, S, H, D to (S, B,) H, D # batch_size, seq_len, attn_heads, head_dim - - query = all_to_all_4D(query, scatter_dim=2, gather_dim=1) # [2, 25440, 24, 128] + query = all_to_all_4D(query, scatter_dim=2, gather_dim=1) key = all_to_all_4D(key, scatter_dim=2, gather_dim=1) value = all_to_all_4D(value, scatter_dim=2, gather_dim=1) @@ -232,13 +214,12 @@ def shrink_head(encoder_state, dim): return encoder_state.narrow( dim, nccl_info.rank_within_group * local_heads, local_heads ) - ForkedPdb().set_trace() - encoder_query = shrink_head(encoder_query, dim=2) # [2, 256, 12, 128] + + encoder_query = shrink_head(encoder_query, dim=2) encoder_key = shrink_head(encoder_key, dim=2) encoder_value = shrink_head(encoder_value, dim=2) - if image_rotary_emb is not None: - freqs_cos = shrink_head(freqs_cos, dim=1) # [25440, 12, 64] + freqs_cos = shrink_head(freqs_cos, dim=1) freqs_sin = shrink_head(freqs_sin, dim=1) if image_rotary_emb is not None: @@ -251,7 +232,7 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): return torch.stack([cos, sin], dim=-1).flatten(-2) - query = apply_rotary_emb(query, freqs_cos, freqs_sin) # [2, 25440, 24, 128] + query = apply_rotary_emb(query, freqs_cos, freqs_sin) key = apply_rotary_emb(key, freqs_cos, freqs_sin) # query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) @@ -265,17 +246,16 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): encoder_sequence_length = encoder_query.size(1) # H - query = torch.cat([query, encoder_query], dim=1).unsqueeze(2) # [2, 25696, 1, 24, 128] + query = torch.cat([query, encoder_query], dim=1).unsqueeze(2) key = torch.cat([key, encoder_key], dim=1).unsqueeze(2) value = torch.cat([value, encoder_value], dim=1).unsqueeze(2) # B, S, 3, H, D - qkv = torch.cat([query, key, value], dim=2) # [2, 25696, 3, 24, 128] + qkv = torch.cat([query, key, value], dim=2) attn_mask = encoder_attention_mask[:, :].bool() attn_mask = F.pad(attn_mask, (sequence_length, 0), value=True) - - hidden_states = flash_attn_no_pad( #[2, 25696, 24, 128] - qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None # [2, 25696] + hidden_states = flash_attn_no_pad( + qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None ) # hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask = None, dropout_p=0.0, is_causal=False) @@ -290,7 +270,6 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): (sequence_length, encoder_sequence_length), dim=1 ) # B, S, H, D - hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) encoder_hidden_states = all_gather( encoder_hidden_states, dim=2 @@ -306,14 +285,14 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( (sequence_length, encoder_sequence_length), dim=1 ) - + # linear proj - hidden_states = attn.to_out[0](hidden_states) #[2, 25440, 3072] + hidden_states = attn.to_out[0](hidden_states) # dropout - hidden_states = attn.to_out[1](hidden_states) #[2, 256, 3072] + hidden_states = attn.to_out[1](hidden_states) if hasattr(attn, "to_add_out"): - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) #[2, 256, 1536] + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states @@ -498,7 +477,6 @@ def _get_positions( dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: scale = (self.target_area / (height * width)) ** 0.5 - t = torch.arange(num_frames * nccl_info.sp_size, device=device, dtype=dtype) h = self._centers( -height * scale / 2, height * scale / 2, height, device, dtype @@ -506,7 +484,7 @@ def _get_positions( w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype) grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") - + positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3) return positions @@ -565,7 +543,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ _supports_gradient_checkpointing = True - _no_split_modules = ["MochiTransformerBlock"] @register_to_config def __init__( @@ -688,16 +665,15 @@ def forward( hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.patch_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) - + image_rotary_emb = self.rope( - self.pos_frequencies, # [3, 24, 64] - num_frames, # 8 + self.pos_frequencies, + num_frames, post_patch_height, post_patch_width, - device=hidden_states.device, # [2, 12720, 3072] + device=hidden_states.device, dtype=torch.float32, ) - attn_outputs_list = [] for i, block in enumerate(self.transformer_blocks): if self.gradient_checkpointing: diff --git a/fastvideo/sample/sample_t2v_hunyuan_hf.py b/fastvideo/sample/sample_t2v_hunyuan_hf.py index 14ed518c..620add3b 100644 --- a/fastvideo/sample/sample_t2v_hunyuan_hf.py +++ b/fastvideo/sample/sample_t2v_hunyuan_hf.py @@ -23,8 +23,8 @@ from fastvideo.models.hunyuan.diffusion.schedulers import FlowMatchDiscreteScheduler from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline from fastvideo.models.hunyuan_hf.modeling_hunyuan import HunyuanVideoTransformer3DModel - def initialize_distributed(): + os.environ["TOKENIZERS_PARALLELISM"] = "false" local_rank = int(os.getenv("RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) print("world_size", world_size) @@ -105,7 +105,6 @@ def main(args): width=args.width, num_frames=args.num_frames, num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, generator=generator, ).frames if nccl_info.global_rank <= 0: @@ -126,7 +125,6 @@ def main(args): width=args.width, num_frames=args.num_frames, num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, generator=generator, ).frames @@ -158,12 +156,5 @@ def main(args): default=None, help="Path to the directory containing LoRA checkpoints", ) - parser.add_argument("--flow_shift", type=int, default=7, help="Flow shift parameter.") - parser.add_argument("--flow-reverse",action="store_true",help="If reverse, learning/sampling from t=1 -> t=0.",) - parser.add_argument("--flow-solver", type=str, default="euler", help="Solver for flow matching.") - parser.add_argument("--shift", type=float, default=8.0) - parser.add_argument("--num_euler_timesteps", type=int, default=100) - parser.add_argument("--linear_threshold", type=float, default=0.025) - parser.add_argument("--linear_range", type=float, default=0.75) args = parser.parse_args() main(args) diff --git a/fastvideo/train.py b/fastvideo/train.py index 13d4f9ce..5c7e0b30 100644 --- a/fastvideo/train.py +++ b/fastvideo/train.py @@ -47,7 +47,6 @@ ) from fastvideo.utils.logging_ import main_print from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline -from diffusers.pipelines.hunyuan_video import HunyuanVideoPipeline # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.31.0") @@ -225,11 +224,7 @@ def main(args): ) if args.use_lora: - #assert args.model_type == "mochi", "LoRA is only supported for Mochi model." - if args.model_type == "mochi": - pipeline = MochiPipeline - elif args.model_type == "hunyuan": - pipeline = HunyuanVideoPipeline + assert args.model_type == "mochi", "LoRA is only supported for Mochi model." transformer.requires_grad_(False) transformer_lora_config = LoraConfig( r=args.lora_rank, @@ -237,13 +232,10 @@ def main(args): init_lora_weights=True, target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) - from IPython import embed - embed() - transformer.add_adapter(transformer_lora_config) if args.resume_from_lora_checkpoint: - lora_state_dict = pipeline.lora_state_dict( + lora_state_dict = MochiPipeline.lora_state_dict( args.resume_from_lora_checkpoint ) transformer_state_dict = { @@ -464,7 +456,7 @@ def main(args): if args.use_lora: # Save LoRA weights save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, step, pipeline + transformer, optimizer, rank, args.output_dir, step ) else: # Your existing checkpoint saving code @@ -475,7 +467,7 @@ def main(args): if args.use_lora: save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipeline + transformer, optimizer, rank, args.output_dir, args.max_train_steps ) else: save_checkpoint( @@ -741,6 +733,5 @@ def main(args): default="fp32", help="Weight type to use - fp32 or bf16.", ) - args = parser.parse_args() main(args) diff --git a/fastvideo/utils/load.py b/fastvideo/utils/load.py index d29cb007..7dce3c32 100644 --- a/fastvideo/utils/load.py +++ b/fastvideo/utils/load.py @@ -3,13 +3,18 @@ MochiTransformer3DModel, MochiTransformerBlock, ) +from fastvideo.models.hunyuan_hf.modeling_hunyuan import ( + HunyuanVideoTransformer3DModel, + HunyuanVideoSingleTransformerBlock, + HunyuanVideoTransformerBlock, +) from fastvideo.models.hunyuan.modules.models import ( HYVideoDiffusionTransformer, MMDoubleStreamBlock, MMSingleStreamBlock, ) from fastvideo.models.hunyuan.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D -from diffusers import AutoencoderKLMochi +from diffusers import AutoencoderKLMochi, AutoencoderKLHunyuanVideo from transformers import T5EncoderModel, AutoTokenizer import os from torch import nn @@ -255,7 +260,7 @@ def load_transformer( pretrained_model_name_or_path, master_weight_type, ): - if model_type == "mochi": + if model_type == "mochi_hf": if dit_model_name_or_path: transformer = MochiTransformer3DModel.from_pretrained( dit_model_name_or_path, @@ -269,6 +274,20 @@ def load_transformer( torch_dtype=master_weight_type, # torch_dtype=torch.bfloat16 if args.use_lora else torch.float32, ) + elif model_type == "hunyuan_hf": + if dit_model_name_or_path: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + dit_model_name_or_path, + torch_dtype=master_weight_type, + # torch_dtype=torch.bfloat16 if args.use_lora else torch.float32, + ) + else: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=master_weight_type, + # torch_dtype=torch.bfloat16 if args.use_lora else torch.float32, + ) elif model_type == "hunyuan": transformer = HYVideoDiffusionTransformer( in_channels=16, out_channels=16, **hunyuan_config, dtype=master_weight_type, @@ -281,12 +300,18 @@ def load_transformer( def load_vae(model_type, pretrained_model_name_or_path): weight_dtype = torch.float32 - if model_type == "mochi": + if model_type == "mochi_hf": vae = AutoencoderKLMochi.from_pretrained( pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype ).to("cuda") autocast_type = torch.bfloat16 fps = 30 + elif model_type == "hunyuan_hf": + vae = AutoencoderKLHunyuanVideo.from_pretrained( + pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype + ).to("cuda") + autocast_type = torch.bfloat16 + fps = 24 elif model_type == "hunyuan": vae_precision = torch.float32 vae_path = os.path.join( @@ -332,6 +357,8 @@ def get_no_split_modules(transformer): # if of type MochiTransformer3DModel if isinstance(transformer, MochiTransformer3DModel): return (MochiTransformerBlock,) + elif isinstance(transformer, HunyuanVideoTransformer3DModel): + return (HunyuanVideoSingleTransformerBlock, HunyuanVideoTransformerBlock) elif isinstance(transformer, HYVideoDiffusionTransformer): return (MMDoubleStreamBlock, MMSingleStreamBlock) else: From 04c161055a45e39c79dcdf798709b8d825198f53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Fri, 27 Dec 2024 03:44:33 +0000 Subject: [PATCH 09/42] support hunyuan hf lora --- .../models/hunyuan_hf/modeling_hunyuan.py | 34 +++++++++++------ .../models/hunyuan_hf/pipeline_hunyuan.py | 37 +++++++++--------- .../models/mochi_hf/mochi_latents_utils.py | 2 +- fastvideo/train.py | 26 +++++++------ fastvideo/utils/load.py | 6 +-- fastvideo/utils/validation.py | 2 +- scripts/finetune/finetune_hunyuan_hf_lora.sh | 38 +++++++++++++++++++ 7 files changed, 97 insertions(+), 48 deletions(-) create mode 100644 scripts/finetune/finetune_hunyuan_hf_lora.sh diff --git a/fastvideo/models/hunyuan_hf/modeling_hunyuan.py b/fastvideo/models/hunyuan_hf/modeling_hunyuan.py index b4c60bba..655206ef 100644 --- a/fastvideo/models/hunyuan_hf/modeling_hunyuan.py +++ b/fastvideo/models/hunyuan_hf/modeling_hunyuan.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Any, Dict, List, Optional, Tuple, Union - import torch import torch.nn as nn import torch.nn.functional as F @@ -56,6 +55,7 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: + sequence_length = hidden_states.size(1) encoder_sequence_length = encoder_hidden_states.size(1) if attn.add_q_proj is None and encoder_hidden_states is not None: @@ -72,14 +72,15 @@ def __call__( # 2. QK normalization if attn.norm_q is not None: - query = attn.norm_q(query) + query = attn.norm_q(query).to(value) if attn.norm_k is not None: - key = attn.norm_k(key) + key = attn.norm_k(key).to(value) + image_rotary_emb = ( shrink_head(image_rotary_emb[0], dim=0), shrink_head(image_rotary_emb[1], dim=0), ) - + # 3. Rotational positional embeddings applied to latent stream if image_rotary_emb is not None: from diffusers.models.embeddings import apply_rotary_emb @@ -114,11 +115,11 @@ def __call__( encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) if attn.norm_added_q is not None: - encoder_query = attn.norm_added_q(encoder_query) + encoder_query = attn.norm_added_q(encoder_query).to(encoder_value) if attn.norm_added_k is not None: - encoder_key = attn.norm_added_k(encoder_key) + encoder_key = attn.norm_added_k(encoder_key).to(encoder_value) - query = torch.cat([query, encoder_query], dim=2) # [1, 24, 1, 38416, 128] + query = torch.cat([query, encoder_query], dim=2) key = torch.cat([key, encoder_key], dim=2) value = torch.cat([value, encoder_value], dim=2) @@ -136,19 +137,20 @@ def __call__( query = torch.cat([query_img, query_txt], dim=2) key = torch.cat([key_img, key_txt], dim=2) value = torch.cat([value_img, value_txt], dim=2) - query = query.unsqueeze(2) # [1, 24, 1, 38416, 128] + + query = query.unsqueeze(2) key = key.unsqueeze(2) value = value.unsqueeze(2) - qkv = torch.cat([query, key, value], dim=2) qkv = qkv.transpose(1,3) + # 5. Attention attention_mask = attention_mask[:,0,:] seq_len = qkv.shape[1] attn_len = attention_mask.shape[1] attention_mask = F.pad(attention_mask, (seq_len-attn_len, 0), value=True) - hidden_states = flash_attn_no_pad(qkv, attention_mask, causal=False, dropout_p=0.0, softmax_scale=None) # [1, 39184, 6, 128] + hidden_states = flash_attn_no_pad(qkv, attention_mask, causal=False, dropout_p=0.0, softmax_scale=None) if get_sequence_parallel_state(): hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( @@ -172,6 +174,7 @@ def __call__( hidden_states[:, : -encoder_hidden_states.shape[1]], hidden_states[:, -encoder_hidden_states.shape[1] :], ) + if encoder_hidden_states is not None: if getattr(attn, "to_out", None) is not None: hidden_states = attn.to_out[0](hidden_states) @@ -717,14 +720,18 @@ def _set_gradient_checkpointing(self, module, value=False): def forward( self, hidden_states: torch.Tensor, - timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, - pooled_projections: torch.Tensor, guidance: torch.Tensor = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if guidance == None: + guidance = torch.tensor( + [6016.0], device=hidden_states.device, dtype=torch.bfloat16 + ) + if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -746,6 +753,9 @@ def forward( post_patch_height = height // p post_patch_width = width // p + pooled_projections = encoder_hidden_states[:, 0, : self.config.pooled_projection_dim] + encoder_hidden_states = encoder_hidden_states[:, 1:] + # 1. RoPE image_rotary_emb = self.rope(hidden_states) diff --git a/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py b/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py index e0b7753e..936d8a98 100644 --- a/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py +++ b/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py @@ -14,7 +14,7 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union - +import torch.nn.functional as F import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast @@ -33,20 +33,7 @@ from fastvideo.utils.communications import all_gather logger = logging.get_logger(__name__) # pylint: disable=invalid-name -import sys -import pdb -class ForkedPdb(pdb.Pdb): - """A Pdb subclass that may be used - from a forked multiprocessing child - """ - def interaction(self, *args, **kwargs): - _stdin = sys.stdin - try: - sys.stdin = open('/dev/stdin') - pdb.Pdb.interaction(self, *args, **kwargs) - finally: - sys.stdin = _stdin EXAMPLE_DOC_STRING = """ Examples: ```python @@ -325,6 +312,7 @@ def encode_prompt( dtype: Optional[torch.dtype] = None, max_sequence_length: int = 256, ): + if prompt_embeds is None: prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( prompt, @@ -574,7 +562,7 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -603,7 +591,7 @@ def __call__( # 3. Encode input prompt prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( prompt=prompt, - prompt_2=prompt_2, + prompt_2=prompt, prompt_template=prompt_template, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, @@ -643,6 +631,7 @@ def __call__( generator, latents, ) + # check sequence_parallel world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group if get_sequence_parallel_state(): latents = rearrange( @@ -665,13 +654,21 @@ def __call__( latent_model_input = latents.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - + if pooled_prompt_embeds.shape[-1] != prompt_embeds.shape[-1]: + pooled_prompt_embeds_padding = F.pad( + pooled_prompt_embeds, + (0, prompt_embeds.shape[2] - pooled_prompt_embeds.shape[1]), + value=0, + ).unsqueeze(1) + encoder_hidden_states = torch.cat( + [pooled_prompt_embeds_padding, prompt_embeds], dim=1 + ) + noise_pred = self.transformer( - hidden_states=latent_model_input, + hidden_states=latent_model_input, + encoder_hidden_states=encoder_hidden_states, # [1, 257, 4096] timestep=timestep, - encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, - pooled_projections=pooled_prompt_embeds, guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, diff --git a/fastvideo/models/mochi_hf/mochi_latents_utils.py b/fastvideo/models/mochi_hf/mochi_latents_utils.py index d75df607..07233456 100644 --- a/fastvideo/models/mochi_hf/mochi_latents_utils.py +++ b/fastvideo/models/mochi_hf/mochi_latents_utils.py @@ -36,7 +36,7 @@ def normalize_dit_input(model_type, latents): - if model_type == "mochi_hf": + if model_type == "mochi": latents_mean = mochi_latents_mean.to(latents.device, latents.dtype) latents_std = mochi_latents_std.to(latents.device, latents.dtype) latents = (latents - latents_mean) / latents_std diff --git a/fastvideo/train.py b/fastvideo/train.py index 5c7e0b30..095f1bef 100644 --- a/fastvideo/train.py +++ b/fastvideo/train.py @@ -47,6 +47,8 @@ ) from fastvideo.utils.logging_ import main_print from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline +from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.31.0") @@ -149,10 +151,7 @@ def train_one_step( dtype=latents.dtype, ) noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise - # if rank<=0: - # print("2222222222222222222222222222222222222222222222") - # print(type(latents_attention_mask)) - # print(latents_attention_mask) + with torch.autocast("cuda", torch.bfloat16): model_pred = transformer( noisy_model_input, @@ -161,8 +160,7 @@ def train_one_step( encoder_attention_mask, # B, L return_dict=False, )[0] - # if rank<=0: - # print("333333333333333333333333333333333333333333333333") + if precondition_outputs: model_pred = noisy_model_input - model_pred * sigmas if precondition_outputs: @@ -224,7 +222,11 @@ def main(args): ) if args.use_lora: - assert args.model_type == "mochi", "LoRA is only supported for Mochi model." + assert args.model_type != "hunyuan", "LoRA is only supported for huggingface model. Please use hunyuan_hf for lora finetuning" + if args.model_type == "mochi": + pipe = MochiPipeline + elif args.model_type == "hunyuan_hf": + pipe = HunyuanVideoPipeline transformer.requires_grad_(False) transformer_lora_config = LoraConfig( r=args.lora_rank, @@ -235,7 +237,7 @@ def main(args): transformer.add_adapter(transformer_lora_config) if args.resume_from_lora_checkpoint: - lora_state_dict = MochiPipeline.lora_state_dict( + lora_state_dict = pipe.lora_state_dict( args.resume_from_lora_checkpoint ) transformer_state_dict = { @@ -456,7 +458,7 @@ def main(args): if args.use_lora: # Save LoRA weights save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, step + transformer, optimizer, rank, args.output_dir, step, pipe ) else: # Your existing checkpoint saving code @@ -467,7 +469,7 @@ def main(args): if args.use_lora: save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, args.max_train_steps + transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipe ) else: save_checkpoint( @@ -481,7 +483,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--model_type", type=str, default="mochi", help="The type of model to train." + "--model_type", type=str, default="mochi", help="The type of model to train. Currentlt support [mochi, hunyuan_hf, hunyuan]" ) # dataset & dataloader parser.add_argument("--data_json_path", type=str, required=True) @@ -733,5 +735,7 @@ def main(args): default="fp32", help="Weight type to use - fp32 or bf16.", ) + args = parser.parse_args() main(args) + \ No newline at end of file diff --git a/fastvideo/utils/load.py b/fastvideo/utils/load.py index 7dce3c32..2039b25d 100644 --- a/fastvideo/utils/load.py +++ b/fastvideo/utils/load.py @@ -260,7 +260,7 @@ def load_transformer( pretrained_model_name_or_path, master_weight_type, ): - if model_type == "mochi_hf": + if model_type == "mochi": if dit_model_name_or_path: transformer = MochiTransformer3DModel.from_pretrained( dit_model_name_or_path, @@ -300,7 +300,7 @@ def load_transformer( def load_vae(model_type, pretrained_model_name_or_path): weight_dtype = torch.float32 - if model_type == "mochi_hf": + if model_type == "mochi": vae = AutoencoderKLMochi.from_pretrained( pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype ).to("cuda") @@ -346,7 +346,7 @@ def load_vae(model_type, pretrained_model_name_or_path): def load_text_encoder(model_type, pretrained_model_name_or_path, device): if model_type == "mochi": text_encoder = MochiTextEncoderWrapper(pretrained_model_name_or_path, device) - elif model_type == "hunyuan": + elif model_type == "hunyuan" or "hunyuan_hf": text_encoder = HunyuanTextEncoderWrapper(pretrained_model_name_or_path, device) else: raise ValueError(f"Unsupported model type: {model_type}") diff --git a/fastvideo/utils/validation.py b/fastvideo/utils/validation.py index aabbc5a5..be8751ea 100644 --- a/fastvideo/utils/validation.py +++ b/fastvideo/utils/validation.py @@ -222,7 +222,7 @@ def log_validation( vae_spatial_scale_factor = 8 vae_temporal_scale_factor = 6 num_channels_latents = 12 - elif args.model_type == "hunyuan": + elif args.model_type == "hunyuan" or "hunyuan_hf": vae_spatial_scale_factor = 8 vae_temporal_scale_factor = 4 num_channels_latents = 16 diff --git a/scripts/finetune/finetune_hunyuan_hf_lora.sh b/scripts/finetune/finetune_hunyuan_hf_lora.sh new file mode 100644 index 00000000..0a062166 --- /dev/null +++ b/scripts/finetune/finetune_hunyuan_hf_lora.sh @@ -0,0 +1,38 @@ +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_MODE=online + +torchrun --nnodes 1 --nproc_per_node 4 \ + fastvideo/train.py \ + --seed 1024 \ + --pretrained_model_name_or_path ~/data/hunyuan_diffusers \ + --model_type hunyuan_hf \ + --cache_dir ~/data/.cache \ + --data_json_path data/Image-Vid-Finetune-Src/videos2caption.json \ + --validation_prompt_dir data/Image-Vid-Finetune-Src/validation \ + --gradient_checkpointing \ + --train_batch_size 1 \ + --num_latent_t 24 \ + --sp_size 4 \ + --train_sp_batch_size 1 \ + --dataloader_num_workers 4 \ + --gradient_accumulation_steps 2 \ + --max_train_steps 2000 \ + --learning_rate 5e-6 \ + --mixed_precision bf16 \ + --checkpointing_steps 200 \ + --validation_steps 100 \ + --validation_sampling_steps 50 \ + --checkpoints_total_limit 3 \ + --allow_tf32 \ + --ema_start_step 0 \ + --cfg 0.0 \ + --ema_decay 0.999 \ + --log_validation \ + --output_dir ~/data/outputs/HSH-Taylor-Finetune-Hunyuan \ + --tracker_project_name HSH-Taylor-Finetune-Hunyuan \ + --num_frames 93 \ + --validation_guidance_scale "1.0" \ + --group_frame \ + --use_lora \ + --lora_rank 128 \ + --lora_alpha 256 From 594b65afd041fd8020b1d422f037c1225d4acca9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Fri, 27 Dec 2024 04:02:39 +0000 Subject: [PATCH 10/42] syn with main --- assets/prompt.txt | 4 +- assets/prompt_test.txt | 2 +- fastvideo/test/test_hunyuan_hf.py | 5 +- fastvideo/train_hunyuan_hf.py | 746 ------------------ scripts/finetune/finetune_hunyuan.sh | 19 +- scripts/finetune/finetune_mochi_lora.sh | 19 +- scripts/inference/inference_hunyuan.sh | 10 +- scripts/inference/inference_hunyuan_hf.sh | 13 +- scripts/inference/inference_mochi_sp.sh | 18 +- scripts/preprocess/preprocess_hunyuan_data.sh | 42 +- 10 files changed, 73 insertions(+), 805 deletions(-) delete mode 100644 fastvideo/train_hunyuan_hf.py diff --git a/assets/prompt.txt b/assets/prompt.txt index d10865e3..0d0a3c67 100644 --- a/assets/prompt.txt +++ b/assets/prompt.txt @@ -3,6 +3,6 @@ A lone hiker stands atop a towering cliff, silhouetted against the vast horizon. A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. A beige string bag sits beside the bowl, adding a rustic touch to the scene. Additional lemons, one halved, are scattered around the base of the bowl. The even lighting enhances the vibrant colors and creates a fresh, inviting atmosphere. A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest. The playful yet serene atmosphere is complemented by soft natural light filtering through the petals. Mid-shot, warm and cheerful tones. A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robots immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1 -fox in the forest close-up quickly turned its head to the left -Man walking his dog in the woods on a hot sunny day +fox in the forest close-up quickly turned its head to the left. +Man walking his dog in the woods on a hot sunny day. A majestic lion strides across the golden savanna, its powerful frame glistening under the warm afternoon sun. The tall grass ripples gently in the breeze, enhancing the lion's commanding presence. The tone is vibrant, embodying the raw energy of the wild. Low angle, steady tracking shot, cinematic. \ No newline at end of file diff --git a/assets/prompt_test.txt b/assets/prompt_test.txt index 11e03e47..333cef7e 100644 --- a/assets/prompt_test.txt +++ b/assets/prompt_test.txt @@ -1 +1 @@ -A majestic lion strides across the golden savanna, its powerful frame glistening under the warm afternoon sun. The tall grass ripples gently in the breeze, enhancing the lion's commanding presence. The tone is vibrant, embodying the raw energy of the wild. Low angle, steady tracking shot, cinematic. +Wukong stands prominently against a clear sky. Wukong's fur is dense and dark, framing an intense expression as they hold a large staff confidently across their shoulders. Elaborate golden armor intricately covers Wukong's torso, adorned with ornate designs and embellishments. With every subtle movement, Wukong exudes a sense of readiness and power, as if preparing for an impending challenge. The crown atop Wukong's head glints majestically in the sunlight, symbolizing leadership and authority. Wukong's fierce eyes remain focused and vigilant, capturing the viewer's attention with an aura of both mystery and strength. \ No newline at end of file diff --git a/fastvideo/test/test_hunyuan_hf.py b/fastvideo/test/test_hunyuan_hf.py index 9f312570..9f46cb57 100644 --- a/fastvideo/test/test_hunyuan_hf.py +++ b/fastvideo/test/test_hunyuan_hf.py @@ -10,8 +10,8 @@ def parse_args(): parser = argparse.ArgumentParser(description='Generate video using Hunyuan model') - parser.add_argument('--prompt', type=str, required=True, help='Text prompt for video generation') - parser.add_argument('--model_path', type=str, default='/root/hunyuan_hf/', help='Path to the Hunyuan model directory') + parser.add_argument('--prompt', type=str, default="", help='Text prompt for video generation') + parser.add_argument('--model_path', type=str, default="/mbz/users/hao.zhang/data/hunyuan_diffusers", help='Path to the Hunyuan model directory') parser.add_argument('--output_dir', type=str, default='outputs_video/hunyuan_hf', help='Directory to save the output video') parser.add_argument('--height', type=int, default=480, help='Height of the output video') parser.add_argument('--width', type=int, default=848, help='Width of the output video') @@ -34,6 +34,7 @@ def main(): "Man walking his dog in the woods on a hot sunny day", "A majestic lion strides across the golden savanna, its powerful frame glistening under the warm afternoon sun. The tall grass ripples gently in the breeze, enhancing the lion's commanding presence. The tone is vibrant, embodying the raw energy of the wild. Low angle, steady tracking shot, cinematic."] # Set random seed + #args.prompt = "Will Smith casually eats noodles, his relaxed demeanor contrasting with the energetic background of a bustling street food market. The scene captures a mix of humor and authenticity. Mid-shot framing, vibrant lighting." generator = torch.Generator("cpu").manual_seed(args.seed) # Load transformer model transformer = HunyuanVideoTransformer3DModel.from_pretrained( diff --git a/fastvideo/train_hunyuan_hf.py b/fastvideo/train_hunyuan_hf.py deleted file mode 100644 index 13d4f9ce..00000000 --- a/fastvideo/train_hunyuan_hf.py +++ /dev/null @@ -1,746 +0,0 @@ -import argparse -from email.policy import strict -import logging -import math -import os -import shutil -from pathlib import Path -from fastvideo.utils.parallel_states import ( - initialize_sequence_parallel_state, - destroy_sequence_parallel_group, - get_sequence_parallel_state, - nccl_info, -) -from fastvideo.utils.communications import sp_parallel_dataloader_wrapper, broadcast -from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input -from fastvideo.utils.validation import log_validation -import time -from torch.utils.data import DataLoader -import torch -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - StateDictType, - FullStateDictConfig, -) -import json -from torch.utils.data.distributed import DistributedSampler -from fastvideo.utils.dataset_utils import LengthGroupedSampler -import wandb -from accelerate.utils import set_seed -from tqdm.auto import tqdm -from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs, apply_fsdp_checkpointing -from diffusers.utils import convert_unet_state_dict_to_peft -from diffusers import FlowMatchEulerDiscreteScheduler -from fastvideo.utils.load import load_transformer -from diffusers.optimization import get_scheduler -from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel -from diffusers.utils import check_min_version -from fastvideo.dataset.latent_datasets import LatentDataset, latent_collate_function -import torch.distributed as dist -from safetensors.torch import save_file, load_file -from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from fastvideo.utils.checkpoint import ( - save_checkpoint, - save_lora_checkpoint, - resume_lora_optimizer, -) -from fastvideo.utils.logging_ import main_print -from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline -from diffusers.pipelines.hunyuan_video import HunyuanVideoPipeline - -# Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0") -import time -from collections import deque - - -def compute_density_for_timestep_sampling( - weighting_scheme: str, - batch_size: int, - generator, - logit_mean: float = None, - logit_std: float = None, - mode_scale: float = None, -): - """ - Compute the density for sampling the timesteps when doing SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal( - mean=logit_mean, - std=logit_std, - size=(batch_size,), - device="cpu", - generator=generator, - ) - u = torch.nn.functional.sigmoid(u) - elif weighting_scheme == "mode": - u = torch.rand(size=(batch_size,), device="cpu", generator=generator) - u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - else: - u = torch.rand(size=(batch_size,), device="cpu", generator=generator) - return u - - -def get_sigmas(noise_scheduler, device, timesteps, n_dim=4, dtype=torch.float32): - sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) - schedule_timesteps = noise_scheduler.timesteps.to(device) - timesteps = timesteps.to(device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - return sigma - - -def train_one_step( - transformer, - model_type, - optimizer, - lr_scheduler, - loader, - noise_scheduler, - noise_random_generator, - gradient_accumulation_steps, - sp_size, - precondition_outputs, - max_grad_norm, - weighting_scheme, - logit_mean, - logit_std, - mode_scale, -): - total_loss = 0.0 - optimizer.zero_grad() - for _ in range(gradient_accumulation_steps): - ( - latents, - encoder_hidden_states, - latents_attention_mask, - encoder_attention_mask, - ) = next(loader) - latents = normalize_dit_input(model_type, latents) - batch_size = latents.shape[0] - noise = torch.randn_like(latents) - u = compute_density_for_timestep_sampling( - weighting_scheme=weighting_scheme, - batch_size=batch_size, - generator=noise_random_generator, - logit_mean=logit_mean, - logit_std=logit_std, - mode_scale=mode_scale, - ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() - timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) - if sp_size > 1: - # Make sure that the timesteps are the same across all sp processes. - broadcast(timesteps) - sigmas = get_sigmas( - noise_scheduler, - latents.device, - timesteps, - n_dim=latents.ndim, - dtype=latents.dtype, - ) - noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise - # if rank<=0: - # print("2222222222222222222222222222222222222222222222") - # print(type(latents_attention_mask)) - # print(latents_attention_mask) - with torch.autocast("cuda", torch.bfloat16): - model_pred = transformer( - noisy_model_input, - encoder_hidden_states, - timesteps, - encoder_attention_mask, # B, L - return_dict=False, - )[0] - # if rank<=0: - # print("333333333333333333333333333333333333333333333333") - if precondition_outputs: - model_pred = noisy_model_input - model_pred * sigmas - if precondition_outputs: - target = latents - else: - target = noise - latents - - loss = ( - torch.mean((model_pred.float() - target.float()) ** 2) - / gradient_accumulation_steps - ) - - loss.backward() - - avg_loss = loss.detach().clone() - dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG) - total_loss += avg_loss.item() - - grad_norm = transformer.clip_grad_norm_(max_grad_norm) - optimizer.step() - lr_scheduler.step() - return total_loss, grad_norm.item() - - -def main(args): - torch.backends.cuda.matmul.allow_tf32 = True - - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - dist.init_process_group("nccl") - torch.cuda.set_device(local_rank) - device = torch.cuda.current_device() - initialize_sequence_parallel_state(args.sp_size) - - # If passed along, set the training seed now. On GPU... - if args.seed is not None: - # TODO: t within the same seq parallel group should be the same. Noise should be different. - set_seed(args.seed + rank) - # We use different seeds for the noise generation in each process to ensure that the noise is different in a batch. - noise_random_generator = None - - # Handle the repository creation - if rank <= 0 and args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - # For mixed precision training we cast all non-trainable weigths to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. - - # Create model: - - main_print(f"--> loading model from {args.pretrained_model_name_or_path}") - # keep the master weight to float32 - transformer = load_transformer( - args.model_type, - args.dit_model_name_or_path, - args.pretrained_model_name_or_path, - torch.float32 if args.master_weight_type == "fp32" else torch.bfloat16, - ) - - if args.use_lora: - #assert args.model_type == "mochi", "LoRA is only supported for Mochi model." - if args.model_type == "mochi": - pipeline = MochiPipeline - elif args.model_type == "hunyuan": - pipeline = HunyuanVideoPipeline - transformer.requires_grad_(False) - transformer_lora_config = LoraConfig( - r=args.lora_rank, - lora_alpha=args.lora_alpha, - init_lora_weights=True, - target_modules=["to_k", "to_q", "to_v", "to_out.0"], - ) - from IPython import embed - embed() - - transformer.add_adapter(transformer_lora_config) - - if args.resume_from_lora_checkpoint: - lora_state_dict = pipeline.lora_state_dict( - args.resume_from_lora_checkpoint - ) - transformer_state_dict = { - f'{k.replace("transformer.", "")}': v - for k, v in lora_state_dict.items() - if k.startswith("transformer.") - } - transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) - incompatible_keys = set_peft_model_state_dict( - transformer, transformer_state_dict, adapter_name="default" - ) - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - main_print( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) - - main_print( - f" Total training parameters = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e6} M" - ) - main_print( - f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}" - ) - fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs( - transformer, - args.fsdp_sharding_startegy, - args.use_lora, - args.use_cpu_offload, - args.master_weight_type, - ) - - if args.use_lora: - transformer.config.lora_rank = args.lora_rank - transformer.config.lora_alpha = args.lora_alpha - transformer.config.lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] - transformer._no_split_modules = [ - no_split_module.__name__ for no_split_module in no_split_modules - ] - fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer) - - transformer = FSDP(transformer, **fsdp_kwargs,) - main_print(f"--> model loaded") - - if args.gradient_checkpointing: - apply_fsdp_checkpointing( - transformer, no_split_modules, args.selective_checkpointing - ) - - # Set model as trainable. - transformer.train() - - noise_scheduler = FlowMatchEulerDiscreteScheduler() - - params_to_optimize = transformer.parameters() - params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize)) - - optimizer = torch.optim.AdamW( - params_to_optimize, - lr=args.learning_rate, - betas=(0.9, 0.999), - weight_decay=args.weight_decay, - eps=1e-8, - ) - - init_steps = 0 - if args.resume_from_lora_checkpoint: - transformer, optimizer, init_steps = resume_lora_optimizer( - transformer, args.resume_from_lora_checkpoint, optimizer - ) - main_print(f"optimizer: {optimizer}") - - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps, - num_cycles=args.lr_num_cycles, - power=args.lr_power, - last_epoch=init_steps - 1, - ) - - train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg) - sampler = ( - LengthGroupedSampler( - args.train_batch_size, - rank=rank, - world_size=world_size, - lengths=train_dataset.lengths, - group_frame=args.group_frame, - group_resolution=args.group_resolution, - ) - if (args.group_frame or args.group_resolution) - else DistributedSampler( - train_dataset, rank=rank, num_replicas=world_size, shuffle=False - ) - ) - - train_dataloader = DataLoader( - train_dataset, - sampler=sampler, - collate_fn=latent_collate_function, - pin_memory=True, - batch_size=args.train_batch_size, - num_workers=args.dataloader_num_workers, - drop_last=True, - ) - - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) - / args.gradient_accumulation_steps - * args.sp_size - / args.train_sp_batch_size - ) - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - if rank <= 0: - project = args.tracker_project_name or "fastvideo" - wandb.init(project=project, config=args) - - # Train! - total_batch_size = ( - args.train_batch_size - * world_size - * args.gradient_accumulation_steps - / args.sp_size - * args.train_sp_batch_size - ) - main_print("***** Running training *****") - main_print(f" Num examples = {len(train_dataset)}") - main_print(f" Dataloader size = {len(train_dataloader)}") - main_print(f" Num Epochs = {args.num_train_epochs}") - main_print(f" Resume training from step {init_steps}") - main_print(f" Instantaneous batch size per device = {args.train_batch_size}") - main_print( - f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}" - ) - main_print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - main_print(f" Total optimization steps = {args.max_train_steps}") - main_print( - f" Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B" - ) - # print dtype - main_print(f" Master weight dtype: {transformer.parameters().__next__().dtype}") - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - assert NotImplementedError("resume_from_checkpoint is not supported now.") - # TODO - - progress_bar = tqdm( - range(0, args.max_train_steps), - initial=init_steps, - desc="Steps", - # Only show the progress bar once on each machine. - disable=local_rank > 0, - ) - - loader = sp_parallel_dataloader_wrapper( - train_dataloader, - device, - args.train_batch_size, - args.sp_size, - args.train_sp_batch_size, - ) - - step_times = deque(maxlen=100) - - # todo future - for i in range(init_steps): - next(loader) - for step in range(init_steps + 1, args.max_train_steps + 1): - start_time = time.time() - loss, grad_norm = train_one_step( - transformer, - args.model_type, - optimizer, - lr_scheduler, - loader, - noise_scheduler, - noise_random_generator, - args.gradient_accumulation_steps, - args.sp_size, - args.precondition_outputs, - args.max_grad_norm, - args.weighting_scheme, - args.logit_mean, - args.logit_std, - args.mode_scale, - ) - - step_time = time.time() - start_time - step_times.append(step_time) - avg_step_time = sum(step_times) / len(step_times) - - progress_bar.set_postfix( - { - "loss": f"{loss:.4f}", - "step_time": f"{step_time:.2f}s", - "grad_norm": grad_norm, - } - ) - progress_bar.update(1) - if rank <= 0: - wandb.log( - { - "train_loss": loss, - "learning_rate": lr_scheduler.get_last_lr()[0], - "step_time": step_time, - "avg_step_time": avg_step_time, - "grad_norm": grad_norm, - }, - step=step, - ) - if step % args.checkpointing_steps == 0: - if args.use_lora: - # Save LoRA weights - save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, step, pipeline - ) - else: - # Your existing checkpoint saving code - save_checkpoint(transformer, optimizer, rank, args.output_dir, step) - dist.barrier() - if args.log_validation and step % args.validation_steps == 0: - log_validation(args, transformer, device, torch.bfloat16, step) - - if args.use_lora: - save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipeline - ) - else: - save_checkpoint( - transformer, optimizer, rank, args.output_dir, args.max_train_steps - ) - - if get_sequence_parallel_state(): - destroy_sequence_parallel_group() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--model_type", type=str, default="mochi", help="The type of model to train." - ) - # dataset & dataloader - parser.add_argument("--data_json_path", type=str, required=True) - parser.add_argument("--num_frames", type=int, default=163) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=10, - help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", - ) - parser.add_argument( - "--train_batch_size", - type=int, - default=16, - help="Batch size (per device) for the training dataloader.", - ) - parser.add_argument( - "--num_latent_t", type=int, default=28, help="Number of latent timesteps." - ) - parser.add_argument("--group_frame", action="store_true") # TODO - parser.add_argument("--group_resolution", action="store_true") # TODO - - # text encoder & vae & diffusion model - parser.add_argument("--pretrained_model_name_or_path", type=str) - parser.add_argument("--dit_model_name_or_path", type=str, default=None) - parser.add_argument("--cache_dir", type=str, default="./cache_dir") - - # diffusion setting - parser.add_argument("--ema_decay", type=float, default=0.999) - parser.add_argument("--ema_start_step", type=int, default=0) - parser.add_argument("--cfg", type=float, default=0.1) - parser.add_argument( - "--precondition_outputs", - action="store_true", - help="Whether to precondition the outputs of the model.", - ) - - # validation & logs - parser.add_argument("--validation_prompt_dir", type=str) - parser.add_argument("--uncond_prompt_dir", type=str) - parser.add_argument( - "--validation_sampling_steps", - type=str, - default="64", - help="use ',' to split multi sampling steps", - ) - parser.add_argument( - "--validation_guidance_scale", - type=str, - default="4.5", - help="use ',' to split multi scale", - ) - parser.add_argument("--validation_steps", type=int, default=50) - parser.add_argument("--log_validation", action="store_true") - parser.add_argument("--tracker_project_name", type=str, default=None) - parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) - parser.add_argument( - "--output_dir", - type=str, - default=None, - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument( - "--checkpoints_total_limit", - type=int, - default=None, - help=("Max number of checkpoints to store."), - ) - parser.add_argument( - "--checkpointing_steps", - type=int, - default=500, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" - " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" - " training using `--resume_from_checkpoint`." - ), - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), - ) - parser.add_argument( - "--resume_from_lora_checkpoint", - type=str, - default=None, - help=( - "Whether training should be resumed from a previous lora checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) - - # optimizer & scheduler & Training - parser.add_argument("--num_train_epochs", type=int, default=100) - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--scale_lr", - action="store_true", - default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", - ) - parser.add_argument( - "--lr_warmup_steps", - type=int, - default=10, - help="Number of steps for the warmup in the lr scheduler.", - ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) - parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) - parser.add_argument("--selective_checkpointing", type=float, default=1.0) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument( - "--use_cpu_offload", - action="store_true", - help="Whether to use CPU offload for param & gradient & optimizer states.", - ) - - parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel") - parser.add_argument( - "--train_sp_batch_size", - type=int, - default=1, - help="Batch size for sequence parallel training", - ) - - parser.add_argument( - "--use_lora", - action="store_true", - default=False, - help="Whether to use LoRA for finetuning.", - ) - parser.add_argument( - "--lora_alpha", type=int, default=256, help="Alpha parameter for LoRA." - ) - parser.add_argument( - "--lora_rank", type=int, default=128, help="LoRA rank parameter. " - ) - parser.add_argument("--fsdp_sharding_startegy", default="full") - - parser.add_argument( - "--weighting_scheme", - type=str, - default="uniform", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"], - ) - parser.add_argument( - "--logit_mean", - type=float, - default=0.0, - help="mean to use when using the `'logit_normal'` weighting scheme.", - ) - parser.add_argument( - "--logit_std", - type=float, - default=1.0, - help="std to use when using the `'logit_normal'` weighting scheme.", - ) - parser.add_argument( - "--mode_scale", - type=float, - default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", - ) - # lr_scheduler - parser.add_argument( - "--lr_scheduler", - type=str, - default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_num_cycles", - type=int, - default=1, - help="Number of cycles in the learning rate scheduler.", - ) - parser.add_argument( - "--lr_power", - type=float, - default=1.0, - help="Power factor of the polynomial scheduler.", - ) - parser.add_argument( - "--weight_decay", type=float, default=0.01, help="Weight decay to apply." - ) - parser.add_argument( - "--master_weight_type", - type=str, - default="fp32", - help="Weight type to use - fp32 or bf16.", - ) - - args = parser.parse_args() - main(args) diff --git a/scripts/finetune/finetune_hunyuan.sh b/scripts/finetune/finetune_hunyuan.sh index 5509dfa7..621023ca 100644 --- a/scripts/finetune/finetune_hunyuan.sh +++ b/scripts/finetune/finetune_hunyuan.sh @@ -8,19 +8,19 @@ torchrun --nnodes 1 --nproc_per_node 8 \ --dit_model_name_or_path data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt\ --model_type "hunyuan" \ --cache_dir data/.cache \ - --data_json_path data/Image-Vid-Finetune-HunYuan/videos2caption.json \ - --validation_prompt_dir data/Image-Vid-Finetune-HunYuan/validation \ + --data_json_path data/Image-Vid-Finetune-Src/videos2caption_overfit.json \ + --validation_prompt_dir data/Image-Vid-Finetune-Src/validation \ --gradient_checkpointing \ - --train_batch_size=1 \ + --train_batch_size 1 \ --num_latent_t 24 \ --sp_size 4 \ --train_sp_batch_size 1 \ --dataloader_num_workers 4 \ - --gradient_accumulation_steps=1 \ - --max_train_steps=2000 \ - --learning_rate=5e-6 \ - --mixed_precision=bf16 \ - --checkpointing_steps=200 \ + --gradient_accumulation_steps 1 \ + --max_train_steps 2000 \ + --learning_rate 5e-6 \ + --mixed_precision bf16 \ + --checkpointing_steps 200 \ --validation_steps 100 \ --validation_sampling_steps 64 \ --checkpoints_total_limit 3 \ @@ -29,8 +29,9 @@ torchrun --nnodes 1 --nproc_per_node 8 \ --cfg 0.0 \ --ema_decay 0.999 \ --log_validation \ - --output_dir=data/outputs/HSH-Taylor-Finetune-Hunyuan \ + --output_dir ~/data/outputs/test/ \ --tracker_project_name HSH-Taylor-Finetune-Hunyuan \ --num_frames 93 \ --validation_guidance_scale "1.0" \ + --master_weight_type fp32 \ --group_frame \ No newline at end of file diff --git a/scripts/finetune/finetune_mochi_lora.sh b/scripts/finetune/finetune_mochi_lora.sh index 2f5a39bd..4015a61d 100644 --- a/scripts/finetune/finetune_mochi_lora.sh +++ b/scripts/finetune/finetune_mochi_lora.sh @@ -1,17 +1,18 @@ export WANDB_BASE_URL="https://api.wandb.ai" export WANDB_MODE=online -torchrun --nnodes 1 --nproc_per_node 2 \ - fastvideo/train.py \ +CUDA_VISIBLE_DEVICES=6 torchrun --nnodes 1 --nproc_per_node 1 --master_port 29403 \ + fastvideo/train_new.py \ --seed 42 \ - --pretrained_model_name_or_path data/mochi \ + --model_type mochi \ + --pretrained_model_name_or_path ~/data/mochi_diffusers \ --cache_dir data/.cache \ - --data_json_path data/Mochi-Black-Myth/videos2caption.json \ - --validation_prompt_dir data/Mochi-Black-Myth/validation \ + --data_json_path data/Encoder_Overfit_Data/videos2caption.json \ + --validation_prompt_dir data/validation_prompt_embed_mask \ --gradient_checkpointing \ --train_batch_size 1 \ - --num_latent_t 14 \ - --sp_size 2 \ + --num_latent_t 2 \ + --sp_size 1 \ --train_sp_batch_size 1 \ --dataloader_num_workers 1 \ --gradient_accumulation_steps 2 \ @@ -27,11 +28,11 @@ torchrun --nnodes 1 --nproc_per_node 2 \ --cfg 0.0 \ --ema_decay 0.999 \ --log_validation \ - --output_dir=data/outputs/Black-Myth-Lora-FT \ + --output_dir data/outputs/Black-Myth-Lora-FT \ --tracker_project_name Black-Myth-Lora-Finetune \ --num_frames 91 \ --lora_rank 128 \ --lora_alpha 256 \ - --master_weight_type "bf16" \ + --master_weight_type fp32 \ --use_lora \ --use_cpu_offload diff --git a/scripts/inference/inference_hunyuan.sh b/scripts/inference/inference_hunyuan.sh index 8ff2a1da..fb419b73 100644 --- a/scripts/inference/inference_hunyuan.sh +++ b/scripts/inference/inference_hunyuan.sh @@ -1,8 +1,8 @@ #!/bin/bash -num_gpus=4 -export MODEL_BASE=/root/hunyuan -torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ +num_gpus=1 +export MODEL_BASE=~/data/hunyuan/ +torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29703 \ fastvideo/sample/sample_t2v_hunyuan.py \ --height 480 \ --width 848 \ @@ -10,10 +10,10 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ --num_inference_steps 50 \ --guidance_scale 1 \ --embedded_cfg_scale 6 \ - --flow_shift 17 \ + --flow_shift 7 \ --flow-reverse \ --prompt ./assets/prompt.txt \ --seed 1024 \ - --output_path outputs_video/hunyuan/cfg6/ \ + --output_path outputs_video/hunyuan/ \ --model_path $MODEL_BASE \ --dit-weight ${MODEL_BASE}/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt diff --git a/scripts/inference/inference_hunyuan_hf.sh b/scripts/inference/inference_hunyuan_hf.sh index abd36a4e..d957acea 100644 --- a/scripts/inference/inference_hunyuan_hf.sh +++ b/scripts/inference/inference_hunyuan_hf.sh @@ -3,19 +3,16 @@ num_gpus=4 torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ fastvideo/sample/sample_t2v_hunyuan_hf.py \ - --model_path ~/hunyuan_hf/ \ - --prompt_path "assets/prompt_test.txt" \ + --model_path ~/data/hunyuan_diffusers/ \ + --prompt_path "assets/prompt.txt" \ --num_frames 93 \ --height 480 \ --width 848 \ --num_inference_steps 50 \ - --guidance_scale 1.5 \ - --output_path outputs_video/hunyuan_hf/ \ + --guidance_scale 6 \ + --output_path outputs_video/hunyuan_hf_test/ \ --seed 1024 \ - --linear_threshold 0.1 \ - --flow_shift 17 \ - --flow-reverse \ - --linear_range 0.75 \ + diff --git a/scripts/inference/inference_mochi_sp.sh b/scripts/inference/inference_mochi_sp.sh index 4ee373f9..48a4c527 100644 --- a/scripts/inference/inference_mochi_sp.sh +++ b/scripts/inference/inference_mochi_sp.sh @@ -1,9 +1,9 @@ #!/bin/bash num_gpus=4 -torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29603 \ fastvideo/sample/sample_t2v_mochi.py \ - --model_path ~/mochi \ + --model_path ~/data/mochi_diffusers/ \ --prompt_path "assets/prompt.txt" \ --num_frames 91 \ --height 480 \ @@ -16,3 +16,17 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ --linear_threshold 0.1 \ --linear_range 0.75 +num_gpus=4 +torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ + fastvideo/sample/sample_t2v_mochi.py \ + --model_path ~/data/mochi_diffusers/ \ + --prompt_path "assets/prompt_test.txt" \ + --num_frames 163 \ + --height 480 \ + --width 848 \ + --num_inference_steps 64 \ + --guidance_scale 6 \ + --output_path outputs_video/debug \ + --shift 8 \ + --seed 12345 \ + --scheduler_type "pcm_linear_quadratic" diff --git a/scripts/preprocess/preprocess_hunyuan_data.sh b/scripts/preprocess/preprocess_hunyuan_data.sh index fc452337..f4e64ce7 100644 --- a/scripts/preprocess/preprocess_hunyuan_data.sh +++ b/scripts/preprocess/preprocess_hunyuan_data.sh @@ -1,29 +1,29 @@ # export WANDB_MODE="offline" -GPU_NUM=1 # 2,4,8 -MODEL_PATH="data/hunyuan" +GPU_NUM=1 +MODEL_PATH="/mbz/users/hao.zhang/data/hunyuan" MODEL_TYPE="hunyuan" DATA_MERGE_PATH="data/Image-Vid-Finetune-Src/merge.txt" -OUTPUT_DIR="data/Image-Vid-Finetune-HunYuan" -VALIDATION_PATH="assets/prompt.txt" +OUTPUT_DIR="data/Image-Vid-Finetune-Src" +VALIDATION_PATH="assets/prompt_HSH.txt" -torchrun --nproc_per_node=$GPU_NUM \ - fastvideo/data_preprocess/preprocess_vae_latents.py \ - --model_path $MODEL_PATH \ - --data_merge_path $DATA_MERGE_PATH \ - --train_batch_size=1 \ - --max_height=480 \ - --max_width=848 \ - --num_frames=93 \ - --dataloader_num_workers 1 \ - --output_dir=$OUTPUT_DIR \ - --model_type $MODEL_TYPE \ - --train_fps 24 +# torchrun --nproc_per_node=$GPU_NUM \ +# fastvideo/data_preprocess/preprocess_vae_latents.py \ +# --model_path $MODEL_PATH \ +# --data_merge_path $DATA_MERGE_PATH \ +# --train_batch_size=1 \ +# --max_height=480 \ +# --max_width=848 \ +# --num_frames=93 \ +# --dataloader_num_workers 1 \ +# --output_dir=$OUTPUT_DIR \ +# --model_type $MODEL_TYPE \ +# --train_fps 24 -torchrun --nproc_per_node=$GPU_NUM \ - fastvideo/data_preprocess/preprocess_text_embeddings.py \ - --model_type $MODEL_TYPE \ - --model_path $MODEL_PATH \ - --output_dir=$OUTPUT_DIR +# torchrun --nproc_per_node=$GPU_NUM \ +# fastvideo/data_preprocess/preprocess_text_embeddings.py \ +# --model_type $MODEL_TYPE \ +# --model_path $MODEL_PATH \ +# --output_dir=$OUTPUT_DIR torchrun --nproc_per_node=1 \ fastvideo/data_preprocess/preprocess_validation_text_embeddings.py \ From 36560ebed3f5a1bca86363026c644c13d55b7752 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Sun, 29 Dec 2024 21:50:17 +0000 Subject: [PATCH 11/42] unified hunyuan hf --- fastvideo/sample/sample_t2v_hunyuan_hf.py | 265 +++++++++++++++--- fastvideo/test/test_hunyuan_hf.py | 82 ------ fastvideo/utils/checkpoint.py | 10 +- .../inference_hunyuan_hf_quantization.sh | 19 ++ 4 files changed, 256 insertions(+), 120 deletions(-) delete mode 100644 fastvideo/test/test_hunyuan_hf.py create mode 100644 scripts/inference/inference_hunyuan_hf_quantization.sh diff --git a/fastvideo/sample/sample_t2v_hunyuan_hf.py b/fastvideo/sample/sample_t2v_hunyuan_hf.py index 620add3b..a7369826 100644 --- a/fastvideo/sample/sample_t2v_hunyuan_hf.py +++ b/fastvideo/sample/sample_t2v_hunyuan_hf.py @@ -1,28 +1,22 @@ import torch import torch.distributed as dist - +from diffusers import BitsAndBytesConfig from diffusers.utils import export_to_video +import imageio as iio +import math +import numpy as np +import io +import time +import argparse +import os +import json from fastvideo.utils.parallel_states import ( initialize_sequence_parallel_state, nccl_info, ) -import argparse -import os -import json -from typing import Optional -from safetensors.torch import save_file, load_file -from peft import set_peft_model_state_dict, inject_adapter_in_model, load_peft_weights -from peft import LoraConfig -import sys -import pdb -import copy -from typing import Dict -from diffusers import FlowMatchEulerDiscreteScheduler -from diffusers.utils import convert_unet_state_dict_to_peft -from fastvideo.distill.solver import PCMFMScheduler -from fastvideo.models.hunyuan.diffusion.schedulers import FlowMatchDiscreteScheduler from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline from fastvideo.models.hunyuan_hf.modeling_hunyuan import HunyuanVideoTransformer3DModel + def initialize_distributed(): os.environ["TOKENIZERS_PARALLELISM"] = "false" local_rank = int(os.getenv("RANK", 0)) @@ -33,9 +27,8 @@ def initialize_distributed(): backend="nccl", init_method="env://", world_size=world_size, rank=local_rank ) initialize_sequence_parallel_state(world_size) - - -def main(args): + +def inference(args): initialize_distributed() print(nccl_info.sp_size) device = torch.cuda.current_device() @@ -50,7 +43,7 @@ def main(args): ) pipe = HunyuanVideoPipeline.from_pretrained( - args.model_path, transformer=transformer, torch_dtype=torch.float16 + args.model_path, transformer=transformer, torch_dtype=weight_dtype ) pipe.enable_vae_tiling() @@ -113,7 +106,7 @@ def main(args): export_to_video( video[0], os.path.join(args.output_path, f"{suffix}.mp4"), - fps=30, + fps=24, ) else: with torch.autocast("cuda", dtype=torch.bfloat16): @@ -129,32 +122,230 @@ def main(args): ).frames if nccl_info.global_rank <= 0: - export_to_video(videos[0], args.output_path + ".mp4", fps=30) + export_to_video(videos[0], args.output_path + ".mp4", fps=24) +def inference_quantization(args): + torch.manual_seed(args.seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + prompt_template = { + "template": ( + "<|start_header_cid|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the contents, including objects, people, and anything else." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the contents." + "4. Background environment, light, style, atmosphere, and qualities." + "5. Camera angles, movements, and transitions used in the video." + "6. Thematic and aesthetic concepts associated with the scene, i.e. realistic, futuristic, fairy tale, etc<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, + } + + + model_id = args.model_path + + if args.quantization == "nf4": + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", llm_int8_skip_modules=["proj_out", "norm_out"]) + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer/" ,torch_dtype=torch.bfloat16, quantization_config=quantization_config + ) + if args.quantization == "int8": + quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out", "norm_out"]) + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer/" ,torch_dtype=torch.bfloat16, quantization_config=quantization_config + ) + elif not args.quantization: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer/" ,torch_dtype=torch.bfloat16 + ).to(device) + + print("Max vram for read transofrmer:", round(torch.cuda.max_memory_allocated(device="cuda") / 1024 ** 3, 3), "GiB") + torch.cuda.reset_max_memory_allocated(device) + + if not args.cpu_offload: + pipe = HunyuanVideoPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device) + pipe.transformer = transformer + else: + pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16) + torch.cuda.reset_max_memory_allocated(device) + pipe.scheduler._shift = args.flow_shift + pipe.vae.enable_tiling() + if args.cpu_offload: + pipe.enable_model_cpu_offload() + print("Max vram for init pipeline:", round(torch.cuda.max_memory_allocated(device="cuda") / 1024 ** 3, 3), "GiB") + with open(args.prompt) as f: + prompts = f.readlines() + + generator = torch.Generator("cpu").manual_seed(args.seed) + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + torch.cuda.reset_max_memory_allocated(device) + for prompt in prompts: + start_time = time.perf_counter() + output = pipe( + prompt=prompt, + height = args.height, + width = args.width, + num_frames = args.num_frames, + prompt_template=prompt_template, + num_inference_steps = args.num_inference_steps, + generator=generator, + ).frames[0] + export_to_video(output, os.path.join(args.output_path, f"{prompt[:100]}.mp4"), fps=args.fps) + print("Time:", round(time.perf_counter() - start_time, 2), "seconds") + print("Max vram for denoise:", round(torch.cuda.max_memory_allocated(device="cuda") / 1024 ** 3, 3), "GiB") if __name__ == "__main__": - # arg parse parser = argparse.ArgumentParser() - parser.add_argument("--prompts", nargs="+", default=[]) - parser.add_argument("--num_frames", type=int, default=163) - parser.add_argument("--height", type=int, default=480) - parser.add_argument("--width", type=int, default=848) - parser.add_argument("--num_inference_steps", type=int, default=64) - parser.add_argument("--guidance_scale", type=float, default=4.5) - parser.add_argument("--model_name", type=str, default="hunyuan") - parser.add_argument("--model_path", type=str, default="data/hunyuan") - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--output_path", type=str, default="./outputs.mp4") - parser.add_argument("--transformer_path", type=str, default=None) + + # Basic parameters + parser.add_argument("--prompt", type=str, help="prompt file for inference") parser.add_argument("--prompt_embed_path", type=str, default=None) parser.add_argument("--prompt_path", type=str, default=None) - parser.add_argument("--scheduler_type", type=str, default="euler") - parser.add_argument("--encoder_attention_mask_path", type=str, default=None) + parser.add_argument("--num_frames", type=int, default=16) + parser.add_argument("--height", type=int, default=256) + parser.add_argument("--width", type=int, default=256) + parser.add_argument("--num_inference_steps", type=int, default=50) + parser.add_argument("--model_path", type=str, default="data/hunyuan") + parser.add_argument("--transformer_path", type=str, default=None) + parser.add_argument("--output_path", type=str, default="./outputs/video") + parser.add_argument("--fps", type=int, default=24) + parser.add_argument("--quantization", type=str, default=None) + parser.add_argument("--cpu_offload", action="store_true") parser.add_argument( "--lora_checkpoint_dir", type=str, default=None, help="Path to the directory containing LoRA checkpoints", ) + # Additional parameters + parser.add_argument( + "--denoise-type", + type=str, + default="flow", + help="Denoise type for noised inputs.", + ) + parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") + parser.add_argument( + "--neg_prompt", type=str, default=None, help="Negative prompt for sampling." + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=1.0, + help="Classifier free guidance scale.", + ) + parser.add_argument( + "--embedded_cfg_scale", + type=float, + default=6.0, + help="Embedded classifier free guidance scale.", + ) + parser.add_argument( + "--flow_shift", type=int, default=7, help="Flow shift parameter." + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size for inference." + ) + parser.add_argument( + "--num_videos", + type=int, + default=1, + help="Number of videos to generate per prompt.", + ) + parser.add_argument( + "--load-key", + type=str, + default="module", + help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.", + ) + parser.add_argument( + "--use-cpu-offload", + action="store_true", + help="Use CPU offload for the model load.", + ) + parser.add_argument( + "--dit-weight", + type=str, + default="data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", + ) + parser.add_argument( + "--reproduce", + action="store_true", + help="Enable reproducibility by setting random seeds and deterministic algorithms.", + ) + parser.add_argument( + "--disable-autocast", + action="store_true", + help="Disable autocast for denoising loop and vae decoding in pipeline sampling.", + ) + + # Flow Matching + parser.add_argument( + "--flow-reverse", + action="store_true", + help="If reverse, learning/sampling from t=1 -> t=0.", + ) + parser.add_argument( + "--flow-solver", type=str, default="euler", help="Solver for flow matching." + ) + parser.add_argument( + "--use-linear-quadratic-schedule", + action="store_true", + help="Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)", + ) + parser.add_argument( + "--linear-schedule-end", + type=int, + default=25, + help="End step for linear quadratic schedule for flow matching.", + ) + + # Model parameters + parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill") + parser.add_argument("--latent-channels", type=int, default=16) + parser.add_argument( + "--precision", type=str, default="bf16", choices=["fp32", "fp16", "bf16", "fp8"] + ) + parser.add_argument( + "--rope-theta", type=int, default=256, help="Theta used in RoPE." + ) + + parser.add_argument("--vae", type=str, default="884-16c-hy") + parser.add_argument( + "--vae-precision", type=str, default="fp16", choices=["fp32", "fp16", "bf16"] + ) + parser.add_argument("--vae-tiling", action="store_true", default=True) + + parser.add_argument("--text-encoder", type=str, default="llm") + parser.add_argument( + "--text-encoder-precision", + type=str, + default="fp16", + choices=["fp32", "fp16", "bf16"], + ) + parser.add_argument("--text-states-dim", type=int, default=4096) + parser.add_argument("--text-len", type=int, default=256) + parser.add_argument("--tokenizer", type=str, default="llm") + parser.add_argument("--prompt-template", type=str, default="dit-llm-encode") + parser.add_argument( + "--prompt-template-video", type=str, default="dit-llm-encode-video" + ) + parser.add_argument("--hidden-state-skip-layer", type=int, default=2) + parser.add_argument("--apply-final-norm", action="store_true") + + parser.add_argument("--text-encoder-2", type=str, default="clipL") + parser.add_argument( + "--text-encoder-precision-2", + type=str, + default="fp16", + choices=["fp32", "fp16", "bf16"], + ) + parser.add_argument("--text-states-dim-2", type=int, default=768) + parser.add_argument("--tokenizer-2", type=str, default="clipL") + parser.add_argument("--text-len-2", type=int, default=77) + args = parser.parse_args() - main(args) + if args.quantization: + inference_quantization(args) + else: + inference(args) \ No newline at end of file diff --git a/fastvideo/test/test_hunyuan_hf.py b/fastvideo/test/test_hunyuan_hf.py deleted file mode 100644 index 9f46cb57..00000000 --- a/fastvideo/test/test_hunyuan_hf.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel -from diffusers.utils import export_to_video -import random -import numpy as np -import argparse -import os -from fastvideo.models.hunyuan.diffusion.schedulers import FlowMatchDiscreteScheduler - -def parse_args(): - parser = argparse.ArgumentParser(description='Generate video using Hunyuan model') - - parser.add_argument('--prompt', type=str, default="", help='Text prompt for video generation') - parser.add_argument('--model_path', type=str, default="/mbz/users/hao.zhang/data/hunyuan_diffusers", help='Path to the Hunyuan model directory') - parser.add_argument('--output_dir', type=str, default='outputs_video/hunyuan_hf', help='Directory to save the output video') - parser.add_argument('--height', type=int, default=480, help='Height of the output video') - parser.add_argument('--width', type=int, default=848, help='Width of the output video') - parser.add_argument('--num_frames', type=int, default=93, help='Number of frames to generate') - parser.add_argument('--num_inference_steps', type=int, default=50, help='Number of inference steps') - parser.add_argument('--seed', type=int, default=1024, help='Random seed for generation') - parser.add_argument('--fps', type=int, default=24, help='Frames per second for the output video') - - - return parser.parse_args() - -def main(): - args = parse_args() - prompt_candidates = ["Will Smith casually eats noodles, his relaxed demeanor contrasting with the energetic background of a bustling street food market. The scene captures a mix of humor and authenticity. Mid-shot framing, vibrant lighting.", - "A lone hiker stands atop a towering cliff, silhouetted against the vast horizon. The rugged landscape stretches endlessly beneath, its earthy tones blending into the soft blues of the sky. The scene captures the spirit of exploration and human resilience. High angle, dynamic framing, with soft natural lighting emphasizing the grandeur of nature.", - "A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. A beige string bag sits beside the bowl, adding a rustic touch to the scene. Additional lemons, one halved, are scattered around the base of the bowl. The even lighting enhances the vibrant colors and creates a fresh, inviting atmosphere.", - "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest. The playful yet serene atmosphere is complemented by soft natural light filtering through the petals. Mid-shot, warm and cheerful tones.", - "A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robots immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1", - "fox in the forest close-up quickly turned its head to the left.", - "Man walking his dog in the woods on a hot sunny day", - "A majestic lion strides across the golden savanna, its powerful frame glistening under the warm afternoon sun. The tall grass ripples gently in the breeze, enhancing the lion's commanding presence. The tone is vibrant, embodying the raw energy of the wild. Low angle, steady tracking shot, cinematic."] - # Set random seed - #args.prompt = "Will Smith casually eats noodles, his relaxed demeanor contrasting with the energetic background of a bustling street food market. The scene captures a mix of humor and authenticity. Mid-shot framing, vibrant lighting." - generator = torch.Generator("cpu").manual_seed(args.seed) - # Load transformer model - transformer = HunyuanVideoTransformer3DModel.from_pretrained( - pretrained_model_name_or_path=args.model_path, - subfolder="transformer", - torch_dtype=torch.bfloat16, - local_files_only=True - ) - - # Initialize pipeline - pipe = HunyuanVideoPipeline.from_pretrained( - pretrained_model_name_or_path=args.model_path, - transformer=transformer, - torch_dtype=torch.float16, - local_files_only=True - ) - #pipe.vae = pipe.vae.to(torch.bfloat16) - pipe.vae.enable_tiling() - - # Move to GPU - device = torch.cuda.current_device() - pipe.to(device) - #pipe.enable_model_cpu_offload(device) - - # Create output directory if it doesn't exist - os.makedirs(args.output_dir, exist_ok=True) - file_name = args.prompt[:20] - output_path = os.path.join(args.output_dir, file_name + 'output.mp4') - - # Generate video - output = pipe( - prompt=args.prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, - num_inference_steps=args.num_inference_steps, - generator=generator - ).frames[0] - - # Save video - export_to_video(output, output_path, fps=args.fps) - print(f"Video saved to: {output_path}") - -if __name__ == "__main__": - main() diff --git a/fastvideo/utils/checkpoint.py b/fastvideo/utils/checkpoint.py index 8b6b6f66..7aeb10c3 100644 --- a/fastvideo/utils/checkpoint.py +++ b/fastvideo/utils/checkpoint.py @@ -38,6 +38,14 @@ def save_checkpoint(model, optimizer, rank, output_dir, step, discriminator=Fals weight_path = os.path.join(save_dir, "diffusion_pytorch_model.safetensors") save_file(cpu_state, weight_path) config_dict = dict(model.config) + config_dict.pop('dtype') + # dtype = config_dict['dtype'] + # if dtype == torch.float32: + # config_dict['dtype'] = 'fp32' + # elif dtype == torch.float16: + # config_dict['dtype'] = 'fp16' + # elif dtype == torch.bfloat16: + # config_dict['dtype'] = 'bf16' config_path = os.path.join(save_dir, "config.json") # save dict as json with open(config_path, "w") as f: @@ -49,7 +57,7 @@ def save_checkpoint(model, optimizer, rank, output_dir, step, discriminator=Fals save_file(cpu_state, weight_path) optimizer_path = os.path.join(save_dir, "discriminator_optimizer.pt") torch.save(optim_state, optimizer_path) - + main_print(f"--> checkpoint saved at step {step}") def save_checkpoint_generator_discriminator( model, optimizer, discriminator, discriminator_optimizer, rank, output_dir, step, diff --git a/scripts/inference/inference_hunyuan_hf_quantization.sh b/scripts/inference/inference_hunyuan_hf_quantization.sh new file mode 100644 index 00000000..92cbefb7 --- /dev/null +++ b/scripts/inference/inference_hunyuan_hf_quantization.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +num_gpus=1 +export MODEL_BASE="data/FastHunyuan-git ad" +torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 12345 \ + fastvideo/sample/sample_t2v_diffusers_hunyuan.py \ + --height 720 \ + --width 1280 \ + --num_frames 45 \ + --num_inference_steps 50 \ + --guidance_scale 1 \ + --embedded_cfg_scale 6 \ + --flow_shift 17 \ + --prompt ./assets/prompt_HSH.txt \ + --seed 1024 \ + --output_path outputs_video/hunyuan_quant/nf4/ \ + --model_path $MODEL_BASE \ + --quantization "nf4" \ + --cpu_offload From d25c235558702a3d0b36dc93ecd442ea53e8e458 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Sun, 29 Dec 2024 21:52:29 +0000 Subject: [PATCH 12/42] unified hunyuan hf --- scripts/inference/inference_hunyuan_hf.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/inference/inference_hunyuan_hf.sh b/scripts/inference/inference_hunyuan_hf.sh index d957acea..545b1562 100644 --- a/scripts/inference/inference_hunyuan_hf.sh +++ b/scripts/inference/inference_hunyuan_hf.sh @@ -3,16 +3,16 @@ num_gpus=4 torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ fastvideo/sample/sample_t2v_hunyuan_hf.py \ - --model_path ~/data/hunyuan_diffusers/ \ + --model_path data/hunyuan_diffusers/ \ --prompt_path "assets/prompt.txt" \ --num_frames 93 \ --height 480 \ --width 848 \ --num_inference_steps 50 \ - --guidance_scale 6 \ - --output_path outputs_video/hunyuan_hf_test/ \ + --output_path outputs_video/hunyuan_hf/ \ --seed 1024 \ + From 557dbca227b03bae4eadd0bafc084dae37032d84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Sun, 29 Dec 2024 21:57:12 +0000 Subject: [PATCH 13/42] unified hunyuan hf --- scripts/finetune/finetune_hunyuan.sh | 8 ++-- scripts/finetune/finetune_hunyuan_hf_lora.sh | 15 ++++---- scripts/preprocess/preprocess_hunyuan_data.sh | 38 +++++++++---------- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/scripts/finetune/finetune_hunyuan.sh b/scripts/finetune/finetune_hunyuan.sh index 621023ca..d6078741 100644 --- a/scripts/finetune/finetune_hunyuan.sh +++ b/scripts/finetune/finetune_hunyuan.sh @@ -1,14 +1,14 @@ export WANDB_BASE_URL="https://api.wandb.ai" export WANDB_MODE=online -torchrun --nnodes 1 --nproc_per_node 8 \ +torchrun --nnodes 1 --nproc_per_node 4 \ fastvideo/train.py \ --seed 42 \ --pretrained_model_name_or_path data/hunyuan \ --dit_model_name_or_path data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt\ --model_type "hunyuan" \ --cache_dir data/.cache \ - --data_json_path data/Image-Vid-Finetune-Src/videos2caption_overfit.json \ + --data_json_path data/Image-Vid-Finetune-Src/videos2caption.json \ --validation_prompt_dir data/Image-Vid-Finetune-Src/validation \ --gradient_checkpointing \ --train_batch_size 1 \ @@ -22,14 +22,14 @@ torchrun --nnodes 1 --nproc_per_node 8 \ --mixed_precision bf16 \ --checkpointing_steps 200 \ --validation_steps 100 \ - --validation_sampling_steps 64 \ + --validation_sampling_steps 50 \ --checkpoints_total_limit 3 \ --allow_tf32 \ --ema_start_step 0 \ --cfg 0.0 \ --ema_decay 0.999 \ --log_validation \ - --output_dir ~/data/outputs/test/ \ + --output_dir data/outputs/HSH-Taylor-Finetune-Hunyuan/ \ --tracker_project_name HSH-Taylor-Finetune-Hunyuan \ --num_frames 93 \ --validation_guidance_scale "1.0" \ diff --git a/scripts/finetune/finetune_hunyuan_hf_lora.sh b/scripts/finetune/finetune_hunyuan_hf_lora.sh index 0a062166..3a2148fd 100644 --- a/scripts/finetune/finetune_hunyuan_hf_lora.sh +++ b/scripts/finetune/finetune_hunyuan_hf_lora.sh @@ -1,12 +1,11 @@ export WANDB_BASE_URL="https://api.wandb.ai" export WANDB_MODE=online - -torchrun --nnodes 1 --nproc_per_node 4 \ +torchrun --nnodes 1 --nproc_per_node 4 --master_port 29903 \ fastvideo/train.py \ --seed 1024 \ - --pretrained_model_name_or_path ~/data/hunyuan_diffusers \ + --pretrained_model_name_or_path data/hunyuan_diffusers \ --model_type hunyuan_hf \ - --cache_dir ~/data/.cache \ + --cache_dir data/.cache \ --data_json_path data/Image-Vid-Finetune-Src/videos2caption.json \ --validation_prompt_dir data/Image-Vid-Finetune-Src/validation \ --gradient_checkpointing \ @@ -19,7 +18,7 @@ torchrun --nnodes 1 --nproc_per_node 4 \ --max_train_steps 2000 \ --learning_rate 5e-6 \ --mixed_precision bf16 \ - --checkpointing_steps 200 \ + --checkpointing_steps 500 \ --validation_steps 100 \ --validation_sampling_steps 50 \ --checkpoints_total_limit 3 \ @@ -28,11 +27,11 @@ torchrun --nnodes 1 --nproc_per_node 4 \ --cfg 0.0 \ --ema_decay 0.999 \ --log_validation \ - --output_dir ~/data/outputs/HSH-Taylor-Finetune-Hunyuan \ + --output_dir data/outputs/HSH-Taylor-Finetune-Hunyuan \ --tracker_project_name HSH-Taylor-Finetune-Hunyuan \ --num_frames 93 \ --validation_guidance_scale "1.0" \ --group_frame \ --use_lora \ - --lora_rank 128 \ - --lora_alpha 256 + --lora_rank 64 \ + --lora_alpha 128 \ diff --git a/scripts/preprocess/preprocess_hunyuan_data.sh b/scripts/preprocess/preprocess_hunyuan_data.sh index f4e64ce7..f001fd0a 100644 --- a/scripts/preprocess/preprocess_hunyuan_data.sh +++ b/scripts/preprocess/preprocess_hunyuan_data.sh @@ -1,29 +1,29 @@ # export WANDB_MODE="offline" GPU_NUM=1 -MODEL_PATH="/mbz/users/hao.zhang/data/hunyuan" +MODEL_PATH="data/hunyuan" MODEL_TYPE="hunyuan" DATA_MERGE_PATH="data/Image-Vid-Finetune-Src/merge.txt" OUTPUT_DIR="data/Image-Vid-Finetune-Src" -VALIDATION_PATH="assets/prompt_HSH.txt" +VALIDATION_PATH="assets/prompt.txt" -# torchrun --nproc_per_node=$GPU_NUM \ -# fastvideo/data_preprocess/preprocess_vae_latents.py \ -# --model_path $MODEL_PATH \ -# --data_merge_path $DATA_MERGE_PATH \ -# --train_batch_size=1 \ -# --max_height=480 \ -# --max_width=848 \ -# --num_frames=93 \ -# --dataloader_num_workers 1 \ -# --output_dir=$OUTPUT_DIR \ -# --model_type $MODEL_TYPE \ -# --train_fps 24 +torchrun --nproc_per_node=$GPU_NUM \ + fastvideo/data_preprocess/preprocess_vae_latents.py \ + --model_path $MODEL_PATH \ + --data_merge_path $DATA_MERGE_PATH \ + --train_batch_size=1 \ + --max_height=480 \ + --max_width=848 \ + --num_frames=93 \ + --dataloader_num_workers 1 \ + --output_dir=$OUTPUT_DIR \ + --model_type $MODEL_TYPE \ + --train_fps 24 -# torchrun --nproc_per_node=$GPU_NUM \ -# fastvideo/data_preprocess/preprocess_text_embeddings.py \ -# --model_type $MODEL_TYPE \ -# --model_path $MODEL_PATH \ -# --output_dir=$OUTPUT_DIR +torchrun --nproc_per_node=$GPU_NUM \ + fastvideo/data_preprocess/preprocess_text_embeddings.py \ + --model_type $MODEL_TYPE \ + --model_path $MODEL_PATH \ + --output_dir=$OUTPUT_DIR torchrun --nproc_per_node=1 \ fastvideo/data_preprocess/preprocess_validation_text_embeddings.py \ From 0c058c300c772f8dcfcba53a7493812b9b16ce51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Tue, 7 Jan 2025 05:25:33 +0000 Subject: [PATCH 14/42] add lora --- fastvideo/models/hunyuan_hf/pipeline_hunyuan.py | 2 +- fastvideo/train.py | 3 ++- fastvideo/utils/validation.py | 8 +++++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py b/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py index 936d8a98..2a636920 100644 --- a/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py +++ b/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py @@ -347,7 +347,7 @@ def check_inputs( prompt_template=None, ): if height % 16 != 0 or width % 16 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs diff --git a/fastvideo/train.py b/fastvideo/train.py index 095f1bef..ec77fa2f 100644 --- a/fastvideo/train.py +++ b/fastvideo/train.py @@ -465,7 +465,7 @@ def main(args): save_checkpoint(transformer, optimizer, rank, args.output_dir, step) dist.barrier() if args.log_validation and step % args.validation_steps == 0: - log_validation(args, transformer, device, torch.bfloat16, step) + log_validation(args, transformer, device, torch.bfloat16, step, shift=args.shift) if args.use_lora: save_lora_checkpoint( @@ -564,6 +564,7 @@ def main(args): " training using `--resume_from_checkpoint`." ), ) + parser.add_argument("--shift", type=float, default=1.0, help=("Set shift to 7 for hunyuan model.")) parser.add_argument( "--resume_from_checkpoint", type=str, diff --git a/fastvideo/utils/validation.py b/fastvideo/utils/validation.py index be8751ea..36d54974 100644 --- a/fastvideo/utils/validation.py +++ b/fastvideo/utils/validation.py @@ -47,6 +47,7 @@ def prepare_latents( def sample_validation_video( + model_type, transformer, vae, scheduler, @@ -105,7 +106,7 @@ def sample_validation_video( threshold_noise = 0.025 sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) sigmas = np.array(sigmas) - if scheduler_type == "euler": + if scheduler_type == "euler" and model_type == "mochi": #todo timesteps, num_inference_steps = retrieve_timesteps( scheduler, num_inference_steps, device, timesteps, sigmas, ) @@ -233,7 +234,7 @@ def log_validation( ) vae.enable_tiling() if scheduler_type == "euler": - scheduler = FlowMatchEulerDiscreteScheduler() + scheduler = FlowMatchEulerDiscreteScheduler(shift=shift) else: linear_quadraic = True if scheduler_type == "pcm_linear_quadratic" else False scheduler = PCMFMScheduler( @@ -292,8 +293,9 @@ def log_validation( negative_prompt_attention_mask = ( torch.zeros(256).bool().to(device).unsqueeze(0) ) - generator = torch.Generator(device="cuda").manual_seed(12345) + generator = torch.Generator(device="cpu").manual_seed(1024) video = sample_validation_video( + args.model_type, transformer, vae, scheduler, From f35ea706cf9ee089bb9a47a300a95356d0752bbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Tue, 7 Jan 2025 23:13:20 +0000 Subject: [PATCH 15/42] unify hunyuan hf inference --- .../inference/inference_diffusers_hunyuan.sh | 20 ------------------- 1 file changed, 20 deletions(-) delete mode 100644 scripts/inference/inference_diffusers_hunyuan.sh diff --git a/scripts/inference/inference_diffusers_hunyuan.sh b/scripts/inference/inference_diffusers_hunyuan.sh deleted file mode 100644 index 376d2255..00000000 --- a/scripts/inference/inference_diffusers_hunyuan.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -num_gpus=1 -export MODEL_BASE="data/FastHunyuan-diffusers" -torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 12345 \ - fastvideo/sample/sample_t2v_diffusers_hunyuan.py \ - --height 720 \ - --width 1280 \ - --num_frames 45 \ - --num_inference_steps 6 \ - --guidance_scale 1 \ - --embedded_cfg_scale 6 \ - --flow_shift 17 \ - --flow-reverse \ - --prompt ./assets/prompt.txt \ - --seed 1024 \ - --output_path outputs_video/hunyuan_quant/nf4/ \ - --model_path $MODEL_BASE \ - --quantization "nf4" \ - --cpu_offload \ No newline at end of file From cad6e6d1f98da6a42bc079e3ae1302adc1320cca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Tue, 7 Jan 2025 23:14:41 +0000 Subject: [PATCH 16/42] unify hunyuan hf inference --- .../sample/sample_t2v_diffusers_hunyuan.py | 282 ------------------ 1 file changed, 282 deletions(-) delete mode 100644 fastvideo/sample/sample_t2v_diffusers_hunyuan.py diff --git a/fastvideo/sample/sample_t2v_diffusers_hunyuan.py b/fastvideo/sample/sample_t2v_diffusers_hunyuan.py deleted file mode 100644 index 81b1fb34..00000000 --- a/fastvideo/sample/sample_t2v_diffusers_hunyuan.py +++ /dev/null @@ -1,282 +0,0 @@ -import argparse -import io -import os -import time - -import imageio as iio -import numpy as np -import torch -from diffusers import (BitsAndBytesConfig, HunyuanVideoPipeline, - HunyuanVideoTransformer3DModel) - - -def export_to_video_bytes(fps, frames): - request = iio.core.Request("", mode="w", extension=".mp4") - pyavobject = iio.plugins.pyav.PyAVPlugin(request) - if isinstance(frames, np.ndarray): - frames = (np.array(frames) * 255).astype('uint8') - else: - frames = np.array(frames) - new_bytes = pyavobject.write(frames, codec="libx264", fps=fps) - out_bytes = io.BytesIO(new_bytes) - return out_bytes - - -def export_to_video(frames, path, fps): - video_bytes = export_to_video_bytes(fps, frames) - video_bytes.seek(0) - with open(path, "wb") as f: - f.write(video_bytes.getbuffer()) - - -def main(args): - torch.manual_seed(args.seed) - device = "cuda" if torch.cuda.is_available() else "cpu" - prompt_template = { - "template": - ("<|start_header_cid|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " - "1. The main content and theme of the video." - "2. The color, shape, size, texture, quantity, text, and spatial relationships of the contents, including objects, people, and anything else." - "3. Actions, events, behaviors temporal relationships, physical movement changes of the contents." - "4. Background environment, light, style, atmosphere, and qualities." - "5. Camera angles, movements, and transitions used in the video." - "6. Thematic and aesthetic concepts associated with the scene, i.e. realistic, futuristic, fairy tale, etc<|eot_id|>" - "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"), - "crop_start": - 95, - } - - model_id = args.model_path - - if args.quantization == "nf4": - quantization_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=torch.bfloat16, - bnb_4bit_quant_type="nf4", - llm_int8_skip_modules=["proj_out", "norm_out"]) - transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, - subfolder="transformer/", - torch_dtype=torch.bfloat16, - quantization_config=quantization_config) - if args.quantization == "int8": - quantization_config = BitsAndBytesConfig( - load_in_8bit=True, llm_int8_skip_modules=["proj_out", "norm_out"]) - transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, - subfolder="transformer/", - torch_dtype=torch.bfloat16, - quantization_config=quantization_config) - elif not args.quantization: - transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer/", - torch_dtype=torch.bfloat16).to(device) - - print("Max vram for read transformer:", - round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3), - "GiB") - torch.cuda.reset_max_memory_allocated(device) - - if not args.cpu_offload: - pipe = HunyuanVideoPipeline.from_pretrained( - model_id, torch_dtype=torch.bfloat16).to(device) - pipe.transformer = transformer - else: - pipe = HunyuanVideoPipeline.from_pretrained(model_id, - transformer=transformer, - torch_dtype=torch.bfloat16) - torch.cuda.reset_max_memory_allocated(device) - pipe.scheduler._shift = args.flow_shift - pipe.vae.enable_tiling() - if args.cpu_offload: - pipe.enable_model_cpu_offload() - print("Max vram for init pipeline:", - round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3), - "GiB") - with open(args.prompt) as f: - prompts = f.readlines() - - generator = torch.Generator("cpu").manual_seed(args.seed) - os.makedirs(os.path.dirname(args.output_path), exist_ok=True) - torch.cuda.reset_max_memory_allocated(device) - for prompt in prompts: - start_time = time.perf_counter() - output = pipe( - prompt=prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, - prompt_template=prompt_template, - num_inference_steps=args.num_inference_steps, - generator=generator, - ).frames[0] - export_to_video(output, - os.path.join(args.output_path, f"{prompt[:100]}.mp4"), - fps=args.fps) - print("Time:", round(time.perf_counter() - start_time, 2), "seconds") - print( - "Max vram for denoise:", - round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3), - "GiB") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - # Basic parameters - parser.add_argument("--prompt", type=str, help="prompt file for inference") - parser.add_argument("--num_frames", type=int, default=16) - parser.add_argument("--height", type=int, default=256) - parser.add_argument("--width", type=int, default=256) - parser.add_argument("--num_inference_steps", type=int, default=50) - parser.add_argument("--model_path", type=str, default="data/hunyuan") - parser.add_argument("--output_path", type=str, default="./outputs/video") - parser.add_argument("--fps", type=int, default=24) - parser.add_argument("--quantization", type=str, default=None) - parser.add_argument("--cpu_offload", action="store_true") - # Additional parameters - parser.add_argument( - "--denoise-type", - type=str, - default="flow", - help="Denoise type for noised inputs.", - ) - parser.add_argument("--seed", - type=int, - default=None, - help="Seed for evaluation.") - parser.add_argument("--neg_prompt", - type=str, - default=None, - help="Negative prompt for sampling.") - parser.add_argument( - "--guidance_scale", - type=float, - default=1.0, - help="Classifier free guidance scale.", - ) - parser.add_argument( - "--embedded_cfg_scale", - type=float, - default=6.0, - help="Embedded classifier free guidance scale.", - ) - parser.add_argument("--flow_shift", - type=int, - default=7, - help="Flow shift parameter.") - parser.add_argument("--batch_size", - type=int, - default=1, - help="Batch size for inference.") - parser.add_argument( - "--num_videos", - type=int, - default=1, - help="Number of videos to generate per prompt.", - ) - parser.add_argument( - "--load-key", - type=str, - default="module", - help= - "Key to load the model states. 'module' for the main model, 'ema' for the EMA model.", - ) - parser.add_argument( - "--use-cpu-offload", - action="store_true", - help="Use CPU offload for the model load.", - ) - parser.add_argument( - "--dit-weight", - type=str, - default= - "data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", - ) - parser.add_argument( - "--reproduce", - action="store_true", - help= - "Enable reproducibility by setting random seeds and deterministic algorithms.", - ) - parser.add_argument( - "--disable-autocast", - action="store_true", - help= - "Disable autocast for denoising loop and vae decoding in pipeline sampling.", - ) - - # Flow Matching - parser.add_argument( - "--flow-reverse", - action="store_true", - help="If reverse, learning/sampling from t=1 -> t=0.", - ) - parser.add_argument("--flow-solver", - type=str, - default="euler", - help="Solver for flow matching.") - parser.add_argument( - "--use-linear-quadratic-schedule", - action="store_true", - help= - "Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)", - ) - parser.add_argument( - "--linear-schedule-end", - type=int, - default=25, - help="End step for linear quadratic schedule for flow matching.", - ) - - # Model parameters - parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill") - parser.add_argument("--latent-channels", type=int, default=16) - parser.add_argument("--precision", - type=str, - default="bf16", - choices=["fp32", "fp16", "bf16", "fp8"]) - parser.add_argument("--rope-theta", - type=int, - default=256, - help="Theta used in RoPE.") - - parser.add_argument("--vae", type=str, default="884-16c-hy") - parser.add_argument("--vae-precision", - type=str, - default="fp16", - choices=["fp32", "fp16", "bf16"]) - parser.add_argument("--vae-tiling", action="store_true", default=True) - - parser.add_argument("--text-encoder", type=str, default="llm") - parser.add_argument( - "--text-encoder-precision", - type=str, - default="fp16", - choices=["fp32", "fp16", "bf16"], - ) - parser.add_argument("--text-states-dim", type=int, default=4096) - parser.add_argument("--text-len", type=int, default=256) - parser.add_argument("--tokenizer", type=str, default="llm") - parser.add_argument("--prompt-template", - type=str, - default="dit-llm-encode") - parser.add_argument("--prompt-template-video", - type=str, - default="dit-llm-encode-video") - parser.add_argument("--hidden-state-skip-layer", type=int, default=2) - parser.add_argument("--apply-final-norm", action="store_true") - - parser.add_argument("--text-encoder-2", type=str, default="clipL") - parser.add_argument( - "--text-encoder-precision-2", - type=str, - default="fp16", - choices=["fp32", "fp16", "bf16"], - ) - parser.add_argument("--text-states-dim-2", type=int, default=768) - parser.add_argument("--tokenizer-2", type=str, default="clipL") - parser.add_argument("--text-len-2", type=int, default=77) - - args = parser.parse_args() - main(args) From 18320426e0217d27e0dc2e3adb738ce8fb4a906f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Tue, 7 Jan 2025 23:24:07 +0000 Subject: [PATCH 17/42] syn with main --- assets/prompt.txt | 4 +- fastvideo/sample/sample_t2v_hunyuan_hf.py | 93 +++++++++++-------- scripts/finetune/finetune_hunyuan.sh | 23 +++-- scripts/finetune/finetune_mochi_lora.sh | 19 ++-- scripts/inference/inference_hunyuan.sh | 18 ++-- scripts/inference/inference_hunyuan_hf.sh | 29 +++++- .../inference_hunyuan_hf_quantization.sh | 8 +- scripts/inference/inference_mochi_sp.sh | 24 +---- scripts/preprocess/preprocess_hunyuan_data.sh | 4 +- 9 files changed, 125 insertions(+), 97 deletions(-) diff --git a/assets/prompt.txt b/assets/prompt.txt index 0d0a3c67..d10865e3 100644 --- a/assets/prompt.txt +++ b/assets/prompt.txt @@ -3,6 +3,6 @@ A lone hiker stands atop a towering cliff, silhouetted against the vast horizon. A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. A beige string bag sits beside the bowl, adding a rustic touch to the scene. Additional lemons, one halved, are scattered around the base of the bowl. The even lighting enhances the vibrant colors and creates a fresh, inviting atmosphere. A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest. The playful yet serene atmosphere is complemented by soft natural light filtering through the petals. Mid-shot, warm and cheerful tones. A superintelligent humanoid robot waking up. The robot has a sleek metallic body with futuristic design features. Its glowing red eyes are the focal point, emanating a sharp, intense light as it powers on. The scene is set in a dimly lit, high-tech laboratory filled with glowing control panels, robotic arms, and holographic screens. The setting emphasizes advanced technology and an atmosphere of mystery. The ambiance is eerie and dramatic, highlighting the moment of awakening and the robots immense intelligence. Photorealistic style with a cinematic, dark sci-fi aesthetic. Aspect ratio: 16:9 --v 6.1 -fox in the forest close-up quickly turned its head to the left. -Man walking his dog in the woods on a hot sunny day. +fox in the forest close-up quickly turned its head to the left +Man walking his dog in the woods on a hot sunny day A majestic lion strides across the golden savanna, its powerful frame glistening under the warm afternoon sun. The tall grass ripples gently in the breeze, enhancing the lion's commanding presence. The tone is vibrant, embodying the raw energy of the wild. Low angle, steady tracking shot, cinematic. \ No newline at end of file diff --git a/fastvideo/sample/sample_t2v_hunyuan_hf.py b/fastvideo/sample/sample_t2v_hunyuan_hf.py index a7369826..018aff7b 100644 --- a/fastvideo/sample/sample_t2v_hunyuan_hf.py +++ b/fastvideo/sample/sample_t2v_hunyuan_hf.py @@ -59,9 +59,10 @@ def inference(args): pipe.load_lora_weights(args.lora_checkpoint_dir, adapter_name="default") pipe.set_adapters(["default"], [lora_scaling]) print(f"Successfully Loaded LoRA weights from {args.lora_checkpoint_dir}") - #pipe.to(device) - - pipe.enable_model_cpu_offload(device) + if args.cpu_offload: + pipe.enable_model_cpu_offload(device) + else: + pipe.to(device) # Generate videos from the input prompt @@ -128,54 +129,67 @@ def inference_quantization(args): torch.manual_seed(args.seed) device = "cuda" if torch.cuda.is_available() else "cpu" prompt_template = { - "template": ( - "<|start_header_cid|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " - "1. The main content and theme of the video." - "2. The color, shape, size, texture, quantity, text, and spatial relationships of the contents, including objects, people, and anything else." - "3. Actions, events, behaviors temporal relationships, physical movement changes of the contents." - "4. Background environment, light, style, atmosphere, and qualities." - "5. Camera angles, movements, and transitions used in the video." - "6. Thematic and aesthetic concepts associated with the scene, i.e. realistic, futuristic, fairy tale, etc<|eot_id|>" - "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" - ), - "crop_start": 95, + "template": + ("<|start_header_cid|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the contents, including objects, people, and anything else." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the contents." + "4. Background environment, light, style, atmosphere, and qualities." + "5. Camera angles, movements, and transitions used in the video." + "6. Thematic and aesthetic concepts associated with the scene, i.e. realistic, futuristic, fairy tale, etc<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"), + "crop_start":95, } - - model_id = args.model_path if args.quantization == "nf4": - quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", llm_int8_skip_modules=["proj_out", "norm_out"]) + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_quant_type="nf4", + llm_int8_skip_modules=["proj_out", "norm_out"]) transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer/" ,torch_dtype=torch.bfloat16, quantization_config=quantization_config - ) + model_id, + subfolder="transformer/", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config) if args.quantization == "int8": - quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out", "norm_out"]) + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, llm_int8_skip_modules=["proj_out", "norm_out"]) transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer/" ,torch_dtype=torch.bfloat16, quantization_config=quantization_config - ) + model_id, + subfolder="transformer/", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config) elif not args.quantization: transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer/" ,torch_dtype=torch.bfloat16 - ).to(device) - - print("Max vram for read transofrmer:", round(torch.cuda.max_memory_allocated(device="cuda") / 1024 ** 3, 3), "GiB") + model_id, subfolder="transformer/", + torch_dtype=torch.bfloat16).to(device) + + print("Max vram for read transformer:", + round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3), + "GiB") torch.cuda.reset_max_memory_allocated(device) - + if not args.cpu_offload: - pipe = HunyuanVideoPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device) + pipe = HunyuanVideoPipeline.from_pretrained( + model_id, torch_dtype=torch.bfloat16).to(device) pipe.transformer = transformer else: - pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16) + pipe = HunyuanVideoPipeline.from_pretrained(model_id, + transformer=transformer, + torch_dtype=torch.bfloat16) torch.cuda.reset_max_memory_allocated(device) pipe.scheduler._shift = args.flow_shift pipe.vae.enable_tiling() if args.cpu_offload: pipe.enable_model_cpu_offload() - print("Max vram for init pipeline:", round(torch.cuda.max_memory_allocated(device="cuda") / 1024 ** 3, 3), "GiB") + print("Max vram for init pipeline:", + round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3), + "GiB") with open(args.prompt) as f: prompts = f.readlines() - + generator = torch.Generator("cpu").manual_seed(args.seed) os.makedirs(os.path.dirname(args.output_path), exist_ok=True) torch.cuda.reset_max_memory_allocated(device) @@ -183,16 +197,21 @@ def inference_quantization(args): start_time = time.perf_counter() output = pipe( prompt=prompt, - height = args.height, - width = args.width, - num_frames = args.num_frames, + height=args.height, + width=args.width, + num_frames=args.num_frames, prompt_template=prompt_template, - num_inference_steps = args.num_inference_steps, + num_inference_steps=args.num_inference_steps, generator=generator, ).frames[0] - export_to_video(output, os.path.join(args.output_path, f"{prompt[:100]}.mp4"), fps=args.fps) + export_to_video(output, + os.path.join(args.output_path, f"{prompt[:100]}.mp4"), + fps=args.fps) print("Time:", round(time.perf_counter() - start_time, 2), "seconds") - print("Max vram for denoise:", round(torch.cuda.max_memory_allocated(device="cuda") / 1024 ** 3, 3), "GiB") + print( + "Max vram for denoise:", + round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3), + "GiB") if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/scripts/finetune/finetune_hunyuan.sh b/scripts/finetune/finetune_hunyuan.sh index 847debbd..12e8e43a 100644 --- a/scripts/finetune/finetune_hunyuan.sh +++ b/scripts/finetune/finetune_hunyuan.sh @@ -1,39 +1,38 @@ export WANDB_BASE_URL="https://api.wandb.ai" export WANDB_MODE=online -torchrun --nnodes 1 --nproc_per_node 4 \ +torchrun --nnodes 1 --nproc_per_node 8 \ fastvideo/train.py \ --seed 42 \ --pretrained_model_name_or_path data/hunyuan \ --dit_model_name_or_path data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt\ --model_type "hunyuan" \ --cache_dir data/.cache \ - --data_json_path data/Image-Vid-Finetune-Src/videos2caption.json \ - --validation_prompt_dir data/Image-Vid-Finetune-Src/validation \ + --data_json_path data/Image-Vid-Finetune-HunYuan/videos2caption.json \ + --validation_prompt_dir data/Image-Vid-Finetune-HunYuan/validation \ --gradient_checkpointing \ - --train_batch_size 1 \ + --train_batch_size=1 \ --num_latent_t 24 \ --sp_size 4 \ --train_sp_batch_size 1 \ --dataloader_num_workers 4 \ - --gradient_accumulation_steps 1 \ - --max_train_steps 2000 \ - --learning_rate 5e-6 \ - --mixed_precision bf16 \ - --checkpointing_steps 200 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=2000 \ + --learning_rate=5e-6 \ + --mixed_precision=bf16 \ + --checkpointing_steps=200 \ --validation_steps 100 \ - --validation_sampling_steps 50 \ + --validation_sampling_steps 64 \ --checkpoints_total_limit 3 \ --allow_tf32 \ --ema_start_step 0 \ --cfg 0.0 \ --ema_decay 0.999 \ --log_validation \ - --output_dir data/outputs/HSH-Taylor-Finetune-Hunyuan/ \ + --output_dir=data/outputs/HSH-Taylor-Finetune-Hunyuan \ --tracker_project_name HSH-Taylor-Finetune-Hunyuan \ --num_frames 93 \ --num_height 720 \ --num_width 1280 \ --validation_guidance_scale "1.0" \ - --master_weight_type fp32 \ --group_frame \ No newline at end of file diff --git a/scripts/finetune/finetune_mochi_lora.sh b/scripts/finetune/finetune_mochi_lora.sh index 4015a61d..2f5a39bd 100644 --- a/scripts/finetune/finetune_mochi_lora.sh +++ b/scripts/finetune/finetune_mochi_lora.sh @@ -1,18 +1,17 @@ export WANDB_BASE_URL="https://api.wandb.ai" export WANDB_MODE=online -CUDA_VISIBLE_DEVICES=6 torchrun --nnodes 1 --nproc_per_node 1 --master_port 29403 \ - fastvideo/train_new.py \ +torchrun --nnodes 1 --nproc_per_node 2 \ + fastvideo/train.py \ --seed 42 \ - --model_type mochi \ - --pretrained_model_name_or_path ~/data/mochi_diffusers \ + --pretrained_model_name_or_path data/mochi \ --cache_dir data/.cache \ - --data_json_path data/Encoder_Overfit_Data/videos2caption.json \ - --validation_prompt_dir data/validation_prompt_embed_mask \ + --data_json_path data/Mochi-Black-Myth/videos2caption.json \ + --validation_prompt_dir data/Mochi-Black-Myth/validation \ --gradient_checkpointing \ --train_batch_size 1 \ - --num_latent_t 2 \ - --sp_size 1 \ + --num_latent_t 14 \ + --sp_size 2 \ --train_sp_batch_size 1 \ --dataloader_num_workers 1 \ --gradient_accumulation_steps 2 \ @@ -28,11 +27,11 @@ CUDA_VISIBLE_DEVICES=6 torchrun --nnodes 1 --nproc_per_node 1 --master_port 2940 --cfg 0.0 \ --ema_decay 0.999 \ --log_validation \ - --output_dir data/outputs/Black-Myth-Lora-FT \ + --output_dir=data/outputs/Black-Myth-Lora-FT \ --tracker_project_name Black-Myth-Lora-Finetune \ --num_frames 91 \ --lora_rank 128 \ --lora_alpha 256 \ - --master_weight_type fp32 \ + --master_weight_type "bf16" \ --use_lora \ --use_cpu_offload diff --git a/scripts/inference/inference_hunyuan.sh b/scripts/inference/inference_hunyuan.sh index fb419b73..0340431f 100644 --- a/scripts/inference/inference_hunyuan.sh +++ b/scripts/inference/inference_hunyuan.sh @@ -1,19 +1,19 @@ #!/bin/bash -num_gpus=1 -export MODEL_BASE=~/data/hunyuan/ -torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29703 \ +num_gpus=4 +export MODEL_BASE=data/FastHunyuan +torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ fastvideo/sample/sample_t2v_hunyuan.py \ - --height 480 \ - --width 848 \ - --num_frames 93 \ - --num_inference_steps 50 \ + --height 720 \ + --width 1280 \ + --num_frames 125 \ + --num_inference_steps 6 \ --guidance_scale 1 \ --embedded_cfg_scale 6 \ - --flow_shift 7 \ + --flow_shift 17 \ --flow-reverse \ --prompt ./assets/prompt.txt \ --seed 1024 \ - --output_path outputs_video/hunyuan/ \ + --output_path outputs_video/hunyuan/cfg6/ \ --model_path $MODEL_BASE \ --dit-weight ${MODEL_BASE}/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt diff --git a/scripts/inference/inference_hunyuan_hf.sh b/scripts/inference/inference_hunyuan_hf.sh index 545b1562..a72178d8 100644 --- a/scripts/inference/inference_hunyuan_hf.sh +++ b/scripts/inference/inference_hunyuan_hf.sh @@ -3,8 +3,8 @@ num_gpus=4 torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ fastvideo/sample/sample_t2v_hunyuan_hf.py \ - --model_path data/hunyuan_diffusers/ \ - --prompt_path "assets/prompt.txt" \ + --model_path ~/data/hunyuan_diffusers/ \ + --prompt_path "assets/prompt_test_3.txt" \ --num_frames 93 \ --height 480 \ --width 848 \ @@ -12,6 +12,31 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ --output_path outputs_video/hunyuan_hf/ \ --seed 1024 \ +num_gpus=4 +CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ + fastvideo/sample/sample_t2v_hunyuan_hf.py \ + --model_path ~/data/hunyuan_diffusers/ \ + --prompt_path "assets/prompt_test_3.txt" \ + --num_frames 93 \ + --height 480 \ + --width 848 \ + --num_inference_steps 50 \ + --output_path outputs_video/hunyuan_hf_new_4/ \ + --seed 1024 \ + --lora_checkpoint_dir data/outputs/HSH-Taylor-Finetune-Hunyuan_8e5_ra32_v40/lora-checkpoint-6000/ + +num_gpus=4 +torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29603 \ + fastvideo/sample/sample_t2v_hunyuan_hf.py \ + --model_path ~/data/hunyuan_diffusers/ \ + --prompt_path "assets/prompt_test_3.txt" \ + --num_frames 93 \ + --height 480 \ + --width 848 \ + --num_inference_steps 50 \ + --output_path outputs_video/hunyuan_hf_new_5/ \ + --seed 12345 \ + --lora_checkpoint_dir data/outputs/HSH-Taylor-Finetune-Hunyuan_1e4_ra64_v40/lora-checkpoint-5250/ diff --git a/scripts/inference/inference_hunyuan_hf_quantization.sh b/scripts/inference/inference_hunyuan_hf_quantization.sh index 92cbefb7..a94460f4 100644 --- a/scripts/inference/inference_hunyuan_hf_quantization.sh +++ b/scripts/inference/inference_hunyuan_hf_quantization.sh @@ -1,17 +1,17 @@ #!/bin/bash num_gpus=1 -export MODEL_BASE="data/FastHunyuan-git ad" +export MODEL_BASE="data/FastHunyuan" torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 12345 \ - fastvideo/sample/sample_t2v_diffusers_hunyuan.py \ + fastvideo/sample/sample_t2v_hunyuan_hf.py \ --height 720 \ --width 1280 \ --num_frames 45 \ - --num_inference_steps 50 \ + --num_inference_steps 6 \ --guidance_scale 1 \ --embedded_cfg_scale 6 \ --flow_shift 17 \ - --prompt ./assets/prompt_HSH.txt \ + --prompt ./assets/prompt.txt \ --seed 1024 \ --output_path outputs_video/hunyuan_quant/nf4/ \ --model_path $MODEL_BASE \ diff --git a/scripts/inference/inference_mochi_sp.sh b/scripts/inference/inference_mochi_sp.sh index 48a4c527..1ded5727 100644 --- a/scripts/inference/inference_mochi_sp.sh +++ b/scripts/inference/inference_mochi_sp.sh @@ -1,32 +1,18 @@ #!/bin/bash num_gpus=4 -CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29603 \ + +torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ fastvideo/sample/sample_t2v_mochi.py \ - --model_path ~/data/mochi_diffusers/ \ + --model_path data/FastMochi-diffusers \ --prompt_path "assets/prompt.txt" \ - --num_frames 91 \ + --num_frames 163 \ --height 480 \ --width 848 \ - --num_inference_steps 64 \ + --num_inference_steps 8 \ --guidance_scale 1.5 \ --output_path outputs_video/mochi_sp/ \ --seed 1024 \ --scheduler_type "pcm_linear_quadratic" \ --linear_threshold 0.1 \ --linear_range 0.75 - -num_gpus=4 -torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ - fastvideo/sample/sample_t2v_mochi.py \ - --model_path ~/data/mochi_diffusers/ \ - --prompt_path "assets/prompt_test.txt" \ - --num_frames 163 \ - --height 480 \ - --width 848 \ - --num_inference_steps 64 \ - --guidance_scale 6 \ - --output_path outputs_video/debug \ - --shift 8 \ - --seed 12345 \ - --scheduler_type "pcm_linear_quadratic" diff --git a/scripts/preprocess/preprocess_hunyuan_data.sh b/scripts/preprocess/preprocess_hunyuan_data.sh index f001fd0a..fc452337 100644 --- a/scripts/preprocess/preprocess_hunyuan_data.sh +++ b/scripts/preprocess/preprocess_hunyuan_data.sh @@ -1,9 +1,9 @@ # export WANDB_MODE="offline" -GPU_NUM=1 +GPU_NUM=1 # 2,4,8 MODEL_PATH="data/hunyuan" MODEL_TYPE="hunyuan" DATA_MERGE_PATH="data/Image-Vid-Finetune-Src/merge.txt" -OUTPUT_DIR="data/Image-Vid-Finetune-Src" +OUTPUT_DIR="data/Image-Vid-Finetune-HunYuan" VALIDATION_PATH="assets/prompt.txt" torchrun --nproc_per_node=$GPU_NUM \ From 5d21b16c2c2900471a9cf1074c01e93738ca6620 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Tue, 7 Jan 2025 23:29:07 +0000 Subject: [PATCH 18/42] syn --- assets/prompt_test.txt | 1 - fastvideo/models/hunyuan/modules/models.py | 1 + scripts/finetune/finetune_hunyuan_hf_lora.sh | 8 ++++---- scripts/inference/inference_mochi_sp.sh | 1 + 4 files changed, 6 insertions(+), 5 deletions(-) delete mode 100644 assets/prompt_test.txt diff --git a/assets/prompt_test.txt b/assets/prompt_test.txt deleted file mode 100644 index 333cef7e..00000000 --- a/assets/prompt_test.txt +++ /dev/null @@ -1 +0,0 @@ -Wukong stands prominently against a clear sky. Wukong's fur is dense and dark, framing an intense expression as they hold a large staff confidently across their shoulders. Elaborate golden armor intricately covers Wukong's torso, adorned with ornate designs and embellishments. With every subtle movement, Wukong exudes a sense of readiness and power, as if preparing for an impending challenge. The crown atop Wukong's head glints majestically in the sunlight, symbolizing leadership and authority. Wukong's fierce eyes remain focused and vigilant, capturing the viewer's attention with an aura of both mystery and strength. \ No newline at end of file diff --git a/fastvideo/models/hunyuan/modules/models.py b/fastvideo/models/hunyuan/modules/models.py index 68e33bc9..759897e2 100644 --- a/fastvideo/models/hunyuan/modules/models.py +++ b/fastvideo/models/hunyuan/modules/models.py @@ -19,6 +19,7 @@ from .posemb_layers import apply_rotary_emb from .token_refiner import SingleTokenRefiner + class MMDoubleStreamBlock(nn.Module): """ A multimodal dit block with separate modulation for diff --git a/scripts/finetune/finetune_hunyuan_hf_lora.sh b/scripts/finetune/finetune_hunyuan_hf_lora.sh index 3a2148fd..8f7e85a4 100644 --- a/scripts/finetune/finetune_hunyuan_hf_lora.sh +++ b/scripts/finetune/finetune_hunyuan_hf_lora.sh @@ -14,9 +14,9 @@ torchrun --nnodes 1 --nproc_per_node 4 --master_port 29903 \ --sp_size 4 \ --train_sp_batch_size 1 \ --dataloader_num_workers 4 \ - --gradient_accumulation_steps 2 \ + --gradient_accumulation_steps 4 \ --max_train_steps 2000 \ - --learning_rate 5e-6 \ + --learning_rate 8e-5 \ --mixed_precision bf16 \ --checkpointing_steps 500 \ --validation_steps 100 \ @@ -33,5 +33,5 @@ torchrun --nnodes 1 --nproc_per_node 4 --master_port 29903 \ --validation_guidance_scale "1.0" \ --group_frame \ --use_lora \ - --lora_rank 64 \ - --lora_alpha 128 \ + --lora_rank 32 \ + --lora_alpha 32 \ diff --git a/scripts/inference/inference_mochi_sp.sh b/scripts/inference/inference_mochi_sp.sh index 1ded5727..5f2da77e 100644 --- a/scripts/inference/inference_mochi_sp.sh +++ b/scripts/inference/inference_mochi_sp.sh @@ -16,3 +16,4 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ --scheduler_type "pcm_linear_quadratic" \ --linear_threshold 0.1 \ --linear_range 0.75 + \ No newline at end of file From 1d7d637b805b90fa0f475e69b8831ee24e20c3a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Tue, 7 Jan 2025 23:30:46 +0000 Subject: [PATCH 19/42] syn --- scripts/inference/inference_hunyuan_hf.sh | 28 ----------------------- scripts/inference/inference_mochi_sp.sh | 1 - 2 files changed, 29 deletions(-) diff --git a/scripts/inference/inference_hunyuan_hf.sh b/scripts/inference/inference_hunyuan_hf.sh index a72178d8..b14d4fc0 100644 --- a/scripts/inference/inference_hunyuan_hf.sh +++ b/scripts/inference/inference_hunyuan_hf.sh @@ -12,32 +12,4 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ --output_path outputs_video/hunyuan_hf/ \ --seed 1024 \ -num_gpus=4 -CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ - fastvideo/sample/sample_t2v_hunyuan_hf.py \ - --model_path ~/data/hunyuan_diffusers/ \ - --prompt_path "assets/prompt_test_3.txt" \ - --num_frames 93 \ - --height 480 \ - --width 848 \ - --num_inference_steps 50 \ - --output_path outputs_video/hunyuan_hf_new_4/ \ - --seed 1024 \ - --lora_checkpoint_dir data/outputs/HSH-Taylor-Finetune-Hunyuan_8e5_ra32_v40/lora-checkpoint-6000/ - -num_gpus=4 -torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29603 \ - fastvideo/sample/sample_t2v_hunyuan_hf.py \ - --model_path ~/data/hunyuan_diffusers/ \ - --prompt_path "assets/prompt_test_3.txt" \ - --num_frames 93 \ - --height 480 \ - --width 848 \ - --num_inference_steps 50 \ - --output_path outputs_video/hunyuan_hf_new_5/ \ - --seed 12345 \ - --lora_checkpoint_dir data/outputs/HSH-Taylor-Finetune-Hunyuan_1e4_ra64_v40/lora-checkpoint-5250/ - - - diff --git a/scripts/inference/inference_mochi_sp.sh b/scripts/inference/inference_mochi_sp.sh index 5f2da77e..1ded5727 100644 --- a/scripts/inference/inference_mochi_sp.sh +++ b/scripts/inference/inference_mochi_sp.sh @@ -16,4 +16,3 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ --scheduler_type "pcm_linear_quadratic" \ --linear_threshold 0.1 \ --linear_range 0.75 - \ No newline at end of file From c1cf441afbd5d4431464ebfb24341626892ade13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Tue, 7 Jan 2025 23:31:33 +0000 Subject: [PATCH 20/42] syn --- scripts/inference/inference_mochi_sp.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/inference/inference_mochi_sp.sh b/scripts/inference/inference_mochi_sp.sh index 1ded5727..ad76a936 100644 --- a/scripts/inference/inference_mochi_sp.sh +++ b/scripts/inference/inference_mochi_sp.sh @@ -15,4 +15,4 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ --seed 1024 \ --scheduler_type "pcm_linear_quadratic" \ --linear_threshold 0.1 \ - --linear_range 0.75 + --linear_range 0.75 \ No newline at end of file From 22d499bab211a38f8d16503eddf337cfbecd09e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Tue, 7 Jan 2025 23:32:38 +0000 Subject: [PATCH 21/42] syn --- scripts/inference/inference_mochi_sp.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/inference/inference_mochi_sp.sh b/scripts/inference/inference_mochi_sp.sh index ad76a936..5f2da77e 100644 --- a/scripts/inference/inference_mochi_sp.sh +++ b/scripts/inference/inference_mochi_sp.sh @@ -15,4 +15,5 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ --seed 1024 \ --scheduler_type "pcm_linear_quadratic" \ --linear_threshold 0.1 \ - --linear_range 0.75 \ No newline at end of file + --linear_range 0.75 + \ No newline at end of file From e7ea0d7cd073e547b9329c3685ff7afe2b84ba8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Tue, 7 Jan 2025 23:33:03 +0000 Subject: [PATCH 22/42] syn --- scripts/inference/inference_mochi_sp.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/inference/inference_mochi_sp.sh b/scripts/inference/inference_mochi_sp.sh index 5f2da77e..1ded5727 100644 --- a/scripts/inference/inference_mochi_sp.sh +++ b/scripts/inference/inference_mochi_sp.sh @@ -16,4 +16,3 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ --scheduler_type "pcm_linear_quadratic" \ --linear_threshold 0.1 \ --linear_range 0.75 - \ No newline at end of file From afe24e2be9c93a8f657023b10940faf0f806e368 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Tue, 7 Jan 2025 23:34:48 +0000 Subject: [PATCH 23/42] syn --- prompts.txt | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 prompts.txt diff --git a/prompts.txt b/prompts.txt deleted file mode 100644 index c228a916..00000000 --- a/prompts.txt +++ /dev/null @@ -1,4 +0,0 @@ -Will Smith casually eats noodles, his relaxed demeanor contrasting with the energetic background of a bustling street food market. The scene captures a mix of humor and authenticity. Mid-shot framing, vibrant lighting. -A lone hiker stands atop a towering cliff, silhouetted against the vast horizon. The rugged landscape stretches endlessly beneath, its earthy tones blending into the soft blues of the sky. The scene captures the spirit of exploration and human resilience. High angle, dynamic framing, with soft natural lighting emphasizing the grandeur of nature. -A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. A beige string bag sits beside the bowl, adding a rustic touch to the scene. Additional lemons, one halved, are scattered around the base of the bowl. The even lighting enhances the vibrant colors and creates a fresh, inviting atmosphere. -A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest. The playful yet serene atmosphere is complemented by soft natural light filtering through the petals. Mid-shot, warm and cheerful tones. From 41f8ac9875b5d834aab4de8c0b1133277f355328 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Wed, 8 Jan 2025 02:00:07 +0000 Subject: [PATCH 24/42] syn --- .../models/hunyuan_hf/modeling_hunyuan.py | 427 +++++++++++------- .../models/hunyuan_hf/pipeline_hunyuan.py | 205 +++++---- fastvideo/train.py | 34 +- fastvideo/utils/checkpoint.py | 8 +- fastvideo/utils/load.py | 32 +- fastvideo/utils/validation.py | 4 +- 6 files changed, 436 insertions(+), 274 deletions(-) diff --git a/fastvideo/models/hunyuan_hf/modeling_hunyuan.py b/fastvideo/models/hunyuan_hf/modeling_hunyuan.py index 655206ef..a722ca63 100644 --- a/fastvideo/models/hunyuan_hf/modeling_hunyuan.py +++ b/fastvideo/models/hunyuan_hf/modeling_hunyuan.py @@ -13,34 +13,41 @@ # limitations under the License. from typing import Any, Dict, List, Optional, Tuple, Union + import torch import torch.nn as nn import torch.nn.functional as F - from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import PeftAdapterMixin, FromOriginalModelMixin -from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention import FeedForward from diffusers.models.attention_processor import Attention, AttentionProcessor from diffusers.models.embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, - CombinedTimestepTextProjEmbeddings, - get_1d_rotary_pos_embed, -) + CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed) from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle -from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info -from fastvideo.utils.communications import all_gather, all_to_all_4D +from diffusers.models.normalization import (AdaLayerNormContinuous, + AdaLayerNormZero, + AdaLayerNormZeroSingle) +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, + scale_lora_layers, unscale_lora_layers) + from fastvideo.models.flash_attn_no_pad import flash_attn_no_pad +from fastvideo.utils.communications import all_gather, all_to_all_4D +from fastvideo.utils.parallel_states import (get_sequence_parallel_state, + nccl_info) + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + def shrink_head(encoder_state, dim): local_heads = encoder_state.shape[dim] // nccl_info.sp_size - return encoder_state.narrow( - dim, nccl_info.rank_within_group * local_heads, local_heads - ) + return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, + local_heads) + + class HunyuanVideoAttnProcessor2_0: + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( @@ -53,13 +60,14 @@ def __call__( hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: sequence_length = hidden_states.size(1) encoder_sequence_length = encoder_hidden_states.size(1) - if attn.add_q_proj is None and encoder_hidden_states is not None: - hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + if attn.add_q_proj is None and encoder_hidden_states is not None: + hidden_states = torch.cat([hidden_states, encoder_hidden_states], + dim=1) # 1. QKV projections query = attn.to_q(hidden_states) @@ -72,7 +80,7 @@ def __call__( # 2. QK normalization if attn.norm_q is not None: - query = attn.norm_q(query).to(value) + query = attn.norm_q(query).to(value) if attn.norm_k is not None: key = attn.norm_k(key).to(value) @@ -88,15 +96,19 @@ def __call__( if attn.add_q_proj is None and encoder_hidden_states is not None: query = torch.cat( [ - apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), - query[:, :, -encoder_hidden_states.shape[1] :], + apply_rotary_emb( + query[:, :, :-encoder_hidden_states.shape[1]], + image_rotary_emb), + query[:, :, -encoder_hidden_states.shape[1]:], ], dim=2, ) key = torch.cat( [ - apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), - key[:, :, -encoder_hidden_states.shape[1] :], + apply_rotary_emb( + key[:, :, :-encoder_hidden_states.shape[1]], + image_rotary_emb), + key[:, :, -encoder_hidden_states.shape[1]:], ], dim=2, ) @@ -110,56 +122,73 @@ def __call__( encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) - encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_query = encoder_query.unflatten( + 2, (attn.heads, -1)).transpose(1, 2) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose( + 1, 2) + encoder_value = encoder_value.unflatten( + 2, (attn.heads, -1)).transpose(1, 2) if attn.norm_added_q is not None: - encoder_query = attn.norm_added_q(encoder_query).to(encoder_value) + encoder_query = attn.norm_added_q(encoder_query).to( + encoder_value) if attn.norm_added_k is not None: encoder_key = attn.norm_added_k(encoder_key).to(encoder_value) - query = torch.cat([query, encoder_query], dim=2) + query = torch.cat([query, encoder_query], dim=2) key = torch.cat([key, encoder_key], dim=2) value = torch.cat([value, encoder_value], dim=2) - + if get_sequence_parallel_state(): - query_img, query_txt = query[:,:,:sequence_length,:], query[:,:,sequence_length:,:] - key_img, key_txt = key[:,:,:sequence_length,:], key[:,:,sequence_length:,:] - value_img, value_txt = value[:,:,:sequence_length,:], value[:,:,sequence_length:,:] - query_img = all_to_all_4D(query_img, scatter_dim=1, gather_dim=2) # + query_img, query_txt = query[:, :, : + sequence_length, :], query[:, :, + sequence_length:, :] + key_img, key_txt = key[:, :, : + sequence_length, :], key[:, :, + sequence_length:, :] + value_img, value_txt = value[:, :, : + sequence_length, :], value[:, :, + sequence_length:, :] + query_img = all_to_all_4D(query_img, scatter_dim=1, + gather_dim=2) # key_img = all_to_all_4D(key_img, scatter_dim=1, gather_dim=2) value_img = all_to_all_4D(value_img, scatter_dim=1, gather_dim=2) - query_txt = shrink_head(query_txt, dim=1) + query_txt = shrink_head(query_txt, dim=1) key_txt = shrink_head(key_txt, dim=1) value_txt = shrink_head(value_txt, dim=1) query = torch.cat([query_img, query_txt], dim=2) key = torch.cat([key_img, key_txt], dim=2) value = torch.cat([value_img, value_txt], dim=2) - - query = query.unsqueeze(2) + + query = query.unsqueeze(2) key = key.unsqueeze(2) value = value.unsqueeze(2) qkv = torch.cat([query, key, value], dim=2) - qkv = qkv.transpose(1,3) - + qkv = qkv.transpose(1, 3) + # 5. Attention - attention_mask = attention_mask[:,0,:] + attention_mask = attention_mask[:, 0, :] seq_len = qkv.shape[1] attn_len = attention_mask.shape[1] - attention_mask = F.pad(attention_mask, (seq_len-attn_len, 0), value=True) + attention_mask = F.pad(attention_mask, (seq_len - attn_len, 0), + value=True) + + hidden_states = flash_attn_no_pad(qkv, + attention_mask, + causal=False, + dropout_p=0.0, + softmax_scale=None) - hidden_states = flash_attn_no_pad(qkv, attention_mask, causal=False, dropout_p=0.0, softmax_scale=None) - if get_sequence_parallel_state(): hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( - (sequence_length * nccl_info.sp_size, encoder_sequence_length), dim=1 - ) - hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) - encoder_hidden_states = all_gather( - encoder_hidden_states, dim=2 - ).contiguous() + (sequence_length * nccl_info.sp_size, encoder_sequence_length), + dim=1) + hidden_states = all_to_all_4D(hidden_states, + scatter_dim=1, + gather_dim=2) + encoder_hidden_states = all_gather(encoder_hidden_states, + dim=2).contiguous() hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) encoder_hidden_states = encoder_hidden_states.flatten(2, 3) @@ -171,10 +200,10 @@ def __call__( # 6. Output projection if encoder_hidden_states is not None: hidden_states, encoder_hidden_states = ( - hidden_states[:, : -encoder_hidden_states.shape[1]], - hidden_states[:, -encoder_hidden_states.shape[1] :], + hidden_states[:, :-encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1]:], ) - + if encoder_hidden_states is not None: if getattr(attn, "to_out", None) is not None: hidden_states = attn.to_out[0](hidden_states) @@ -187,6 +216,7 @@ def __call__( class HunyuanVideoPatchEmbed(nn.Module): + def __init__( self, patch_size: Union[int, Tuple[int, int, int]] = 16, @@ -195,17 +225,25 @@ def __init__( ) -> None: super().__init__() - patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size - self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + patch_size = (patch_size, patch_size, patch_size) if isinstance( + patch_size, int) else patch_size + self.proj = nn.Conv3d(in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.proj(hidden_states) - hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC + hidden_states = hidden_states.flatten(2).transpose(1, + 2) # BCFHW -> BNC return hidden_states class HunyuanVideoAdaNorm(nn.Module): - def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + + def __init__(self, + in_features: int, + out_features: Optional[int] = None) -> None: super().__init__() out_features = out_features or 2 * in_features @@ -214,7 +252,8 @@ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None def forward( self, temb: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor]: temb = self.linear(self.nonlinearity(temb)) gate_msa, gate_mlp = temb.chunk(2, dim=1) gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) @@ -222,6 +261,7 @@ def forward( class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): + def __init__( self, num_attention_heads: int, @@ -234,7 +274,9 @@ def __init__( hidden_size = num_attention_heads * attention_head_dim - self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.norm1 = nn.LayerNorm(hidden_size, + elementwise_affine=True, + eps=1e-6) self.attn = Attention( query_dim=hidden_size, cross_attention_dim=None, @@ -243,8 +285,13 @@ def __init__( bias=attention_bias, ) - self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) - self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) + self.norm2 = nn.LayerNorm(hidden_size, + elementwise_affine=True, + eps=1e-6) + self.ff = FeedForward(hidden_size, + mult=mlp_width_ratio, + activation_fn="linear-silu", + dropout=mlp_drop_rate) self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) @@ -272,6 +319,7 @@ def forward( class HunyuanVideoIndividualTokenRefiner(nn.Module): + def __init__( self, num_attention_heads: int, @@ -283,18 +331,15 @@ def __init__( ) -> None: super().__init__() - self.refiner_blocks = nn.ModuleList( - [ - HunyuanVideoIndividualTokenRefinerBlock( - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - mlp_width_ratio=mlp_width_ratio, - mlp_drop_rate=mlp_drop_rate, - attention_bias=attention_bias, - ) - for _ in range(num_layers) - ] - ) + self.refiner_blocks = nn.ModuleList([ + HunyuanVideoIndividualTokenRefinerBlock( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) for _ in range(num_layers) + ]) def forward( self, @@ -307,7 +352,9 @@ def forward( batch_size = attention_mask.shape[0] seq_len = attention_mask.shape[1] attention_mask = attention_mask.to(hidden_states.device).bool() - self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, + seq_len).repeat( + 1, 1, seq_len, 1) self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() self_attn_mask[:, :, :, 0] = True @@ -319,6 +366,7 @@ def forward( class HunyuanVideoTokenRefiner(nn.Module): + def __init__( self, in_channels: int, @@ -334,8 +382,7 @@ def __init__( hidden_size = num_attention_heads * attention_head_dim self.time_text_embed = CombinedTimestepTextProjEmbeddings( - embedding_dim=hidden_size, pooled_projection_dim=in_channels - ) + embedding_dim=hidden_size, pooled_projection_dim=in_channels) self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) self.token_refiner = HunyuanVideoIndividualTokenRefiner( num_attention_heads=num_attention_heads, @@ -357,7 +404,8 @@ def forward( else: original_dtype = hidden_states.dtype mask_float = attention_mask.float().unsqueeze(-1) - pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) + pooled_projections = (hidden_states * mask_float).sum( + dim=1) / mask_float.sum(dim=1) pooled_projections = pooled_projections.to(original_dtype) temb = self.time_text_embed(timestep, pooled_projections) @@ -368,7 +416,12 @@ def forward( class HunyuanVideoRotaryPosEmbed(nn.Module): - def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: + + def __init__(self, + patch_size: int, + patch_size_t: int, + rope_dim: List[int], + theta: float = 256.0) -> None: super().__init__() self.patch_size = patch_size @@ -378,29 +431,41 @@ def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], thet def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape - rope_sizes = [num_frames * nccl_info.sp_size // self.patch_size_t, height // self.patch_size, width // self.patch_size] + rope_sizes = [ + num_frames * nccl_info.sp_size // self.patch_size_t, + height // self.patch_size, width // self.patch_size + ] axes_grids = [] for i in range(3): # Note: The following line diverges from original behaviour. We create the grid on the device, whereas # original implementation creates it on CPU and then moves it to device. This results in numerical # differences in layerwise debugging outputs, but visually it is the same. - grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32) + grid = torch.arange(0, + rope_sizes[i], + device=hidden_states.device, + dtype=torch.float32) axes_grids.append(grid) grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T] grid = torch.stack(grid, dim=0) # [3, W, H, T] freqs = [] for i in range(3): - freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) + freq = get_1d_rotary_pos_embed(self.rope_dim[i], + grid[i].reshape(-1), + self.theta, + use_real=True) freqs.append(freq) - freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) - freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) + freqs_cos = torch.cat([f[0] for f in freqs], + dim=1) # (W * H * T, D / 2) + freqs_sin = torch.cat([f[1] for f in freqs], + dim=1) # (W * H * T, D / 2) return freqs_cos, freqs_sin class HunyuanVideoSingleTransformerBlock(nn.Module): + def __init__( self, num_attention_heads: int, @@ -440,7 +505,8 @@ def forward( image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.shape[1] - hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + hidden_states = torch.cat([hidden_states, encoder_hidden_states], + dim=1) residual = hidden_states @@ -475,6 +541,7 @@ def forward( class HunyuanVideoTransformerBlock(nn.Module): + def __init__( self, num_attention_heads: int, @@ -487,7 +554,8 @@ def __init__( hidden_size = num_attention_heads * attention_head_dim self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") - self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, + norm_type="layer_norm") self.attn = Attention( query_dim=hidden_size, @@ -503,11 +571,19 @@ def __init__( eps=1e-6, ) - self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + self.norm2 = nn.LayerNorm(hidden_size, + elementwise_affine=False, + eps=1e-6) + self.ff = FeedForward(hidden_size, + mult=mlp_ratio, + activation_fn="gelu-approximate") - self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + self.norm2_context = nn.LayerNorm(hidden_size, + elementwise_affine=False, + eps=1e-6) + self.ff_context = FeedForward(hidden_size, + mult=mlp_ratio, + activation_fn="gelu-approximate") def forward( self, @@ -518,10 +594,10 @@ def forward( freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Input normalization - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( - encoder_hidden_states, emb=temb - ) + encoder_hidden_states, emb=temb) # 2. Joint attention attn_output, context_attn_output = self.attn( @@ -533,25 +609,30 @@ def forward( # 3. Modulation and residual connection hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) - encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze( + 1) norm_hidden_states = self.norm2(hidden_states) norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + norm_hidden_states = norm_hidden_states * ( + 1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * ( + 1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] # 4. Feed-forward ff_output = self.ff(norm_hidden_states) context_ff_output = self.ff_context(norm_encoder_hidden_states) hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze( + 1) * context_ff_output return hidden_states, encoder_hidden_states -class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, + FromOriginalModelMixin): r""" A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). @@ -594,23 +675,23 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, @register_to_config def __init__( - self, - in_channels: int = 16, - out_channels: int = 16, - num_attention_heads: int = 24, - attention_head_dim: int = 128, - num_layers: int = 20, - num_single_layers: int = 40, - num_refiner_layers: int = 2, - mlp_ratio: float = 4.0, - patch_size: int = 2, - patch_size_t: int = 1, - qk_norm: str = "rms_norm", - guidance_embeds: bool = True, - text_embed_dim: int = 4096, - pooled_projection_dim: int = 768, - rope_theta: float = 256.0, - rope_axes_dim: Tuple[int] = (16, 56, 56), + self, + in_channels: int = 16, + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 20, + num_single_layers: int = 40, + num_refiner_layers: int = 2, + mlp_ratio: float = 4.0, + patch_size: int = 2, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + guidance_embeds: bool = True, + text_embed_dim: int = 4096, + pooled_projection_dim: int = 768, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int] = (16, 56, 56), ) -> None: super().__init__() @@ -618,38 +699,45 @@ def __init__( out_channels = out_channels or in_channels # 1. Latent and condition embedders - self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.x_embedder = HunyuanVideoPatchEmbed( + (patch_size_t, patch_size, patch_size), in_channels, inner_dim) self.context_embedder = HunyuanVideoTokenRefiner( - text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers - ) - self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) + text_embed_dim, + num_attention_heads, + attention_head_dim, + num_layers=num_refiner_layers) + self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings( + inner_dim, pooled_projection_dim) # 2. RoPE - self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, + rope_axes_dim, rope_theta) # 3. Dual stream transformer blocks - self.transformer_blocks = nn.ModuleList( - [ - HunyuanVideoTransformerBlock( - num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm - ) - for _ in range(num_layers) - ] - ) + self.transformer_blocks = nn.ModuleList([ + HunyuanVideoTransformerBlock(num_attention_heads, + attention_head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm) + for _ in range(num_layers) + ]) # 4. Single stream transformer blocks - self.single_transformer_blocks = nn.ModuleList( - [ - HunyuanVideoSingleTransformerBlock( - num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm - ) - for _ in range(num_single_layers) - ] - ) + self.single_transformer_blocks = nn.ModuleList([ + HunyuanVideoSingleTransformerBlock(num_attention_heads, + attention_head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm) + for _ in range(num_single_layers) + ]) # 5. Output projection - self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + self.norm_out = AdaLayerNormContinuous(inner_dim, + inner_dim, + elementwise_affine=False, + eps=1e-6) + self.proj_out = nn.Linear( + inner_dim, patch_size_t * patch_size * patch_size * out_channels) self.gradient_checkpointing = False @@ -664,12 +752,15 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]: # set recursively processors = {} - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + def fn_recursive_add_processors(name: str, module: torch.nn.Module, + processors: Dict[str, + AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + fn_recursive_add_processors(f"{name}.{sub_name}", child, + processors) return processors @@ -679,7 +770,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor(self, processor: Union[AttentionProcessor, + Dict[str, + AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -700,7 +793,8 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, + processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) @@ -708,7 +802,8 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + fn_recursive_attn_processor(f"{name}.{sub_name}", child, + processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) @@ -727,10 +822,10 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - if guidance == None: - guidance = torch.tensor( - [6016.0], device=hidden_states.device, dtype=torch.bfloat16 - ) + if guidance is None: + guidance = torch.tensor([6016.0], + device=hidden_states.device, + dtype=torch.bfloat16) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() @@ -742,7 +837,8 @@ def forward( # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + if attention_kwargs is not None and attention_kwargs.get( + "scale", None) is not None: logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) @@ -753,35 +849,43 @@ def forward( post_patch_height = height // p post_patch_width = width // p - pooled_projections = encoder_hidden_states[:, 0, : self.config.pooled_projection_dim] + pooled_projections = encoder_hidden_states[:, 0, :self.config. + pooled_projection_dim] encoder_hidden_states = encoder_hidden_states[:, 1:] - + # 1. RoPE image_rotary_emb = self.rope(hidden_states) # 2. Conditional embeddings temb = self.time_text_embed(timestep, guidance, pooled_projections) hidden_states = self.x_embedder(hidden_states) - encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) + encoder_hidden_states = self.context_embedder(encoder_hidden_states, + timestep, + encoder_attention_mask) # 3. Attention mask preparation latent_sequence_length = hidden_states.shape[1] condition_sequence_length = encoder_hidden_states.shape[1] sequence_length = latent_sequence_length + condition_sequence_length - attention_mask = torch.zeros( - batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool - ) # [B, N, N] - - effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) + attention_mask = torch.zeros(batch_size, + sequence_length, + sequence_length, + device=hidden_states.device, + dtype=torch.bool) # [B, N, N] + + effective_condition_sequence_length = encoder_attention_mask.sum( + dim=1, dtype=torch.int) effective_sequence_length = latent_sequence_length + effective_condition_sequence_length for i in range(batch_size): - attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True + attention_mask[i, :effective_sequence_length[i], : + effective_sequence_length[i]] = True # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) @@ -790,7 +894,9 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = { + "use_reentrant": False + } if is_torch_version(">=", "1.11.0") else {} for block in self.transformer_blocks: hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( @@ -817,21 +923,22 @@ def custom_forward(*inputs): else: for block in self.transformer_blocks: hidden_states, encoder_hidden_states = block( - hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb - ) + hidden_states, encoder_hidden_states, temb, attention_mask, + image_rotary_emb) for block in self.single_transformer_blocks: hidden_states, encoder_hidden_states = block( - hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb - ) + hidden_states, encoder_hidden_states, temb, attention_mask, + image_rotary_emb) # 5. Output projection hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape( - batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p - ) + hidden_states = hidden_states.reshape(batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, -1, p_t, p, p) hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) @@ -840,6 +947,6 @@ def custom_forward(*inputs): unscale_lora_layers(self, lora_scale) if not return_dict: - return (hidden_states,) + return (hidden_states, ) return Transformer2DModelOutput(sample=hidden_states) diff --git a/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py b/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py index 2a636920..18329e4d 100644 --- a/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py +++ b/fastvideo/models/hunyuan_hf/pipeline_hunyuan.py @@ -14,23 +14,28 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import torch.nn.functional as F + import numpy as np import torch -from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast - +import torch.nn.functional as F from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.loaders import HunyuanVideoLoraLoaderMixin -from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from diffusers.models import (AutoencoderKLHunyuanVideo, + HunyuanVideoTransformer3DModel) +from diffusers.pipelines.hunyuan_video.pipeline_output import \ + HunyuanVideoPipelineOutput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import logging, replace_example_docstring from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput -from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info from einops import rearrange +from transformers import (CLIPTextModel, CLIPTokenizer, LlamaModel, + LlamaTokenizerFast) + from fastvideo.utils.communications import all_gather +from fastvideo.utils.parallel_states import (get_sequence_parallel_state, + nccl_info) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -60,18 +65,17 @@ ``` """ - DEFAULT_PROMPT_TEMPLATE = { - "template": ( - "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " - "1. The main content and theme of the video." - "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." - "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." - "4. background environment, light, style and atmosphere." - "5. camera angles, movements, and transitions used in the video:<|eot_id|>" - "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" - ), - "crop_start": 95, + "template": + ("<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"), + "crop_start": + 95, } @@ -108,9 +112,12 @@ def retrieve_timesteps( second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -120,7 +127,8 @@ def retrieve_timesteps( timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -186,13 +194,14 @@ def __init__( tokenizer_2=tokenizer_2, ) - self.vae_scale_factor_temporal = ( - self.vae.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 - ) - self.vae_scale_factor_spatial = ( - self.vae.spatial_compression_ratio if hasattr(self, "vae") and self.vae is not None else 8 - ) - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.vae_scale_factor_temporal = (self.vae.temporal_compression_ratio + if hasattr(self, "vae") + and self.vae is not None else 4) + self.vae_scale_factor_spatial = (self.vae.spatial_compression_ratio + if hasattr(self, "vae") + and self.vae is not None else 8) + self.video_processor = VideoProcessor( + vae_scale_factor=self.vae_scale_factor_spatial) def _get_llama_prompt_embeds( self, @@ -254,9 +263,12 @@ def _get_llama_prompt_embeds( # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) - prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, + seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat( + 1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view( + batch_size * num_videos_per_prompt, seq_len) return prompt_embeds, prompt_attention_mask @@ -283,19 +295,25 @@ def _get_clip_prompt_embeds( ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + untruncated_ids = self.tokenizer_2(prompt, + padding="longest", + return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode( + untruncated_ids[:, max_sequence_length - 1:-1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {max_sequence_length} tokens: {removed_text}" - ) + f" {max_sequence_length} tokens: {removed_text}") - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + prompt_embeds = self.text_encoder_2( + text_input_ids.to(device), + output_hidden_states=False).pooler_output # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, + -1) return prompt_embeds @@ -347,11 +365,13 @@ def check_inputs( prompt_template=None, ): if height % 16 != 0 or width % 16 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + raise ValueError( + f"`height` and `width` have to be divisible by 16 but are {height} and {width}." + ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) @@ -359,25 +379,31 @@ def check_inputs( if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) + " only forward one of the two.") elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) + " only forward one of the two.") elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt is not None and (not isinstance(prompt, str) + and not isinstance(prompt, list)): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + elif prompt_2 is not None and (not isinstance(prompt_2, str) + and not isinstance(prompt_2, list)): + raise ValueError( + f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}" + ) if prompt_template is not None: if not isinstance(prompt_template, dict): - raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + raise ValueError( + f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}" + ) if "template" not in prompt_template: raise ValueError( f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" @@ -392,7 +418,8 @@ def prepare_latents( num_frames: int = 129, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: @@ -411,7 +438,10 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, + generator=generator, + device=device, + dtype=dtype) return latents def enable_vae_slicing(self): @@ -472,7 +502,8 @@ def __call__( sigmas: List[float] = None, guidance_scale: float = 6.0, num_videos_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, @@ -480,9 +511,9 @@ def __call__( output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, + callback_on_step_end: Optional[Union[Callable[[int, int, Dict], + None], PipelineCallback, + MultiPipelineCallbacks]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, max_sequence_length: int = 256, @@ -560,9 +591,10 @@ def __call__( indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + if isinstance(callback_on_step_end, + (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -608,7 +640,8 @@ def __call__( pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) # 4. Prepare timesteps - sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + sigmas = np.linspace(1.0, 0.0, num_inference_steps + + 1)[:-1] if sigmas is None else sigmas timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, @@ -618,7 +651,8 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + num_latent_frames = (num_frames - + 1) // self.vae_scale_factor_temporal + 1 latents = self.prepare_latents( batch_size * num_videos_per_prompt, @@ -634,16 +668,19 @@ def __call__( # check sequence_parallel world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group if get_sequence_parallel_state(): - latents = rearrange( - latents, "b t (n s) h w -> b t n s h w", n=world_size - ).contiguous() + latents = rearrange(latents, + "b t (n s) h w -> b t n s h w", + n=world_size).contiguous() latents = latents[:, :, rank, :, :, :] - + # 6. Prepare guidance condition - guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + guidance = torch.tensor([guidance_scale] * latents.shape[0], + dtype=transformer_dtype, + device=device) * 1000.0 # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = len( + timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -657,16 +694,17 @@ def __call__( if pooled_prompt_embeds.shape[-1] != prompt_embeds.shape[-1]: pooled_prompt_embeds_padding = F.pad( pooled_prompt_embeds, - (0, prompt_embeds.shape[2] - pooled_prompt_embeds.shape[1]), + (0, prompt_embeds.shape[2] - + pooled_prompt_embeds.shape[1]), value=0, ).unsqueeze(1) encoder_hidden_states = torch.cat( - [pooled_prompt_embeds_padding, prompt_embeds], dim=1 - ) - + [pooled_prompt_embeds_padding, prompt_embeds], dim=1) + noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=encoder_hidden_states, # [1, 257, 4096] + hidden_states=latent_model_input, + encoder_hidden_states= + encoder_hidden_states, # [1, 257, 4096] timestep=timestep, encoder_attention_mask=prompt_attention_mask, guidance=guidance, @@ -675,28 +713,37 @@ def __call__( )[0] # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, + t, + latents, + return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and + (i + 1) % self.scheduler.order == 0): progress_bar.update() if get_sequence_parallel_state(): latents = all_gather(latents, dim=2) if not output_type == "latent": - latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + latents = latents.to( + self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] - video = self.video_processor.postprocess_video(video, output_type=output_type) + video = self.video_processor.postprocess_video( + video, output_type=output_type) else: video = latents @@ -704,6 +751,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - return (video,) + return (video, ) return HunyuanVideoPipelineOutput(frames=video) diff --git a/fastvideo/train.py b/fastvideo/train.py index fb1dbc33..c5b02845 100644 --- a/fastvideo/train.py +++ b/fastvideo/train.py @@ -215,7 +215,7 @@ def main(args): if args.use_lora: assert args.model_type != "hunyuan", "LoRA is only supported for huggingface model. Please use hunyuan_hf for lora finetuning" - if args.model_type == "mochi": + if args.model_type == "mochi": pipe = MochiPipeline elif args.model_type == "hunyuan_hf": pipe = HunyuanVideoPipeline @@ -230,8 +230,7 @@ def main(args): if args.resume_from_lora_checkpoint: lora_state_dict = pipe.lora_state_dict( - args.resume_from_lora_checkpoint - ) + args.resume_from_lora_checkpoint) transformer_state_dict = { f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") @@ -445,21 +444,24 @@ def main(args): if step % args.checkpointing_steps == 0: if args.use_lora: # Save LoRA weights - save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, step, pipe - ) + save_lora_checkpoint(transformer, optimizer, rank, + args.output_dir, step, pipe) else: # Your existing checkpoint saving code save_checkpoint(transformer, optimizer, rank, args.output_dir, step) dist.barrier() if args.log_validation and step % args.validation_steps == 0: - log_validation(args, transformer, device, torch.bfloat16, step, shift=args.shift) + log_validation(args, + transformer, + device, + torch.bfloat16, + step, + shift=args.shift) if args.use_lora: - save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipe - ) + save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, + args.max_train_steps, pipe) else: save_checkpoint(transformer, optimizer, rank, args.output_dir, args.max_train_steps) @@ -471,7 +473,11 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--model_type", type=str, default="mochi", help="The type of model to train. Currentlt support [mochi, hunyuan_hf, hunyuan]" + "--model_type", + type=str, + default="mochi", + help= + "The type of model to train. Currentlt support [mochi, hunyuan_hf, hunyuan]" ) # dataset & dataloader parser.add_argument("--data_json_path", type=str, required=True) @@ -557,7 +563,10 @@ def main(args): " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" " training using `--resume_from_checkpoint`."), ) - parser.add_argument("--shift", type=float, default=1.0, help=("Set shift to 7 for hunyuan model.")) + parser.add_argument("--shift", + type=float, + default=1.0, + help=("Set shift to 7 for hunyuan model.")) parser.add_argument( "--resume_from_checkpoint", type=str, @@ -744,4 +753,3 @@ def main(args): args = parser.parse_args() main(args) - \ No newline at end of file diff --git a/fastvideo/utils/checkpoint.py b/fastvideo/utils/checkpoint.py index fbf93ce8..60ae1ba5 100644 --- a/fastvideo/utils/checkpoint.py +++ b/fastvideo/utils/checkpoint.py @@ -15,7 +15,6 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType -from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline from fastvideo.utils.logging_ import main_print @@ -68,6 +67,7 @@ def save_checkpoint_optimizer(model, torch.save(optim_state, optimizer_path) main_print(f"--> checkpoint saved at step {step}") + def save_checkpoint(transformer, rank, output_dir, step): main_print(f"--> saving checkpoint at step {step}") with FSDP.state_dict_type( @@ -260,7 +260,8 @@ def resume_training(model, optimizer, checkpoint_dir, discriminator=False): return model, optimizer, step -def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step, pipeline): +def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step, + pipeline): with FSDP.state_dict_type( transformer, StateDictType.FULL_STATE_DICT, @@ -282,8 +283,7 @@ def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step, pipelin # save lora weight main_print(f"--> saving LoRA checkpoint at step {step}") transformer_lora_layers = get_peft_model_state_dict( - model=transformer, state_dict=full_state_dict - ) + model=transformer, state_dict=full_state_dict) pipeline.save_lora_weights( save_directory=save_dir, transformer_lora_layers=transformer_lora_layers, diff --git a/fastvideo/utils/load.py b/fastvideo/utils/load.py index a1d8792e..fe748cf9 100644 --- a/fastvideo/utils/load.py +++ b/fastvideo/utils/load.py @@ -1,9 +1,9 @@ - import os from pathlib import Path + import torch import torch.nn.functional as F -from diffusers import AutoencoderKLMochi, AutoencoderKLHunyuanVideo +from diffusers import AutoencoderKLHunyuanVideo, AutoencoderKLMochi from torch import nn from transformers import AutoTokenizer, T5EncoderModel @@ -12,15 +12,11 @@ from fastvideo.models.hunyuan.text_encoder import TextEncoder from fastvideo.models.hunyuan.vae.autoencoder_kl_causal_3d import \ AutoencoderKLCausal3D -from fastvideo.models.mochi_hf.modeling_mochi import ( - MochiTransformer3DModel, - MochiTransformerBlock, -) from fastvideo.models.hunyuan_hf.modeling_hunyuan import ( - HunyuanVideoTransformer3DModel, - HunyuanVideoSingleTransformerBlock, - HunyuanVideoTransformerBlock, -) + HunyuanVideoSingleTransformerBlock, HunyuanVideoTransformer3DModel, + HunyuanVideoTransformerBlock) +from fastvideo.models.mochi_hf.modeling_mochi import (MochiTransformer3DModel, + MochiTransformerBlock) from fastvideo.utils.logging_ import main_print hunyuan_config = { @@ -312,8 +308,9 @@ def load_vae(model_type, pretrained_model_name_or_path): fps = 30 elif model_type == "hunyuan_hf": vae = AutoencoderKLHunyuanVideo.from_pretrained( - pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype - ).to("cuda") + pretrained_model_name_or_path, + subfolder="vae", + torch_dtype=weight_dtype).to("cuda") autocast_type = torch.bfloat16 fps = 24 elif model_type == "hunyuan": @@ -347,9 +344,11 @@ def load_vae(model_type, pretrained_model_name_or_path): def load_text_encoder(model_type, pretrained_model_name_or_path, device): if model_type == "mochi": - text_encoder = MochiTextEncoderWrapper(pretrained_model_name_or_path, device) + text_encoder = MochiTextEncoderWrapper(pretrained_model_name_or_path, + device) elif model_type == "hunyuan" or "hunyuan_hf": - text_encoder = HunyuanTextEncoderWrapper(pretrained_model_name_or_path, device) + text_encoder = HunyuanTextEncoderWrapper(pretrained_model_name_or_path, + device) else: raise ValueError(f"Unsupported model type: {model_type}") return text_encoder @@ -358,9 +357,10 @@ def load_text_encoder(model_type, pretrained_model_name_or_path, device): def get_no_split_modules(transformer): # if of type MochiTransformer3DModel if isinstance(transformer, MochiTransformer3DModel): - return (MochiTransformerBlock,) + return (MochiTransformerBlock, ) elif isinstance(transformer, HunyuanVideoTransformer3DModel): - return (HunyuanVideoSingleTransformerBlock, HunyuanVideoTransformerBlock) + return (HunyuanVideoSingleTransformerBlock, + HunyuanVideoTransformerBlock) elif isinstance(transformer, HYVideoDiffusionTransformer): return (MMDoubleStreamBlock, MMSingleStreamBlock) else: diff --git a/fastvideo/utils/validation.py b/fastvideo/utils/validation.py index fb506285..42294e7e 100644 --- a/fastvideo/utils/validation.py +++ b/fastvideo/utils/validation.py @@ -4,7 +4,6 @@ import numpy as np import torch -import wandb from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.utils import export_to_video from diffusers.utils.torch_utils import randn_tensor @@ -12,6 +11,7 @@ from einops import rearrange from tqdm import tqdm +import wandb from fastvideo.distill.solver import PCMFMScheduler from fastvideo.models.mochi_hf.pipeline_mochi import ( linear_quadratic_schedule, retrieve_timesteps) @@ -106,7 +106,7 @@ def sample_validation_video( threshold_noise = 0.025 sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) sigmas = np.array(sigmas) - if scheduler_type == "euler" and model_type == "mochi": #todo + if scheduler_type == "euler" and model_type == "mochi": #todo timesteps, num_inference_steps = retrieve_timesteps( scheduler, num_inference_steps, From 3f2fc1ab5fd71b9f83cd8540b29fc5d762303248 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Sun, 12 Jan 2025 21:57:05 +0000 Subject: [PATCH 25/42] fix train.py --- fastvideo/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/fastvideo/train.py b/fastvideo/train.py index c5b02845..65de715e 100644 --- a/fastvideo/train.py +++ b/fastvideo/train.py @@ -145,13 +145,16 @@ def train_one_step( dtype=latents.dtype, ) noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise - + training_guidance = torch.tensor([1000.0], + device=noisy_model_input.device, + dtype=torch.bfloat16) with torch.autocast("cuda", torch.bfloat16): model_pred = transformer( noisy_model_input, encoder_hidden_states, timesteps, encoder_attention_mask, # B, L + training_guidance, return_dict=False, )[0] @@ -448,7 +451,7 @@ def main(args): args.output_dir, step, pipe) else: # Your existing checkpoint saving code - save_checkpoint(transformer, optimizer, rank, args.output_dir, + save_checkpoint(transformer, rank, args.output_dir, step) dist.barrier() if args.log_validation and step % args.validation_steps == 0: @@ -463,7 +466,7 @@ def main(args): save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipe) else: - save_checkpoint(transformer, optimizer, rank, args.output_dir, + save_checkpoint(transformer, rank, args.output_dir, args.max_train_steps) if get_sequence_parallel_state(): From 32a7bb3e27eded78a58dff27117f5738b94a3b66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Sun, 12 Jan 2025 22:12:54 +0000 Subject: [PATCH 26/42] update README --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 15156eb4..a85c8c0a 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ We now support NF4 and LLM-INT8 quantized inference using BitsAndBytes for FastH # Download the model weight python scripts/huggingface/download_hf.py --repo_id=FastVideo/FastHunyuan-diffusers --local_dir=data/FastHunyuan-diffusers --repo_type=model # CLI inference -bash scripts/inference/inference_diffusers_hunyuan.sh +bash scripts/inference/inference_hunyuan_hf_quantization.sh ``` For more information about the VRAM requirements for BitsAndBytes quantization, please refer to the table below (timing measured on an H100 GPU): @@ -124,8 +124,9 @@ bash scripts/finetune/finetune_mochi.sh # for mochi ``` **Note that for finetuning, we did not tune the hyperparameters in the provided script** ### ⚡ Lora Finetune -Currently, we only provide Lora Finetune for Mochi model, the command for Lora Finetune is +Currently, both Mochi and Huyuan model support Lora Finetune through diffusers. ``` +bash scripts/finetune/finetune_hunyuan_hf_lora.sh bash scripts/finetune/finetune_mochi_lora.sh ``` ### Minimum Hardware Requirement From 5af5b6ac00bbd65450cec8748b23e00f47c611db Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 01:00:15 +0000 Subject: [PATCH 27/42] add dataset preparation scripts --- .../dataset_preparation/prepare_json_file.py | 133 +++++++++++++++ scripts/dataset_preparation/resize_videos.py | 158 ++++++++++++++++++ 2 files changed, 291 insertions(+) create mode 100644 scripts/dataset_preparation/prepare_json_file.py create mode 100644 scripts/dataset_preparation/resize_videos.py diff --git a/scripts/dataset_preparation/prepare_json_file.py b/scripts/dataset_preparation/prepare_json_file.py new file mode 100644 index 00000000..8433666d --- /dev/null +++ b/scripts/dataset_preparation/prepare_json_file.py @@ -0,0 +1,133 @@ +import os +import json +import cv2 +from pathlib import Path + +def get_video_info(video_path, prompt_text): + """Extract video information using OpenCV and corresponding prompt text""" + cap = cv2.VideoCapture(str(video_path)) + + if not cap.isOpened(): + print(f"Error: Could not open video {video_path}") + return None + + # Get video properties + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + duration = frame_count / fps if fps > 0 else 0 + + cap.release() + + return { + "path": video_path.name, + "resolution": { + "width": width, + "height": height + }, + "fps": fps, + "duration": duration, + "cap": [prompt_text] + } + +def read_prompt_file(prompt_path): + """Read and return the content of a prompt file""" + try: + with open(prompt_path, 'r', encoding='utf-8') as f: + return f.read().strip() + except Exception as e: + print(f"Error reading prompt file {prompt_path}: {e}") + return None + +def process_videos_and_prompts(video_dir_path, prompt_dir_path, verbose=False): + """Process videos and their corresponding prompt files + + Args: + video_dir_path (str): Path to directory containing video files + prompt_dir_path (str): Path to directory containing prompt files + verbose (bool): Whether to print verbose processing information + """ + video_dir = Path(video_dir_path) + prompt_dir = Path(prompt_dir_path) + processed_data = [] + + # Ensure directories exist + if not video_dir.exists() or not prompt_dir.exists(): + print(f"Error: One or both directories do not exist:\nVideos: {video_dir}\nPrompts: {prompt_dir}") + return [] + + # Process each video file + for video_file in video_dir.glob('*.mp4'): + video_name = video_file.stem + prompt_file = prompt_dir / f"{video_name}.txt" + + # Check if corresponding prompt file exists + if not prompt_file.exists(): + print(f"Warning: No prompt file found for video {video_name}") + continue + + # Read prompt content + prompt_text = read_prompt_file(prompt_file) + if prompt_text is None: + continue + + # Process video and add to results + video_info = get_video_info(video_file, prompt_text) + if video_info: + processed_data.append(video_info) + + return processed_data + +def save_results(processed_data, output_path): + """Save processed data to JSON file + + Args: + processed_data (list): List of processed video information + output_path (str): Full path for output JSON file + """ + output_path = Path(output_path) + + # Create parent directories if they don't exist + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(processed_data, f, indent=2, ensure_ascii=False) + + return output_path + +def parse_args(): + """Parse command line arguments""" + import argparse + + parser = argparse.ArgumentParser(description='Process videos and their corresponding prompt files') + parser.add_argument('--video_dir', '-v', required=True, + help='Directory containing video files') + parser.add_argument('--prompt_dir', '-p', required=True, + help='Directory containing prompt text files') + parser.add_argument('--output_path', '-o', required=True, + help='Full path for output JSON file (e.g., /path/to/output/videos2caption.json)') + parser.add_argument('--verbose', action='store_true', + help='Print verbose processing information') + + return parser.parse_args() + +if __name__ == "__main__": + # Parse command line arguments + args = parse_args() + + # Process videos and prompts + processed_videos = process_videos_and_prompts(args.video_dir, args.prompt_dir, args.verbose) + + if processed_videos: + # Save results + output_path = save_results(processed_videos, args.output_path) + + print(f"\nProcessed {len(processed_videos)} videos") + print(f"Results saved to: {output_path}") + + # Print example of processed data + print("\nExample of processed video info:") + print(json.dumps(processed_videos[0], indent=2)) + else: + print("No videos were processed successfully") \ No newline at end of file diff --git a/scripts/dataset_preparation/resize_videos.py b/scripts/dataset_preparation/resize_videos.py new file mode 100644 index 00000000..321e4999 --- /dev/null +++ b/scripts/dataset_preparation/resize_videos.py @@ -0,0 +1,158 @@ +import argparse +from pathlib import Path +import time +import logging +from tqdm import tqdm +from moviepy.editor import VideoFileClip +import numpy as np +from skimage.transform import resize +from concurrent.futures import ProcessPoolExecutor, as_completed +import multiprocessing +import random +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[logging.FileHandler('video_processing.log')] +) + +def is_16_9_ratio(width: int, height: int, tolerance: float = 0.1) -> bool: + target_ratio = 16 / 9 + actual_ratio = width / height + return abs(actual_ratio - target_ratio) <= (target_ratio * tolerance) + +def resize_video(args_tuple): + """ + Resize a single video file. + args_tuple: (input_file, output_dir, width, height, fps) + """ + input_file, output_dir, width, height, fps = args_tuple + video = None + resized = None + try: + output_file = output_dir / f"{input_file.name}" + + if output_file.exists(): + output_file.unlink() + + video = VideoFileClip(str(input_file)) + + if not is_16_9_ratio(video.w, video.h): + return (input_file.name, "skipped", "Not 16:9") + + def process_frame(frame): + frame_float = frame.astype(float) / 255.0 + resized = resize(frame_float, + (height, width, 3), + mode='reflect', + anti_aliasing=True, + preserve_range=True) + return (resized * 255).astype(np.uint8) + + resized = video.fl_image(process_frame) + resized = resized.set_fps(fps) + + resized.write_videofile( + str(output_file), + codec='libx264', + audio_codec='aac', + temp_audiofile=f'temp-audio-{input_file.stem}.m4a', + remove_temp=True, + verbose=False, + logger=None, + fps=fps + ) + + return (input_file.name, "success", None) + + except Exception as e: + return (input_file.name, "failed", str(e)) + finally: + try: + if video is not None: + video.close() + if resized is not None: + resized.close() + except: + pass + +def process_folder(args): + input_path = Path(args.input_dir) + output_path = Path(args.output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.webm'} + video_files = [f for f in input_path.iterdir() if f.is_file() and f.suffix.lower() in video_extensions] + + if not video_files: + print(f"No video files found in {args.input_dir}") + return + + print(f"Found {len(video_files)} videos") + print(f"Target: {args.width}x{args.height} at {args.fps}fps") + + # Prepare arguments for parallel processing + process_args = [ + (video_file, output_path, args.width, args.height, args.fps) + for video_file in video_files + ] + + successful = 0 + skipped = 0 + failed = [] + + # Use ProcessPoolExecutor instead of ThreadPoolExecutor + with tqdm(total=len(video_files), desc="Converting videos", dynamic_ncols=True) as pbar: + # Use max_workers as specified or default to CPU count + max_workers = args.max_workers + with ProcessPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks + future_to_file = { + executor.submit(resize_video, arg): arg[0] + for arg in process_args + } + + # Process completed tasks + for future in as_completed(future_to_file): + filename, status, message = future.result() + if status == "success": + successful += 1 + elif status == "skipped": + skipped += 1 + else: + failed.append((filename, message)) + pbar.update(1) + + # Print final summary + print(f"\nDone! Processed: {successful}, Skipped: {skipped}, Failed: {len(failed)}") + if failed: + print("Failed files:") + for fname, error in failed: + print(f"- {fname}: {error}") + +def parse_args(): + parser = argparse.ArgumentParser(description='Batch resize videos to specified resolution and FPS (16:9 only)') + parser.add_argument('--input_dir', required=True, help='Input directory containing video files') + parser.add_argument('--output_dir', required=True, help='Output directory for processed videos') + parser.add_argument('--width', type=int, default=1280, help='Target width in pixels (default: 848)') + parser.add_argument('--height', type=int, default=720, help='Target height in pixels (default: 480)') + parser.add_argument('--fps', type=int, default=30, help='Target frames per second (default: 30)') + parser.add_argument('--max_workers', type=int, default=4, help='Maximum number of concurrent processes (default: 4)') + parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default='INFO', help='Set the logging level (default: INFO)') + return parser.parse_args() + +def main(): + args = parse_args() + logging.getLogger().setLevel(getattr(logging, args.log_level)) + + if not Path(args.input_dir).exists(): + logging.error(f"Input directory not found: {args.input_dir}") + return + + start_time = time.time() + process_folder(args) + duration = time.time() - start_time + logging.info(f"Batch processing completed in {duration:.2f} seconds") + +if __name__ == "__main__": + main() From 54d1e5d1221678f136e31598c45523c435a13691 Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 01:14:24 +0000 Subject: [PATCH 28/42] format --- fastvideo/train.py | 7 +- .../dataset_preparation/prepare_json_file.py | 87 +++++---- scripts/dataset_preparation/resize_videos.py | 180 ++++++++++-------- 3 files changed, 155 insertions(+), 119 deletions(-) diff --git a/fastvideo/train.py b/fastvideo/train.py index 65de715e..c3c8ebee 100644 --- a/fastvideo/train.py +++ b/fastvideo/train.py @@ -146,8 +146,8 @@ def train_one_step( ) noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise training_guidance = torch.tensor([1000.0], - device=noisy_model_input.device, - dtype=torch.bfloat16) + device=noisy_model_input.device, + dtype=torch.bfloat16) with torch.autocast("cuda", torch.bfloat16): model_pred = transformer( noisy_model_input, @@ -451,8 +451,7 @@ def main(args): args.output_dir, step, pipe) else: # Your existing checkpoint saving code - save_checkpoint(transformer, rank, args.output_dir, - step) + save_checkpoint(transformer, rank, args.output_dir, step) dist.barrier() if args.log_validation and step % args.validation_steps == 0: log_validation(args, diff --git a/scripts/dataset_preparation/prepare_json_file.py b/scripts/dataset_preparation/prepare_json_file.py index 8433666d..3a9ca8d0 100644 --- a/scripts/dataset_preparation/prepare_json_file.py +++ b/scripts/dataset_preparation/prepare_json_file.py @@ -1,25 +1,26 @@ -import os import json -import cv2 from pathlib import Path +import cv2 + + def get_video_info(video_path, prompt_text): """Extract video information using OpenCV and corresponding prompt text""" cap = cv2.VideoCapture(str(video_path)) - + if not cap.isOpened(): print(f"Error: Could not open video {video_path}") return None - + # Get video properties width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = frame_count / fps if fps > 0 else 0 - + cap.release() - + return { "path": video_path.name, "resolution": { @@ -31,6 +32,7 @@ def get_video_info(video_path, prompt_text): "cap": [prompt_text] } + def read_prompt_file(prompt_path): """Read and return the content of a prompt file""" try: @@ -40,6 +42,7 @@ def read_prompt_file(prompt_path): print(f"Error reading prompt file {prompt_path}: {e}") return None + def process_videos_and_prompts(video_dir_path, prompt_dir_path, verbose=False): """Process videos and their corresponding prompt files @@ -51,34 +54,37 @@ def process_videos_and_prompts(video_dir_path, prompt_dir_path, verbose=False): video_dir = Path(video_dir_path) prompt_dir = Path(prompt_dir_path) processed_data = [] - + # Ensure directories exist if not video_dir.exists() or not prompt_dir.exists(): - print(f"Error: One or both directories do not exist:\nVideos: {video_dir}\nPrompts: {prompt_dir}") + print( + f"Error: One or both directories do not exist:\nVideos: {video_dir}\nPrompts: {prompt_dir}" + ) return [] - + # Process each video file for video_file in video_dir.glob('*.mp4'): video_name = video_file.stem prompt_file = prompt_dir / f"{video_name}.txt" - + # Check if corresponding prompt file exists if not prompt_file.exists(): print(f"Warning: No prompt file found for video {video_name}") continue - + # Read prompt content prompt_text = read_prompt_file(prompt_file) if prompt_text is None: continue - + # Process video and add to results video_info = get_video_info(video_file, prompt_text) if video_info: processed_data.append(video_info) - + return processed_data + def save_results(processed_data, output_path): """Save processed data to JSON file @@ -87,47 +93,62 @@ def save_results(processed_data, output_path): output_path (str): Full path for output JSON file """ output_path = Path(output_path) - + # Create parent directories if they don't exist output_path.parent.mkdir(parents=True, exist_ok=True) - + with open(output_path, 'w', encoding='utf-8') as f: json.dump(processed_data, f, indent=2, ensure_ascii=False) - + return output_path + def parse_args(): """Parse command line arguments""" import argparse - - parser = argparse.ArgumentParser(description='Process videos and their corresponding prompt files') - parser.add_argument('--video_dir', '-v', required=True, - help='Directory containing video files') - parser.add_argument('--prompt_dir', '-p', required=True, - help='Directory containing prompt text files') - parser.add_argument('--output_path', '-o', required=True, - help='Full path for output JSON file (e.g., /path/to/output/videos2caption.json)') - parser.add_argument('--verbose', action='store_true', - help='Print verbose processing information') - + + parser = argparse.ArgumentParser( + description='Process videos and their corresponding prompt files') + parser.add_argument('--video_dir', + '-v', + required=True, + help='Directory containing video files') + parser.add_argument('--prompt_dir', + '-p', + required=True, + help='Directory containing prompt text files') + parser.add_argument( + '--output_path', + '-o', + required=True, + help= + 'Full path for output JSON file (e.g., /path/to/output/videos2caption.json)' + ) + parser.add_argument('--verbose', + action='store_true', + help='Print verbose processing information') + return parser.parse_args() + if __name__ == "__main__": # Parse command line arguments args = parse_args() - + # Process videos and prompts - processed_videos = process_videos_and_prompts(args.video_dir, args.prompt_dir, args.verbose) - + processed_videos = process_videos_and_prompts(args.video_dir, + args.prompt_dir, + args.verbose) + if processed_videos: # Save results output_path = save_results(processed_videos, args.output_path) - + print(f"\nProcessed {len(processed_videos)} videos") print(f"Results saved to: {output_path}") - + # Print example of processed data print("\nExample of processed video info:") print(json.dumps(processed_videos[0], indent=2)) else: - print("No videos were processed successfully") \ No newline at end of file + print("No videos were processed successfully") diff --git a/scripts/dataset_preparation/resize_videos.py b/scripts/dataset_preparation/resize_videos.py index 321e4999..523fa6ef 100644 --- a/scripts/dataset_preparation/resize_videos.py +++ b/scripts/dataset_preparation/resize_videos.py @@ -1,26 +1,26 @@ import argparse -from pathlib import Path -import time import logging -from tqdm import tqdm -from moviepy.editor import VideoFileClip +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path + import numpy as np +from moviepy.editor import VideoFileClip from skimage.transform import resize -from concurrent.futures import ProcessPoolExecutor, as_completed -import multiprocessing -import random +from tqdm import tqdm + # Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[logging.FileHandler('video_processing.log')] -) +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[logging.FileHandler('video_processing.log')]) + def is_16_9_ratio(width: int, height: int, tolerance: float = 0.1) -> bool: target_ratio = 16 / 9 actual_ratio = width / height return abs(actual_ratio - target_ratio) <= (target_ratio * tolerance) + def resize_video(args_tuple): """ Resize a single video file. @@ -29,80 +29,69 @@ def resize_video(args_tuple): input_file, output_dir, width, height, fps = args_tuple video = None resized = None - try: - output_file = output_dir / f"{input_file.name}" - - if output_file.exists(): - output_file.unlink() - - video = VideoFileClip(str(input_file)) - - if not is_16_9_ratio(video.w, video.h): - return (input_file.name, "skipped", "Not 16:9") - - def process_frame(frame): - frame_float = frame.astype(float) / 255.0 - resized = resize(frame_float, - (height, width, 3), - mode='reflect', - anti_aliasing=True, - preserve_range=True) - return (resized * 255).astype(np.uint8) - - resized = video.fl_image(process_frame) - resized = resized.set_fps(fps) - - resized.write_videofile( - str(output_file), - codec='libx264', - audio_codec='aac', - temp_audiofile=f'temp-audio-{input_file.stem}.m4a', - remove_temp=True, - verbose=False, - logger=None, - fps=fps - ) - - return (input_file.name, "success", None) - - except Exception as e: - return (input_file.name, "failed", str(e)) - finally: - try: - if video is not None: - video.close() - if resized is not None: - resized.close() - except: - pass + output_file = output_dir / f"{input_file.name}" + + if output_file.exists(): + output_file.unlink() + + video = VideoFileClip(str(input_file)) + + if not is_16_9_ratio(video.w, video.h): + return (input_file.name, "skipped", "Not 16:9") + + def process_frame(frame): + frame_float = frame.astype(float) / 255.0 + resized = resize(frame_float, (height, width, 3), + mode='reflect', + anti_aliasing=True, + preserve_range=True) + return (resized * 255).astype(np.uint8) + + resized = video.fl_image(process_frame) + resized = resized.set_fps(fps) + + resized.write_videofile(str(output_file), + codec='libx264', + audio_codec='aac', + temp_audiofile=f'temp-audio-{input_file.stem}.m4a', + remove_temp=True, + verbose=False, + logger=None, + fps=fps) + + return (input_file.name, "success", None) + def process_folder(args): input_path = Path(args.input_dir) output_path = Path(args.output_dir) output_path.mkdir(parents=True, exist_ok=True) - + video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.webm'} - video_files = [f for f in input_path.iterdir() if f.is_file() and f.suffix.lower() in video_extensions] - + video_files = [ + f for f in input_path.iterdir() + if f.is_file() and f.suffix.lower() in video_extensions + ] + if not video_files: print(f"No video files found in {args.input_dir}") return print(f"Found {len(video_files)} videos") print(f"Target: {args.width}x{args.height} at {args.fps}fps") - + # Prepare arguments for parallel processing - process_args = [ - (video_file, output_path, args.width, args.height, args.fps) - for video_file in video_files - ] - + process_args = [(video_file, output_path, args.width, args.height, + args.fps) for video_file in video_files] + successful = 0 skipped = 0 failed = [] - + # Use ProcessPoolExecutor instead of ThreadPoolExecutor - with tqdm(total=len(video_files), desc="Converting videos", dynamic_ncols=True) as pbar: + with tqdm(total=len(video_files), + desc="Converting videos", + dynamic_ncols=True) as pbar: # Use max_workers as specified or default to CPU count max_workers = args.max_workers with ProcessPoolExecutor(max_workers=max_workers) as executor: @@ -111,7 +100,7 @@ def process_folder(args): executor.submit(resize_video, arg): arg[0] for arg in process_args } - + # Process completed tasks for future in as_completed(future_to_file): filename, status, message = future.result() @@ -122,37 +111,64 @@ def process_folder(args): else: failed.append((filename, message)) pbar.update(1) - + # Print final summary - print(f"\nDone! Processed: {successful}, Skipped: {skipped}, Failed: {len(failed)}") + print( + f"\nDone! Processed: {successful}, Skipped: {skipped}, Failed: {len(failed)}" + ) if failed: print("Failed files:") for fname, error in failed: print(f"- {fname}: {error}") + def parse_args(): - parser = argparse.ArgumentParser(description='Batch resize videos to specified resolution and FPS (16:9 only)') - parser.add_argument('--input_dir', required=True, help='Input directory containing video files') - parser.add_argument('--output_dir', required=True, help='Output directory for processed videos') - parser.add_argument('--width', type=int, default=1280, help='Target width in pixels (default: 848)') - parser.add_argument('--height', type=int, default=720, help='Target height in pixels (default: 480)') - parser.add_argument('--fps', type=int, default=30, help='Target frames per second (default: 30)') - parser.add_argument('--max_workers', type=int, default=4, help='Maximum number of concurrent processes (default: 4)') - parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default='INFO', help='Set the logging level (default: INFO)') + parser = argparse.ArgumentParser( + description= + 'Batch resize videos to specified resolution and FPS (16:9 only)') + parser.add_argument('--input_dir', + required=True, + help='Input directory containing video files') + parser.add_argument('--output_dir', + required=True, + help='Output directory for processed videos') + parser.add_argument('--width', + type=int, + default=1280, + help='Target width in pixels (default: 848)') + parser.add_argument('--height', + type=int, + default=720, + help='Target height in pixels (default: 480)') + parser.add_argument('--fps', + type=int, + default=30, + help='Target frames per second (default: 30)') + parser.add_argument( + '--max_workers', + type=int, + default=4, + help='Maximum number of concurrent processes (default: 4)') + parser.add_argument('--log-level', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], + default='INFO', + help='Set the logging level (default: INFO)') return parser.parse_args() + def main(): args = parse_args() logging.getLogger().setLevel(getattr(logging, args.log_level)) - + if not Path(args.input_dir).exists(): logging.error(f"Input directory not found: {args.input_dir}") return - + start_time = time.time() process_folder(args) duration = time.time() - start_time logging.info(f"Batch processing completed in {duration:.2f} seconds") + if __name__ == "__main__": main() From 2017c082526d0fd15a930c458066894a1699e6bd Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 03:10:52 +0000 Subject: [PATCH 29/42] syn with main; add readme --- README.md | 36 ++++++++++++++++++++++--- scripts/inference/inference_mochi_hf.sh | 18 +++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) create mode 100644 scripts/inference/inference_mochi_hf.sh diff --git a/README.md b/README.md index 32a6e2c4..c76f5cec 100644 --- a/README.md +++ b/README.md @@ -121,15 +121,45 @@ bash scripts/finetune/finetune_mochi.sh # for mochi ``` **Note that for finetuning, we did not tune the hyperparameters in the provided script** ### ⚡ Lora Finetune -Currently, both Mochi and Huyuan model support Lora Finetune through diffusers. +#### 🎯 Demos of Black-Myth Wukong +https://drive.google.com/file/d/10GxtDcrda7fJofrzfn0D5sPQZlO9D--F/view?usp=sharing +More demos and prompts can be found in [here](https://huggingface.co/FastVideo/Hunyuan-Black-Myth-Wukong-lora-weight). You can download the Lora weight through: +```bash +python scripts/huggingface/download_hf.py --repo_id=FastVideo/Hunyuan-Black-Myth-Wukong-lora-weight --local_dir=data/Hunyuan-Black-Myth-Wukong-lora-weight --repo_type=model +``` +Currently, both Mochi and Hunyuan models support Lora finetuning through diffusers. To generate personalized videos from your own dataset, you'll need to follow three main steps: dataset preparation, finetuning, and inference. + +#### Dataset Preparation +We provide scripts to better help you get started to train on your own characters! +You can run this to organize your dataset to get the videos2caption.json before preprocess. Specify your video folder and corresponding caption folder(Caption files should be .txt files and have the same name with its video): +``` +python scripts/dataset_preparation/prepare_json_file.py --video_dir data/input_videos/ --prompt_dir data/captions/ --output_path data/output_folder/videos2caption.json --verbose +``` +Also, we provide script to resize your videos: +``` +python scripts/data_preprocess/resize_videos.py \ + --input_dir data/raw_videos/ \ + --output_dir data/resized_videos/ \ + --width 1280 \ + --height 720 \ + --fps 30 +``` +#### Finetuning +After basic dataset preparation and preprocess, you can start to finetune your model using Lora: ``` bash scripts/finetune/finetune_hunyuan_hf_lora.sh bash scripts/finetune/finetune_mochi_lora.sh ``` -### Minimum Hardware Requirement +#### Finetuning +For inference with Lora checkpoint, you can run the following scripts with Additional parameter --lora_checkpoint_dir: +``` +bash scripts/inference/inference_hunyuan_hf.sh +bash scripts/inference/inference_mochi_hf.sh +``` +#### Minimum Hardware Requirement - 40 GB GPU memory each for 2 GPUs with lora - 30 GB GPU memory each for 2 GPUs with CPU offload and lora. -### Finetune with Both Image and Video +#### Finetune with Both Image and Video Our codebase support finetuning with both image and video. ```bash bash scripts/finetune/finetune_hunyuan.sh diff --git a/scripts/inference/inference_mochi_hf.sh b/scripts/inference/inference_mochi_hf.sh new file mode 100644 index 00000000..c7ff9547 --- /dev/null +++ b/scripts/inference/inference_mochi_hf.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +num_gpus=4 + +torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ + fastvideo/sample/sample_t2v_mochi.py \ + --model_path data/FastMochi-diffusers \ + --prompt_path "assets/prompt.txt" \ + --num_frames 163 \ + --height 480 \ + --width 848 \ + --num_inference_steps 8 \ + --guidance_scale 1.5 \ + --output_path outputs_video/mochi_hf/ \ + --seed 1024 \ + --scheduler_type "pcm_linear_quadratic" \ + --linear_threshold 0.1 \ + --linear_range 0.75 From 26e7dc9c8f467017da4c9c5b5bbc9ac5577e7964 Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 03:12:46 +0000 Subject: [PATCH 30/42] syn with main; add readme --- scripts/inference/inference_mochi_sp.sh | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 scripts/inference/inference_mochi_sp.sh diff --git a/scripts/inference/inference_mochi_sp.sh b/scripts/inference/inference_mochi_sp.sh deleted file mode 100644 index 1ded5727..00000000 --- a/scripts/inference/inference_mochi_sp.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -num_gpus=4 - -torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ - fastvideo/sample/sample_t2v_mochi.py \ - --model_path data/FastMochi-diffusers \ - --prompt_path "assets/prompt.txt" \ - --num_frames 163 \ - --height 480 \ - --width 848 \ - --num_inference_steps 8 \ - --guidance_scale 1.5 \ - --output_path outputs_video/mochi_sp/ \ - --seed 1024 \ - --scheduler_type "pcm_linear_quadratic" \ - --linear_threshold 0.1 \ - --linear_range 0.75 From 9109210220a447cf2284dcefb911f3119a5ef8bd Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 03:30:41 +0000 Subject: [PATCH 31/42] ready for lora release --- README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index c76f5cec..18799c27 100644 --- a/README.md +++ b/README.md @@ -120,10 +120,8 @@ Then you can run the finetune with: bash scripts/finetune/finetune_mochi.sh # for mochi ``` **Note that for finetuning, we did not tune the hyperparameters in the provided script** -### ⚡ Lora Finetune -#### 🎯 Demos of Black-Myth Wukong -https://drive.google.com/file/d/10GxtDcrda7fJofrzfn0D5sPQZlO9D--F/view?usp=sharing -More demos and prompts can be found in [here](https://huggingface.co/FastVideo/Hunyuan-Black-Myth-Wukong-lora-weight). You can download the Lora weight through: +### ⚡ Lora Finetune +Demos and prompts of Black-Myth-Wukong can be found in [here](https://huggingface.co/FastVideo/Hunyuan-Black-Myth-Wukong-lora-weight). You can download the Lora weight through: ```bash python scripts/huggingface/download_hf.py --repo_id=FastVideo/Hunyuan-Black-Myth-Wukong-lora-weight --local_dir=data/Hunyuan-Black-Myth-Wukong-lora-weight --repo_type=model ``` From e357df9fe249542b3ba43536998c91a33110c6f4 Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 04:39:32 +0000 Subject: [PATCH 32/42] format check --- fastvideo/sample/sample_t2v_hunyuan_hf.py | 144 +++++++++++++--------- 1 file changed, 84 insertions(+), 60 deletions(-) diff --git a/fastvideo/sample/sample_t2v_hunyuan_hf.py b/fastvideo/sample/sample_t2v_hunyuan_hf.py index d97a6ddb..b98e56c1 100644 --- a/fastvideo/sample/sample_t2v_hunyuan_hf.py +++ b/fastvideo/sample/sample_t2v_hunyuan_hf.py @@ -2,7 +2,7 @@ import torch.distributed as dist from diffusers import BitsAndBytesConfig from diffusers.utils import export_to_video -import imageio as iio +import imageio as iio import math import numpy as np import io @@ -17,17 +17,20 @@ from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline from fastvideo.models.hunyuan_hf.modeling_hunyuan import HunyuanVideoTransformer3DModel + def initialize_distributed(): os.environ["TOKENIZERS_PARALLELISM"] = "false" local_rank = int(os.getenv("RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) print("world_size", world_size) torch.cuda.set_device(local_rank) - dist.init_process_group( - backend="nccl", init_method="env://", world_size=world_size, rank=local_rank - ) + dist.init_process_group(backend="nccl", + init_method="env://", + world_size=world_size, + rank=local_rank) initialize_sequence_parallel_state(world_size) - + + def inference(args): initialize_distributed() print(nccl_info.sp_size) @@ -36,29 +39,35 @@ def inference(args): weight_dtype = torch.bfloat16 if args.transformer_path is not None: - transformer = HunyuanVideoTransformer3DModel.from_pretrained(args.transformer_path) + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + args.transformer_path) else: transformer = HunyuanVideoTransformer3DModel.from_pretrained( - args.model_path, subfolder="transformer/", torch_dtype=weight_dtype - ) + args.model_path, + subfolder="transformer/", + torch_dtype=weight_dtype) - pipe = HunyuanVideoPipeline.from_pretrained( - args.model_path, transformer=transformer, torch_dtype=weight_dtype - ) + pipe = HunyuanVideoPipeline.from_pretrained(args.model_path, + transformer=transformer, + torch_dtype=weight_dtype) pipe.enable_vae_tiling() if args.lora_checkpoint_dir is not None: print(f"Loading LoRA weights from {args.lora_checkpoint_dir}") - config_path = os.path.join(args.lora_checkpoint_dir, "lora_config.json") + config_path = os.path.join(args.lora_checkpoint_dir, + "lora_config.json") with open(config_path, "r") as f: lora_config_dict = json.load(f) rank = lora_config_dict["lora_params"]["lora_rank"] lora_alpha = lora_config_dict["lora_params"]["lora_alpha"] lora_scaling = lora_alpha / rank - pipe.load_lora_weights(args.lora_checkpoint_dir, adapter_name="default") + pipe.load_lora_weights(args.lora_checkpoint_dir, + adapter_name="default") pipe.set_adapters(["default"], [lora_scaling]) - print(f"Successfully Loaded LoRA weights from {args.lora_checkpoint_dir}") + print( + f"Successfully Loaded LoRA weights from {args.lora_checkpoint_dir}" + ) if args.cpu_offload: pipe.enable_model_cpu_offload(device) else: @@ -67,18 +76,13 @@ def inference(args): # Generate videos from the input prompt if args.prompt_embed_path is not None: - prompt_embeds = ( - torch.load(args.prompt_embed_path, map_location="cpu", weights_only=True) - .to(device) - .unsqueeze(0) - ) - encoder_attention_mask = ( - torch.load( - args.encoder_attention_mask_path, map_location="cpu", weights_only=True - ) - .to(device) - .unsqueeze(0) - ) + prompt_embeds = (torch.load(args.prompt_embed_path, + map_location="cpu", + weights_only=True).to(device).unsqueeze(0)) + encoder_attention_mask = (torch.load( + args.encoder_attention_mask_path, + map_location="cpu", + weights_only=True).to(device).unsqueeze(0)) prompts = None elif args.prompt_path is not None: prompts = [line.strip() for line in open(args.prompt_path, "r")] @@ -121,10 +125,11 @@ def inference(args): num_inference_steps=args.num_inference_steps, generator=generator, ).frames - + if nccl_info.global_rank <= 0: export_to_video(videos[0], args.output_path + ".mp4", fps=24) + def inference_quantization(args): torch.manual_seed(args.seed) device = "cuda" if torch.cuda.is_available() else "cpu" @@ -138,7 +143,8 @@ def inference_quantization(args): "5. Camera angles, movements, and transitions used in the video." "6. Thematic and aesthetic concepts associated with the scene, i.e. realistic, futuristic, fairy tale, etc<|eot_id|>" "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"), - "crop_start":95, + "crop_start": + 95, } model_id = args.model_path @@ -213,6 +219,7 @@ def inference_quantization(args): round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3), "GiB") + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -243,10 +250,14 @@ def inference_quantization(args): default="flow", help="Denoise type for noised inputs.", ) - parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") - parser.add_argument( - "--neg_prompt", type=str, default=None, help="Negative prompt for sampling." - ) + parser.add_argument("--seed", + type=int, + default=None, + help="Seed for evaluation.") + parser.add_argument("--neg_prompt", + type=str, + default=None, + help="Negative prompt for sampling.") parser.add_argument( "--guidance_scale", type=float, @@ -259,12 +270,14 @@ def inference_quantization(args): default=6.0, help="Embedded classifier free guidance scale.", ) - parser.add_argument( - "--flow_shift", type=int, default=7, help="Flow shift parameter." - ) - parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size for inference." - ) + parser.add_argument("--flow_shift", + type=int, + default=7, + help="Flow shift parameter.") + parser.add_argument("--batch_size", + type=int, + default=1, + help="Batch size for inference.") parser.add_argument( "--num_videos", type=int, @@ -275,22 +288,26 @@ def inference_quantization(args): "--load-key", type=str, default="module", - help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.", + help= + "Key to load the model states. 'module' for the main model, 'ema' for the EMA model.", ) parser.add_argument( "--dit-weight", type=str, - default="data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", + default= + "data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", ) parser.add_argument( "--reproduce", action="store_true", - help="Enable reproducibility by setting random seeds and deterministic algorithms.", + help= + "Enable reproducibility by setting random seeds and deterministic algorithms.", ) parser.add_argument( "--disable-autocast", action="store_true", - help="Disable autocast for denoising loop and vae decoding in pipeline sampling.", + help= + "Disable autocast for denoising loop and vae decoding in pipeline sampling.", ) # Flow Matching @@ -299,13 +316,15 @@ def inference_quantization(args): action="store_true", help="If reverse, learning/sampling from t=1 -> t=0.", ) - parser.add_argument( - "--flow-solver", type=str, default="euler", help="Solver for flow matching." - ) + parser.add_argument("--flow-solver", + type=str, + default="euler", + help="Solver for flow matching.") parser.add_argument( "--use-linear-quadratic-schedule", action="store_true", - help="Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)", + help= + "Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)", ) parser.add_argument( "--linear-schedule-end", @@ -317,17 +336,20 @@ def inference_quantization(args): # Model parameters parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill") parser.add_argument("--latent-channels", type=int, default=16) - parser.add_argument( - "--precision", type=str, default="bf16", choices=["fp32", "fp16", "bf16", "fp8"] - ) - parser.add_argument( - "--rope-theta", type=int, default=256, help="Theta used in RoPE." - ) + parser.add_argument("--precision", + type=str, + default="bf16", + choices=["fp32", "fp16", "bf16", "fp8"]) + parser.add_argument("--rope-theta", + type=int, + default=256, + help="Theta used in RoPE.") parser.add_argument("--vae", type=str, default="884-16c-hy") - parser.add_argument( - "--vae-precision", type=str, default="fp16", choices=["fp32", "fp16", "bf16"] - ) + parser.add_argument("--vae-precision", + type=str, + default="fp16", + choices=["fp32", "fp16", "bf16"]) parser.add_argument("--vae-tiling", action="store_true", default=True) parser.add_argument("--text-encoder", type=str, default="llm") @@ -340,10 +362,12 @@ def inference_quantization(args): parser.add_argument("--text-states-dim", type=int, default=4096) parser.add_argument("--text-len", type=int, default=256) parser.add_argument("--tokenizer", type=str, default="llm") - parser.add_argument("--prompt-template", type=str, default="dit-llm-encode") - parser.add_argument( - "--prompt-template-video", type=str, default="dit-llm-encode-video" - ) + parser.add_argument("--prompt-template", + type=str, + default="dit-llm-encode") + parser.add_argument("--prompt-template-video", + type=str, + default="dit-llm-encode-video") parser.add_argument("--hidden-state-skip-layer", type=int, default=2) parser.add_argument("--apply-final-norm", action="store_true") @@ -362,4 +386,4 @@ def inference_quantization(args): if args.quantization: inference_quantization(args) else: - inference(args) \ No newline at end of file + inference(args) From 6bb07059002e420061e224b0046106e82ccae027 Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 04:41:12 +0000 Subject: [PATCH 33/42] format check --- fastvideo/sample/sample_t2v_hunyuan_hf.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/fastvideo/sample/sample_t2v_hunyuan_hf.py b/fastvideo/sample/sample_t2v_hunyuan_hf.py index b98e56c1..0a2abeb3 100644 --- a/fastvideo/sample/sample_t2v_hunyuan_hf.py +++ b/fastvideo/sample/sample_t2v_hunyuan_hf.py @@ -2,10 +2,6 @@ import torch.distributed as dist from diffusers import BitsAndBytesConfig from diffusers.utils import export_to_video -import imageio as iio -import math -import numpy as np -import io import time import argparse import os From aedb4a288a6d717b40f97acf324c69125221bc84 Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 04:47:43 +0000 Subject: [PATCH 34/42] format check --- fastvideo/sample/sample_t2v_hunyuan_hf.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fastvideo/sample/sample_t2v_hunyuan_hf.py b/fastvideo/sample/sample_t2v_hunyuan_hf.py index 0a2abeb3..bb1bccfc 100644 --- a/fastvideo/sample/sample_t2v_hunyuan_hf.py +++ b/fastvideo/sample/sample_t2v_hunyuan_hf.py @@ -7,9 +7,7 @@ import os import json from fastvideo.utils.parallel_states import ( - initialize_sequence_parallel_state, - nccl_info, -) + initialize_sequence_parallel_state, nccl_info) from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline from fastvideo.models.hunyuan_hf.modeling_hunyuan import HunyuanVideoTransformer3DModel From 5a086e97d7b3d496ece06a546d326c5e2a452fd3 Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 04:52:42 +0000 Subject: [PATCH 35/42] format check --- fastvideo/sample/sample_t2v_hunyuan_hf.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/fastvideo/sample/sample_t2v_hunyuan_hf.py b/fastvideo/sample/sample_t2v_hunyuan_hf.py index bb1bccfc..3cc9124b 100644 --- a/fastvideo/sample/sample_t2v_hunyuan_hf.py +++ b/fastvideo/sample/sample_t2v_hunyuan_hf.py @@ -1,15 +1,17 @@ +import argparse +import json +import os +import time + import torch import torch.distributed as dist from diffusers import BitsAndBytesConfig from diffusers.utils import export_to_video -import time -import argparse -import os -import json + +from fastvideo.models.hunyuan_hf.modeling_hunyuan import HunyuanVideoTransformer3DModel +from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline from fastvideo.utils.parallel_states import ( initialize_sequence_parallel_state, nccl_info) -from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline -from fastvideo.models.hunyuan_hf.modeling_hunyuan import HunyuanVideoTransformer3DModel def initialize_distributed(): From e819bdb548d55490ae9b805e8cc3de821af9f383 Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 04:57:22 +0000 Subject: [PATCH 36/42] format check --- fastvideo/sample/sample_t2v_hunyuan_hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastvideo/sample/sample_t2v_hunyuan_hf.py b/fastvideo/sample/sample_t2v_hunyuan_hf.py index 3cc9124b..b1f5ba1d 100644 --- a/fastvideo/sample/sample_t2v_hunyuan_hf.py +++ b/fastvideo/sample/sample_t2v_hunyuan_hf.py @@ -8,7 +8,8 @@ from diffusers import BitsAndBytesConfig from diffusers.utils import export_to_video -from fastvideo.models.hunyuan_hf.modeling_hunyuan import HunyuanVideoTransformer3DModel +from fastvideo.models.hunyuan_hf.modeling_hunyuan import \ + HunyuanVideoTransformer3DModel from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline from fastvideo.utils.parallel_states import ( initialize_sequence_parallel_state, nccl_info) From 7fcc4149537da29d339a8237b9dd1026a0c41618 Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 05:12:46 +0000 Subject: [PATCH 37/42] scripts clean --- scripts/finetune/finetune_hunyuan.sh | 7 +++---- scripts/finetune/finetune_hunyuan_hf_lora.sh | 13 ++++++------- scripts/inference/inference_hunyuan_hf.sh | 8 ++++---- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/scripts/finetune/finetune_hunyuan.sh b/scripts/finetune/finetune_hunyuan.sh index 12e8e43a..2b8811a4 100644 --- a/scripts/finetune/finetune_hunyuan.sh +++ b/scripts/finetune/finetune_hunyuan.sh @@ -12,7 +12,7 @@ torchrun --nnodes 1 --nproc_per_node 8 \ --validation_prompt_dir data/Image-Vid-Finetune-HunYuan/validation \ --gradient_checkpointing \ --train_batch_size=1 \ - --num_latent_t 24 \ + --num_latent_t 32 \ --sp_size 4 \ --train_sp_batch_size 1 \ --dataloader_num_workers 4 \ @@ -31,8 +31,7 @@ torchrun --nnodes 1 --nproc_per_node 8 \ --log_validation \ --output_dir=data/outputs/HSH-Taylor-Finetune-Hunyuan \ --tracker_project_name HSH-Taylor-Finetune-Hunyuan \ - --num_frames 93 \ + --num_frames 125 \ --num_height 720 \ --num_width 1280 \ - --validation_guidance_scale "1.0" \ - --group_frame \ No newline at end of file + --validation_guidance_scale "1.0" \ \ No newline at end of file diff --git a/scripts/finetune/finetune_hunyuan_hf_lora.sh b/scripts/finetune/finetune_hunyuan_hf_lora.sh index 8f7e85a4..8e4c48ed 100644 --- a/scripts/finetune/finetune_hunyuan_hf_lora.sh +++ b/scripts/finetune/finetune_hunyuan_hf_lora.sh @@ -6,16 +6,16 @@ torchrun --nnodes 1 --nproc_per_node 4 --master_port 29903 \ --pretrained_model_name_or_path data/hunyuan_diffusers \ --model_type hunyuan_hf \ --cache_dir data/.cache \ - --data_json_path data/Image-Vid-Finetune-Src/videos2caption.json \ - --validation_prompt_dir data/Image-Vid-Finetune-Src/validation \ + --data_json_path data/Image-Vid-Finetune-HunYuan/videos2caption.json \ + --validation_prompt_dir data/Image-Vid-Finetune-HunYuan/validation \ --gradient_checkpointing \ --train_batch_size 1 \ - --num_latent_t 24 \ + --num_latent_t 32 \ --sp_size 4 \ --train_sp_batch_size 1 \ --dataloader_num_workers 4 \ --gradient_accumulation_steps 4 \ - --max_train_steps 2000 \ + --max_train_steps 6000 \ --learning_rate 8e-5 \ --mixed_precision bf16 \ --checkpointing_steps 500 \ @@ -29,9 +29,8 @@ torchrun --nnodes 1 --nproc_per_node 4 --master_port 29903 \ --log_validation \ --output_dir data/outputs/HSH-Taylor-Finetune-Hunyuan \ --tracker_project_name HSH-Taylor-Finetune-Hunyuan \ - --num_frames 93 \ + --num_frames 125 \ --validation_guidance_scale "1.0" \ - --group_frame \ --use_lora \ --lora_rank 32 \ - --lora_alpha 32 \ + --lora_alpha 32 \ No newline at end of file diff --git a/scripts/inference/inference_hunyuan_hf.sh b/scripts/inference/inference_hunyuan_hf.sh index b14d4fc0..d263db3c 100644 --- a/scripts/inference/inference_hunyuan_hf.sh +++ b/scripts/inference/inference_hunyuan_hf.sh @@ -4,10 +4,10 @@ num_gpus=4 torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ fastvideo/sample/sample_t2v_hunyuan_hf.py \ --model_path ~/data/hunyuan_diffusers/ \ - --prompt_path "assets/prompt_test_3.txt" \ - --num_frames 93 \ - --height 480 \ - --width 848 \ + --prompt_path "assets/prompt.txt" \ + --num_frames 125 \ + --height 720 \ + --width 1280 \ --num_inference_steps 50 \ --output_path outputs_video/hunyuan_hf/ \ --seed 1024 \ From 1d0d787c1d90cd848e674cc96f7e987c6fda6bca Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 17:14:47 +0000 Subject: [PATCH 38/42] fix shift issue in scripts --- scripts/finetune/finetune_hunyuan.sh | 1 + scripts/finetune/finetune_hunyuan_hf_lora.sh | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/finetune/finetune_hunyuan.sh b/scripts/finetune/finetune_hunyuan.sh index 2b8811a4..cb1809d3 100644 --- a/scripts/finetune/finetune_hunyuan.sh +++ b/scripts/finetune/finetune_hunyuan.sh @@ -34,4 +34,5 @@ torchrun --nnodes 1 --nproc_per_node 8 \ --num_frames 125 \ --num_height 720 \ --num_width 1280 \ + --shift 7 \ --validation_guidance_scale "1.0" \ \ No newline at end of file diff --git a/scripts/finetune/finetune_hunyuan_hf_lora.sh b/scripts/finetune/finetune_hunyuan_hf_lora.sh index 8e4c48ed..ccf59ace 100644 --- a/scripts/finetune/finetune_hunyuan_hf_lora.sh +++ b/scripts/finetune/finetune_hunyuan_hf_lora.sh @@ -31,6 +31,8 @@ torchrun --nnodes 1 --nproc_per_node 4 --master_port 29903 \ --tracker_project_name HSH-Taylor-Finetune-Hunyuan \ --num_frames 125 \ --validation_guidance_scale "1.0" \ + --shift 7 \ --use_lora \ --lora_rank 32 \ - --lora_alpha 32 \ No newline at end of file + --lora_alpha 32 + \ \ No newline at end of file From f7c9a37c23698f9454b87ef0e0a04d95a1053d5e Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 18:17:49 +0000 Subject: [PATCH 39/42] add change log --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 18799c27..9bd76c70 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ Comparison between original FastHunyuan, LLM-INT8 quantized FastHunyuan and NF4 https://github.com/user-attachments/assets/cf89efb5-5f68-4949-a085-f41c1ef26c94 ## Change Log +- ```2025/01/13```: Support Lora finetuning for HunyuanVideo. - ```2024/12/25```: Enable single 4090 inference for `FastHunyuan`, please rerun the installation steps to update the environment. - ```2024/12/17```: `FastVideo` v1.0 is released. From a7e5aac3d0b1e44c76996e0579040f36fc3dc44f Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Mon, 13 Jan 2025 20:17:51 +0000 Subject: [PATCH 40/42] fix huynuan ft scripts val steps --- scripts/finetune/finetune_hunyuan.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/finetune/finetune_hunyuan.sh b/scripts/finetune/finetune_hunyuan.sh index cb1809d3..f420bb2d 100644 --- a/scripts/finetune/finetune_hunyuan.sh +++ b/scripts/finetune/finetune_hunyuan.sh @@ -18,11 +18,11 @@ torchrun --nnodes 1 --nproc_per_node 8 \ --dataloader_num_workers 4 \ --gradient_accumulation_steps=1 \ --max_train_steps=2000 \ - --learning_rate=5e-6 \ + --learning_rate=1e-5 \ --mixed_precision=bf16 \ --checkpointing_steps=200 \ --validation_steps 100 \ - --validation_sampling_steps 64 \ + --validation_sampling_steps 50 \ --checkpoints_total_limit 3 \ --allow_tf32 \ --ema_start_step 0 \ From db92841807370b5e1d4be04bc8ed59b46c69959a Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Tue, 14 Jan 2025 01:07:23 +0000 Subject: [PATCH 41/42] pr version revision --- fastvideo/sample/sample_t2v_hunyuan_hf.py | 14 -------------- fastvideo/utils/checkpoint.py | 7 ------- scripts/finetune/finetune_hunyuan_hf_lora.sh | 11 +++++------ .../inference/inference_hunyuan_hf_quantization.sh | 2 +- 4 files changed, 6 insertions(+), 28 deletions(-) diff --git a/fastvideo/sample/sample_t2v_hunyuan_hf.py b/fastvideo/sample/sample_t2v_hunyuan_hf.py index b1f5ba1d..58eeb8b8 100644 --- a/fastvideo/sample/sample_t2v_hunyuan_hf.py +++ b/fastvideo/sample/sample_t2v_hunyuan_hf.py @@ -130,19 +130,6 @@ def inference(args): def inference_quantization(args): torch.manual_seed(args.seed) device = "cuda" if torch.cuda.is_available() else "cpu" - prompt_template = { - "template": - ("<|start_header_cid|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " - "1. The main content and theme of the video." - "2. The color, shape, size, texture, quantity, text, and spatial relationships of the contents, including objects, people, and anything else." - "3. Actions, events, behaviors temporal relationships, physical movement changes of the contents." - "4. Background environment, light, style, atmosphere, and qualities." - "5. Camera angles, movements, and transitions used in the video." - "6. Thematic and aesthetic concepts associated with the scene, i.e. realistic, futuristic, fairy tale, etc<|eot_id|>" - "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"), - "crop_start": - 95, - } model_id = args.model_path if args.quantization == "nf4": @@ -203,7 +190,6 @@ def inference_quantization(args): height=args.height, width=args.width, num_frames=args.num_frames, - prompt_template=prompt_template, num_inference_steps=args.num_inference_steps, generator=generator, ).frames[0] diff --git a/fastvideo/utils/checkpoint.py b/fastvideo/utils/checkpoint.py index 60ae1ba5..162c68a9 100644 --- a/fastvideo/utils/checkpoint.py +++ b/fastvideo/utils/checkpoint.py @@ -46,13 +46,6 @@ def save_checkpoint_optimizer(model, save_file(cpu_state, weight_path) config_dict = dict(model.config) config_dict.pop('dtype') - # dtype = config_dict['dtype'] - # if dtype == torch.float32: - # config_dict['dtype'] = 'fp32' - # elif dtype == torch.float16: - # config_dict['dtype'] = 'fp16' - # elif dtype == torch.bfloat16: - # config_dict['dtype'] = 'bf16' config_path = os.path.join(save_dir, "config.json") # save dict as json with open(config_path, "w") as f: diff --git a/scripts/finetune/finetune_hunyuan_hf_lora.sh b/scripts/finetune/finetune_hunyuan_hf_lora.sh index ccf59ace..254973ae 100644 --- a/scripts/finetune/finetune_hunyuan_hf_lora.sh +++ b/scripts/finetune/finetune_hunyuan_hf_lora.sh @@ -6,8 +6,8 @@ torchrun --nnodes 1 --nproc_per_node 4 --master_port 29903 \ --pretrained_model_name_or_path data/hunyuan_diffusers \ --model_type hunyuan_hf \ --cache_dir data/.cache \ - --data_json_path data/Image-Vid-Finetune-HunYuan/videos2caption.json \ - --validation_prompt_dir data/Image-Vid-Finetune-HunYuan/validation \ + --data_json_path data/Black-Myth-Wukong/videos2caption.json \ + --validation_prompt_dir data/Black-Myth-Wukong/validation \ --gradient_checkpointing \ --train_batch_size 1 \ --num_latent_t 32 \ @@ -27,12 +27,11 @@ torchrun --nnodes 1 --nproc_per_node 4 --master_port 29903 \ --cfg 0.0 \ --ema_decay 0.999 \ --log_validation \ - --output_dir data/outputs/HSH-Taylor-Finetune-Hunyuan \ - --tracker_project_name HSH-Taylor-Finetune-Hunyuan \ + --output_dir data/outputs/Hunyuan-lora-finetuning-Black-Myth-Wukong \ + --tracker_project_name Hunyuan-lora-finetuning-Black-Myth-Wukong \ --num_frames 125 \ --validation_guidance_scale "1.0" \ --shift 7 \ --use_lora \ --lora_rank 32 \ - --lora_alpha 32 - \ \ No newline at end of file + --lora_alpha 32 \ No newline at end of file diff --git a/scripts/inference/inference_hunyuan_hf_quantization.sh b/scripts/inference/inference_hunyuan_hf_quantization.sh index a94460f4..a3a7cadc 100644 --- a/scripts/inference/inference_hunyuan_hf_quantization.sh +++ b/scripts/inference/inference_hunyuan_hf_quantization.sh @@ -1,7 +1,7 @@ #!/bin/bash num_gpus=1 -export MODEL_BASE="data/FastHunyuan" +export MODEL_BASE="data/FastHunyuan-diffusers" torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 12345 \ fastvideo/sample/sample_t2v_hunyuan_hf.py \ --height 720 \ From a44972bdb2c25a1d63cf1de199d5849313990eba Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Tue, 14 Jan 2025 01:13:59 +0000 Subject: [PATCH 42/42] ready for lora pr --- fastvideo/train.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/fastvideo/train.py b/fastvideo/train.py index c3c8ebee..5878b9f8 100644 --- a/fastvideo/train.py +++ b/fastvideo/train.py @@ -145,18 +145,20 @@ def train_one_step( dtype=latents.dtype, ) noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise - training_guidance = torch.tensor([1000.0], - device=noisy_model_input.device, - dtype=torch.bfloat16) - with torch.autocast("cuda", torch.bfloat16): - model_pred = transformer( - noisy_model_input, - encoder_hidden_states, - timesteps, - encoder_attention_mask, # B, L - training_guidance, - return_dict=False, - )[0] + with torch.autocast("cuda", dtype=torch.bfloat16): + input_kwargs = { + "hidden_states": noisy_model_input, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timesteps, + "encoder_attention_mask": encoder_attention_mask, # B, L + "return_dict": False, + } + if 'hunyuan' in model_type: + input_kwargs["guidance"] = torch.tensor( + [1000.0], + device=noisy_model_input.device, + dtype=torch.bfloat16) + model_pred = transformer(**input_kwargs)[0] if precondition_outputs: model_pred = noisy_model_input - model_pred * sigmas