Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding hunyuan hf (support lora finetuning); unified hunyuan hf inference with quantization #135

Merged
merged 47 commits into from
Jan 14, 2025
Merged
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
a133b58
fix lora cp saving issue
BrianChen1129 Dec 19, 2024
3008adf
Merge branch 'main' of github.com:jzhang38/FastVideo-OSP
BrianChen1129 Dec 19, 2024
a3f8fc2
fix lora save issue
BrianChen1129 Dec 19, 2024
580575c
fix lora save issue
BrianChen1129 Dec 19, 2024
4c16991
Revert "fix lora save issue"
BrianChen1129 Dec 19, 2024
677efd9
fix lora save issue
BrianChen1129 Dec 19, 2024
48fd2d2
Merge branch 'main' of github.com:jzhang38/FastVideo-OSP into yq-lora…
BrianChen1129 Dec 20, 2024
c31f507
debug hunyuan hf sp
BrianChen1129 Dec 23, 2024
ac58b2f
test hunyuan hf
BrianChen1129 Dec 24, 2024
77b690c
add huanyuan hf inference and train
BrianChen1129 Dec 24, 2024
04c1610
support hunyuan hf lora
BrianChen1129 Dec 27, 2024
bfe0448
syn with main
BrianChen1129 Dec 27, 2024
594b65a
syn with main
BrianChen1129 Dec 27, 2024
36560eb
unified hunyuan hf
BrianChen1129 Dec 29, 2024
d25c235
unified hunyuan hf
BrianChen1129 Dec 29, 2024
557dbca
unified hunyuan hf
BrianChen1129 Dec 29, 2024
0c058c3
add lora
BrianChen1129 Jan 7, 2025
cbc52b0
syn with main
BrianChen1129 Jan 7, 2025
f35ea70
unify hunyuan hf inference
BrianChen1129 Jan 7, 2025
cad6e6d
unify hunyuan hf inference
BrianChen1129 Jan 7, 2025
1832042
syn with main
BrianChen1129 Jan 7, 2025
5d21b16
syn
BrianChen1129 Jan 7, 2025
1d7d637
syn
BrianChen1129 Jan 7, 2025
c1cf441
syn
BrianChen1129 Jan 7, 2025
22d499b
syn
BrianChen1129 Jan 7, 2025
e7ea0d7
syn
BrianChen1129 Jan 7, 2025
afe24e2
syn
BrianChen1129 Jan 7, 2025
41f8ac9
syn
BrianChen1129 Jan 8, 2025
3f2fc1a
fix train.py
BrianChen1129 Jan 12, 2025
32a7bb3
update README
BrianChen1129 Jan 12, 2025
79be9b5
syn
BrianChen1129 Jan 12, 2025
5af5b6a
add dataset preparation scripts
BrianChen1129 Jan 13, 2025
54d1e5d
format
BrianChen1129 Jan 13, 2025
2017c08
syn with main; add readme
BrianChen1129 Jan 13, 2025
26e7dc9
syn with main; add readme
BrianChen1129 Jan 13, 2025
9109210
ready for lora release
BrianChen1129 Jan 13, 2025
e357df9
format check
BrianChen1129 Jan 13, 2025
6bb0705
format check
BrianChen1129 Jan 13, 2025
aedb4a2
format check
BrianChen1129 Jan 13, 2025
5a086e9
format check
BrianChen1129 Jan 13, 2025
e819bdb
format check
BrianChen1129 Jan 13, 2025
7fcc414
scripts clean
BrianChen1129 Jan 13, 2025
1d0d787
fix shift issue in scripts
BrianChen1129 Jan 13, 2025
f7c9a37
add change log
BrianChen1129 Jan 13, 2025
a7e5aac
fix huynuan ft scripts val steps
BrianChen1129 Jan 13, 2025
db92841
pr version revision
BrianChen1129 Jan 14, 2025
a44972b
ready for lora pr
BrianChen1129 Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
format check
BrianChen1129 committed Jan 13, 2025
commit e357df9fe249542b3ba43536998c91a33110c6f4
144 changes: 84 additions & 60 deletions fastvideo/sample/sample_t2v_hunyuan_hf.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import torch.distributed as dist
from diffusers import BitsAndBytesConfig
from diffusers.utils import export_to_video
import imageio as iio
import imageio as iio
import math
import numpy as np
import io
@@ -17,17 +17,20 @@
from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline
from fastvideo.models.hunyuan_hf.modeling_hunyuan import HunyuanVideoTransformer3DModel


def initialize_distributed():
os.environ["TOKENIZERS_PARALLELISM"] = "false"
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
print("world_size", world_size)
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl", init_method="env://", world_size=world_size, rank=local_rank
)
dist.init_process_group(backend="nccl",
init_method="env://",
world_size=world_size,
rank=local_rank)
initialize_sequence_parallel_state(world_size)



def inference(args):
Copy link
Collaborator

@jzhang38 jzhang38 Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why separate inference and inference quantization functions?

initialize_distributed()
print(nccl_info.sp_size)
@@ -36,29 +39,35 @@ def inference(args):
weight_dtype = torch.bfloat16

if args.transformer_path is not None:
transformer = HunyuanVideoTransformer3DModel.from_pretrained(args.transformer_path)
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
args.transformer_path)
else:
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
args.model_path, subfolder="transformer/", torch_dtype=weight_dtype
)
args.model_path,
subfolder="transformer/",
torch_dtype=weight_dtype)

pipe = HunyuanVideoPipeline.from_pretrained(
args.model_path, transformer=transformer, torch_dtype=weight_dtype
)
pipe = HunyuanVideoPipeline.from_pretrained(args.model_path,
transformer=transformer,
torch_dtype=weight_dtype)

pipe.enable_vae_tiling()

if args.lora_checkpoint_dir is not None:
print(f"Loading LoRA weights from {args.lora_checkpoint_dir}")
config_path = os.path.join(args.lora_checkpoint_dir, "lora_config.json")
config_path = os.path.join(args.lora_checkpoint_dir,
"lora_config.json")
with open(config_path, "r") as f:
lora_config_dict = json.load(f)
rank = lora_config_dict["lora_params"]["lora_rank"]
lora_alpha = lora_config_dict["lora_params"]["lora_alpha"]
lora_scaling = lora_alpha / rank
pipe.load_lora_weights(args.lora_checkpoint_dir, adapter_name="default")
pipe.load_lora_weights(args.lora_checkpoint_dir,
adapter_name="default")
pipe.set_adapters(["default"], [lora_scaling])
print(f"Successfully Loaded LoRA weights from {args.lora_checkpoint_dir}")
print(
f"Successfully Loaded LoRA weights from {args.lora_checkpoint_dir}"
)
if args.cpu_offload:
pipe.enable_model_cpu_offload(device)
else:
@@ -67,18 +76,13 @@ def inference(args):
# Generate videos from the input prompt

if args.prompt_embed_path is not None:
prompt_embeds = (
torch.load(args.prompt_embed_path, map_location="cpu", weights_only=True)
.to(device)
.unsqueeze(0)
)
encoder_attention_mask = (
torch.load(
args.encoder_attention_mask_path, map_location="cpu", weights_only=True
)
.to(device)
.unsqueeze(0)
)
prompt_embeds = (torch.load(args.prompt_embed_path,
map_location="cpu",
weights_only=True).to(device).unsqueeze(0))
encoder_attention_mask = (torch.load(
args.encoder_attention_mask_path,
map_location="cpu",
weights_only=True).to(device).unsqueeze(0))
prompts = None
elif args.prompt_path is not None:
prompts = [line.strip() for line in open(args.prompt_path, "r")]
@@ -121,10 +125,11 @@ def inference(args):
num_inference_steps=args.num_inference_steps,
generator=generator,
).frames

if nccl_info.global_rank <= 0:
export_to_video(videos[0], args.output_path + ".mp4", fps=24)


def inference_quantization(args):
torch.manual_seed(args.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -138,7 +143,8 @@ def inference_quantization(args):
"5. Camera angles, movements, and transitions used in the video."
"6. Thematic and aesthetic concepts associated with the scene, i.e. realistic, futuristic, fairy tale, etc<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"),
"crop_start":95,
"crop_start":
95,
}
model_id = args.model_path

@@ -213,6 +219,7 @@ def inference_quantization(args):
round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3),
"GiB")


if __name__ == "__main__":
parser = argparse.ArgumentParser()

@@ -243,10 +250,14 @@ def inference_quantization(args):
default="flow",
help="Denoise type for noised inputs.",
)
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
parser.add_argument(
"--neg_prompt", type=str, default=None, help="Negative prompt for sampling."
)
parser.add_argument("--seed",
type=int,
default=None,
help="Seed for evaluation.")
parser.add_argument("--neg_prompt",
type=str,
default=None,
help="Negative prompt for sampling.")
parser.add_argument(
"--guidance_scale",
type=float,
@@ -259,12 +270,14 @@ def inference_quantization(args):
default=6.0,
help="Embedded classifier free guidance scale.",
)
parser.add_argument(
"--flow_shift", type=int, default=7, help="Flow shift parameter."
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for inference."
)
parser.add_argument("--flow_shift",
type=int,
default=7,
help="Flow shift parameter.")
parser.add_argument("--batch_size",
type=int,
default=1,
help="Batch size for inference.")
parser.add_argument(
"--num_videos",
type=int,
@@ -275,22 +288,26 @@ def inference_quantization(args):
"--load-key",
type=str,
default="module",
help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
help=
"Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
)
parser.add_argument(
"--dit-weight",
type=str,
default="data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
default=
"data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
)
parser.add_argument(
"--reproduce",
action="store_true",
help="Enable reproducibility by setting random seeds and deterministic algorithms.",
help=
"Enable reproducibility by setting random seeds and deterministic algorithms.",
)
parser.add_argument(
"--disable-autocast",
action="store_true",
help="Disable autocast for denoising loop and vae decoding in pipeline sampling.",
help=
"Disable autocast for denoising loop and vae decoding in pipeline sampling.",
)

# Flow Matching
@@ -299,13 +316,15 @@ def inference_quantization(args):
action="store_true",
help="If reverse, learning/sampling from t=1 -> t=0.",
)
parser.add_argument(
"--flow-solver", type=str, default="euler", help="Solver for flow matching."
)
parser.add_argument("--flow-solver",
type=str,
default="euler",
help="Solver for flow matching.")
parser.add_argument(
"--use-linear-quadratic-schedule",
action="store_true",
help="Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
help=
"Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
)
parser.add_argument(
"--linear-schedule-end",
@@ -317,17 +336,20 @@ def inference_quantization(args):
# Model parameters
parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill")
parser.add_argument("--latent-channels", type=int, default=16)
parser.add_argument(
"--precision", type=str, default="bf16", choices=["fp32", "fp16", "bf16", "fp8"]
)
parser.add_argument(
"--rope-theta", type=int, default=256, help="Theta used in RoPE."
)
parser.add_argument("--precision",
type=str,
default="bf16",
choices=["fp32", "fp16", "bf16", "fp8"])
parser.add_argument("--rope-theta",
type=int,
default=256,
help="Theta used in RoPE.")

parser.add_argument("--vae", type=str, default="884-16c-hy")
parser.add_argument(
"--vae-precision", type=str, default="fp16", choices=["fp32", "fp16", "bf16"]
)
parser.add_argument("--vae-precision",
type=str,
default="fp16",
choices=["fp32", "fp16", "bf16"])
parser.add_argument("--vae-tiling", action="store_true", default=True)

parser.add_argument("--text-encoder", type=str, default="llm")
@@ -340,10 +362,12 @@ def inference_quantization(args):
parser.add_argument("--text-states-dim", type=int, default=4096)
parser.add_argument("--text-len", type=int, default=256)
parser.add_argument("--tokenizer", type=str, default="llm")
parser.add_argument("--prompt-template", type=str, default="dit-llm-encode")
parser.add_argument(
"--prompt-template-video", type=str, default="dit-llm-encode-video"
)
parser.add_argument("--prompt-template",
type=str,
default="dit-llm-encode")
parser.add_argument("--prompt-template-video",
type=str,
default="dit-llm-encode-video")
parser.add_argument("--hidden-state-skip-layer", type=int, default=2)
parser.add_argument("--apply-final-norm", action="store_true")

@@ -362,4 +386,4 @@ def inference_quantization(args):
if args.quantization:
inference_quantization(args)
else:
inference(args)
inference(args)