diff --git a/examples/wan2_1/README.md b/examples/wan2_1/README.md new file mode 100644 index 0000000000..d9a91f741b --- /dev/null +++ b/examples/wan2_1/README.md @@ -0,0 +1,287 @@ +
+

🚀 Wan: Open and Advanced Large-Scale Video Generative Models

+ +
+ +In this repository, we present an efficient MindSpore implementation of [Wan2.1](https://github.com/Wan-Video/Wan2.1). This repository is built on the models and code released by the Alibaba Wan group. We are grateful for their exceptional work and generous contribution to open source. + +- 👍 **SOTA Performance**: **Wan2.1** consistently outperforms existing open-source models and state-of-the-art commercial solutions across multiple benchmarks. +- 👍 **Multiple Tasks**: **Wan2.1** excels in Text-to-Video, Image-to-Video, Video Editing, Text-to-Image, and Video-to-Audio, advancing the field of video generation. +- 👍 **Visual Text Generation**: **Wan2.1** is the first video model capable of generating both Chinese and English text, featuring robust text generation that enhances its practical applications. +- 👍 **Powerful Video VAE**: **Wan-VAE** delivers exceptional efficiency and performance, encoding and decoding 1080P videos of any length while preserving temporal information, making it an ideal foundation for video and image generation. + + +## Video Demos + +The following videos are generated based on MindSpore and Ascend 910*. + +- Text-to-Video + +https://github.com/user-attachments/assets/f6705d28-7755-447b-a256-6727f66d693b + +```text +prompt: A sepia-toned vintage photograph depicting a whimsical bicycle race featuring several dogs wearing goggles and tiny cycling outfits. The canine racers, with determined expressions and blurred motion, pedal miniature bicycles on a dusty road. Spectators in period clothing line the sides, adding to the nostalgic atmosphere. Slightly grainy and blurred, mimicking old photos, with soft side lighting enhancing the warm tones and rustic charm of the scene. 'Bicycle Race' captures this unique moment in a medium shot, focusing on both the racers and the lively crowd. +``` + +https://github.com/user-attachments/assets/1e1da53a-9112-4fc3-bb8e-b458497c4806 + +```text +prompt: Film quality, professional quality, rich details. The video begins to show the surface of a pond, and the camera slowly zooms in to a close-up. The water surface begins to bubble, and then a blonde woman is seen coming out of the lotus pond soaked all over, showing the subtle changes in her facial expression, creating a dreamy atmosphere. +``` + +https://github.com/user-attachments/assets/34e4501f-a207-40bb-bb6c-b162ff6505b0 + +```text +prompt: Two anthropomorphic cats wearing boxing suits and bright gloves fiercely battled on the boxing ring under the spotlight. Their muscles are tight, displaying the strength and agility of professional boxers. A spotted dog judge stood aside. The animals in the audience around cheered and cheered, adding a lively atmosphere to the competition. The cat's boxing movements are quick and powerful, with its paws tracing blurry trajectories in the air. The screen adopts a dynamic blur effect, close ups, and focuses on the intense confrontation on the boxing ring. +``` + +https://github.com/user-attachments/assets/aceda253-78a2-4fa5-9edc-83f035c7c2ea + +```text +prompt: Sports photography full of dynamism, several motorcycles fiercely compete on the loess flying track, their wheels rolling up the dust in the sky. The motorcyclist is wearing professional racing clothes. The camera uses a high-speed shutter to capture moments, follows from the side and rear, and finally freezes in a close-up of a motorcycle, showcasing its exquisite body lines and powerful mechanical beauty, creating a tense and exciting racing atmosphere. Close up dynamic perspective, perfectly presenting the visual impact of speed and power. +``` + +https://github.com/user-attachments/assets/c00ca7b8-5e05-4776-8c72-ae19e6bd44f5 + +```text +prompt: 电影画质,专业质量,丰富细节。一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。 +``` + + +- Image-to-Video + + +https://github.com/user-attachments/assets/d37bf480-595e-4a41-95f8-acbc421b7428 + +```text +prompt: Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside. +``` + + +## 🔥 Latest News!! + +* Feb 28, 2025: 👋 MindSpore implementation of Wan2.1 is released, supporting text-to-video and image-to-video inference tasks on 1.3B and 14B models. + + +## 📑 Todo List +- Wan2.1 Text-to-Video + - [x] Single-NPU inference code of the 14B and 1.3B models + - [x] Multi-NPU inference acceleration for the 14B models + - [ ] prompt extension + - [ ] Gradio demo +- Wan2.1 Image-to-Video + - [x] Single-NPU inference code of the 14B model + - [x] Multi-NPU inference acceleration for the 14B model + - [ ] prompt extension + - [ ] Gradio demo + + +## Quickstart + +### Requirments + +The code is tested in the following environments + +| mindspore | ascend driver | firmware | cann tookit/kernel | +| :---: | :---: | :---: | :---: | +| 2.5.0 | 24.1.0 |7.35.23 | 8.0.RC3.beta1 | + + +### Installation +Clone the repo: +``` +git clone https://github.com/mindspore-lab/mindone +cd mindone/examples/wan2_1 +``` + +Install dependencies: +``` +pip install -r requirements.txt +``` + +### Model Download + +| Models | Download Link | Notes | +| --------------|-------------------------------------------------------------------------------|-------------------------------| +| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P +| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P +| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P +| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P + +> 💡Note: The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution. + + +Download the models using huggingface-cli or modelscope: +``` +pip install "huggingface_hub[cli]" +huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir ./Wan2.1-T2V-1.3B +huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir ./Wan2.1-T2V-14B +huggingface-cli download Wan-AI/Wan2.1-I2V-14B-480P --local-dir ./Wan2.1-I2V-14B-480P +huggingface-cli download Wan-AI/Wan2.1-I2V-14B-720P --local-dir ./Wan2.1-I2V-14B-720P +``` + +Download models using modelscope-cli similarly: +``` +pip install modelscope +huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir ./Wan2.1-T2V-1.3B +``` + +### Run Text-to-Video Generation + +This repository supports two Text-to-Video models (1.3B and 14B) and two resolutions (480P and 720P). The parameters and configurations for these models are as follows: + + + + + + + + + + + + + + + + + + + + + + + + + + + +
TaskResolutionModel
480P720P
t2v-14B✔️✔️Wan2.1-T2V-14B
t2v-1.3B✔️Wan2.1-T2V-1.3B
+ + +- Single-NPU inference + +``` +python generate.py \ + --task t2v-14B \ + --size 1280*720 \ + --ckpt_dir ./Wan2.1-T2V-14B \ + --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." +``` + +``` +python generate.py \ + --task t2v-1.3B \ + --size 832*480 \ + --ckpt_dir ./Wan2.1-T2V-1.3B \ + --sample_shift 8 \ + --sample_guide_scale 6 \ + --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." +``` + +> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance. + + +- Multi-NPU inference + +``` +msrun --worker_num=2 --local_worker_num=2 generate.py \ + --task t2v-14B \ + --size 1280*720 \ + --ckpt_dir ./Wan2.1-T2V-14B \ + --dit_zero3 --t5_zero3 --ulysses_sp \ + --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." +``` + + > 💡 We support 720 T2V inference using only 1 card. But using more cards can accelerate the generation process. + +### Run Image-to-Video Generation + +Similar to Text-to-Video, Image-to-Video supports different resolutions. The specific parameters and their corresponding settings are as follows: + + + + + + + + + + + + + + + + + + + + + + + + + + + +
TaskResolutionModel
480P720P
i2v-14B✔️Wan2.1-I2V-14B-720P
i2v-14B✔️Wan2.1-T2V-14B-480P
+ + +- Single-NPU inference + +``` +python generate.py \ + --task i2v-14B \ + --size 832*480 \ + --ckpt_dir ./Wan2.1-I2V-14B-480P \ + --image examples/i2v_input.JPG \ + --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +``` + +> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image. + + +- Multi-NPU inference + +``` +msrun --worker_num=2 --local_worker_num=2 generate.py \ + --task i2v-14B --size 1280*720 \ + --ckpt_dir ./Wan2.1-I2V-14B-720P \ + --dit_zero3 --t5_zero3 --ulysses_sp \ + --image examples/i2v_input.JPG \ + --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +``` + + > 💡At least 2 cards are required to run 720P I2V generation to avoid OOM. 8 cards will accelerate the generation process at most. + +## Performance + +Experiments are tested on ascend 910* with mindspore 2.5.0 **pynative** mode: + +| model | h x w x f | cards | steps | npu peak memory | s/video | +|:------------:|:------------:|:------------:|:------------:|:------------:|:------------:| +| T2V-1.3B | 832x480x81 | 1 | 50 | 21GB | ~235 | +| T2V-14B | 1280x720x81 | 1 | 50 | 52.2GB | ~4650 | +| I2V-14B | 832x480x81 | 1 | 40 | 50GB | ~1150 | +| I2V-14B | 1280x720x81 | 4 | 40 | 25GB | ~1000 | + + +## Citation + +``` +@article{wan2.1, + title = {Wan: Open and Advanced Large-Scale Video Generative Models}, + author = {Wan Team}, + journal = {}, + year = {2025} +} +``` + +## License Agreement +The models in this repository are licensed under the Apache 2.0 License. We claim no rights over the your generate contents, granting you the freedom to use them while ensuring that your usage complies with the provisions of this license. You are fully accountable for your use of the models, which must not involve sharing any content that violates applicable laws, causes harm to individuals or groups, disseminates personal information intended for harm, spreads misinformation, or targets vulnerable populations. For a complete list of restrictions and details regarding your rights, please refer to the full text of the [license](LICENSE.txt). + + +## Acknowledgements + +We would like to thank the contributors to the [Wan2.1](https://github.com/Wan-Video/Wan2.1), [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [Qwen](https://huggingface.co/Qwen), [umt5-xxl](https://huggingface.co/google/umt5-xxl), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research. diff --git a/examples/wan2_1/examples/i2v_input.JPG b/examples/wan2_1/examples/i2v_input.JPG new file mode 100644 index 0000000000..6f6dcced59 Binary files /dev/null and b/examples/wan2_1/examples/i2v_input.JPG differ diff --git a/examples/wan2_1/generate.py b/examples/wan2_1/generate.py new file mode 100644 index 0000000000..a377f3a820 --- /dev/null +++ b/examples/wan2_1/generate.py @@ -0,0 +1,302 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +import logging +import os +import random +import sys +from datetime import datetime + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +mindone_lib_path = os.path.abspath(os.path.join(__dir__, "../../")) +sys.path.insert(0, mindone_lib_path) + +import wan +from PIL import Image +from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS +from wan.utils.utils import cache_image, cache_video, str2bool + +import mindspore as ms +import mindspore.mint.distributed as dist +from mindspore.communication import GlobalComm + +EXAMPLE_PROMPT = { + "t2v-1.3B": { + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "t2v-14B": { + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "t2i-14B": { + "prompt": "一个朴素端庄的美人", + }, + "i2v-14B": { + "prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. " + "The fluffy-furred feline gazes directly at the camera with a relaxed expression. " + "Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. " + "The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. " + "A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", + "image": "examples/i2v_input.JPG", + }, +} + + +def _validate_args(args): + # Basic check + assert args.ckpt_dir is not None, "Please specify the checkpoint directory." + assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" + assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" + + # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. + if args.sample_steps is None: + args.sample_steps = 40 if "i2v" in args.task else 50 + + if args.sample_shift is None: + args.sample_shift = 5.0 + if "i2v" in args.task and args.size in ["832*480", "480*832"]: + args.sample_shift = 3.0 + + # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. + if args.frame_num is None: + args.frame_num = 1 if "t2i" in args.task else 81 + + # T2I frame_num check + if "t2i" in args.task: + assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}" + + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize) + # Size check + assert ( + args.size in SUPPORTED_SIZES[args.task] + ), f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" + + +def _parse_args(): + parser = argparse.ArgumentParser(description="Generate a image or video from a text prompt or image using Wan") + parser.add_argument( + "--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run." + ) + parser.add_argument( + "--size", + type=str, + default="1280*720", + choices=list(SIZE_CONFIGS.keys()), + help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image.", + ) + parser.add_argument( + "--frame_num", + type=int, + default=None, + help="How many frames to sample from a image or video. The number should be 4n+1", + ) + parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory.") + parser.add_argument( + "--offload_model", + type=str2bool, + default=None, + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage.", + ) + parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.") + parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.") + parser.add_argument("--t5_zero3", action="store_true", default=False, help="Whether to use ZeRO3 for T5.") + parser.add_argument("--t5_cpu", action="store_true", default=False, help="Whether to place T5 model on CPU.") + parser.add_argument("--dit_zero3", action="store_true", default=False, help="Whether to use ZeRO3 for DiT.") + parser.add_argument("--save_file", type=str, default=None, help="The file to save the generated image or video to.") + parser.add_argument("--prompt", type=str, default=None, help="The prompt to generate the image or video from.") + parser.add_argument("--use_prompt_extend", action="store_true", default=False, help="Whether to use prompt extend.") + parser.add_argument( + "--prompt_extend_method", + type=str, + default="local_qwen", + choices=["dashscope", "local_qwen"], + help="The prompt extend method to use.", + ) + parser.add_argument("--prompt_extend_model", type=str, default=None, help="The prompt extend model to use.") + parser.add_argument( + "--prompt_extend_target_lang", + type=str, + default="ch", + choices=["ch", "en"], + help="The target language of prompt extend.", + ) + parser.add_argument("--base_seed", type=int, default=0, help="The seed to use for generating the image or video.") + parser.add_argument("--image", type=str, default=None, help="The image to generate the video from.") + parser.add_argument( + "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++"], help="The solver used to sample." + ) + parser.add_argument("--sample_steps", type=int, default=None, help="The sampling steps.") + parser.add_argument( + "--sample_shift", type=float, default=None, help="Sampling shift factor for flow matching schedulers." + ) + parser.add_argument("--sample_guide_scale", type=float, default=5.0, help="Classifier free guidance scale.") + + # extra for mindspore + parser.add_argument("--ulysses_sp", action="store_true", default=False, help="turn on ulysses parallelism in DiT.") + + args = parser.parse_args() + + _validate_args(args) + + return args + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)], + ) + else: + logging.basicConfig(level=logging.ERROR) + + +def generate(args): + if args.ulysses_sp or args.t5_zero3 or args.dit_zero3: + dist.init_process_group(backend="hccl") + ms.set_auto_parallel_context(parallel_mode="data_parallel") + + rank = dist.get_rank() + world_size = dist.get_world_size() + + if args.ulysses_sp: + args.ulysses_size = world_size + else: + assert not ( + args.t5_zero3 or args.dit_zero3 + ), "t5_zero3 and dit_zero3 are not supported in non-distributed environments." + assert not ( + args.ulysses_size > 1 or args.ring_size > 1 + ), "context parallel are not supported in non-distributed environments." + rank = 0 + world_size = 1 + + _init_logging(rank) + + if args.offload_model is None: + args.offload_model = False + logging.info(f"offload_model is not specified, set to {args.offload_model}.") + + if args.ulysses_size > 1 or args.ring_size > 1: + assert ( + args.ulysses_size * args.ring_size == world_size + ), "The number of ulysses_size and ring_size should be equal to the world size." + + if args.use_prompt_extend: + raise NotImplementedError("prompt_extend is not supported") + + cfg = WAN_CONFIGS[args.task] + if args.ulysses_size > 1: + assert cfg.num_heads % args.ulysses_size == 0, "`num_heads` must be divisible by `ulysses_size`." + + logging.info(f"Generation job args: {args}") + logging.info(f"Generation model config: {cfg}") + + # TODO: GlobalComm.INITED -> mint.is_initialzed + if GlobalComm.INITED: + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + + if "t2v" in args.task or "t2i" in args.task: + if args.prompt is None: + args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] + logging.info(f"Input prompt: {args.prompt}") + if args.use_prompt_extend: + raise NotImplementedError + + logging.info("Creating WanT2V pipeline.") + wan_t2v = wan.WanT2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + rank=rank, + t5_zero3=args.t5_zero3, + dit_zero3=args.dit_zero3, + use_usp=args.ulysses_sp, + t5_cpu=args.t5_cpu, + ) + + logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...") + video = wan_t2v.generate( + args.prompt, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model, + ) + + else: + if args.prompt is None: + args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] + if args.image is None: + args.image = EXAMPLE_PROMPT[args.task]["image"] + logging.info(f"Input prompt: {args.prompt}") + logging.info(f"Input image: {args.image}") + + img = Image.open(args.image).convert("RGB") + if args.use_prompt_extend: + raise NotImplementedError + + logging.info("Creating WanI2V pipeline.") + wan_i2v = wan.WanI2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + rank=rank, + t5_zero3=args.t5_zero3, + dit_zero3=args.dit_zero3, + use_usp=args.ulysses_sp, + t5_cpu=args.t5_cpu, + ) + + logging.info("Generating video ...") + video = wan_i2v.generate( + args.prompt, + img, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model, + ) + + if rank == 0: + if args.save_file is None: + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50] + suffix = ".png" if "t2i" in args.task else ".mp4" + args.save_file = ( + f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + + suffix + ) + + if "t2i" in args.task: + logging.info(f"Saving generated image to {args.save_file}") + cache_image( + tensor=video.squeeze(1)[None], save_file=args.save_file, nrow=1, normalize=True, value_range=(-1, 1) + ) + else: + logging.info(f"Saving generated video to {args.save_file}") + cache_video( + tensor=video[None], + save_file=args.save_file, + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1), + ) + logging.info("Finished.") + + +if __name__ == "__main__": + args = _parse_args() + # TODO: remove global seed + ms.set_seed(args.base_seed) + generate(args) diff --git a/examples/wan2_1/requirements.txt b/examples/wan2_1/requirements.txt new file mode 100644 index 0000000000..34c2419b20 --- /dev/null +++ b/examples/wan2_1/requirements.txt @@ -0,0 +1,12 @@ +opencv-python>=4.9.0.80 +transformers>=4.49.0 +tokenizers>=0.20.3 +torch +tqdm +imageio +easydict +ftfy +dashscope +imageio-ffmpeg +gradio>=5.0.0 +numpy>=1.23.5,<2 diff --git a/examples/wan2_1/wan/__init__.py b/examples/wan2_1/wan/__init__.py new file mode 100644 index 0000000000..3196bf11a6 --- /dev/null +++ b/examples/wan2_1/wan/__init__.py @@ -0,0 +1,2 @@ +from .image2video import WanI2V +from .text2video import WanT2V diff --git a/examples/wan2_1/wan/acceleration/__init__.py b/examples/wan2_1/wan/acceleration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/wan2_1/wan/acceleration/communications.py b/examples/wan2_1/wan/acceleration/communications.py new file mode 100644 index 0000000000..06769bf331 --- /dev/null +++ b/examples/wan2_1/wan/acceleration/communications.py @@ -0,0 +1,95 @@ +from typing import Callable, Literal, Tuple + +import mindspore.mint as mint +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.communication import GlobalComm, get_group_size, get_rank + + +def _split(x: Tensor, dim: int, rank: int, world_size: int) -> Tensor: + dim_size = x.shape[dim] + tensor_list = x.split(dim_size // world_size, axis=dim) + x = tensor_list[rank] + return x + + +def _communicate_along_dim(x: Tensor, dim: int, func: Callable[[Tensor], Tensor]) -> Tensor: + x = x.swapaxes(0, dim) + x = func(x) + x = x.swapaxes(dim, 0) + return x + + +class SplitForwardGatherBackward(nn.Cell): + def __init__( + self, dim: int = 0, grad_scale: Literal["up", "down"] = "down", group: str = GlobalComm.WORLD_COMM_GROUP + ) -> None: + super().__init__() + self.dim = dim + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.gather = ops.AllGather(group=group) + + if grad_scale == "up": + self.scale = self.world_size + else: + self.scale = 1 / self.world_size + + def construct(self, x: Tensor) -> Tensor: + return _split(x, self.dim, self.rank, self.world_size) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = dout * self.scale + dout = _communicate_along_dim(dout, self.dim, self.gather) + return (dout,) + + +class GatherForwardSplitBackward(nn.Cell): + def __init__( + self, dim: int = 0, grad_scale: Literal["up", "down"] = "up", group: str = GlobalComm.WORLD_COMM_GROUP + ) -> None: + super().__init__() + self.dim = dim + self.rank = get_rank(group) + self.world_size = get_group_size(group) + self.gather = ops.AllGather(group=group) + + if grad_scale == "up": + self.scale = self.world_size + else: + self.scale = 1 / self.world_size + + def construct(self, x: Tensor) -> Tensor: + x = _communicate_along_dim(x, self.dim, self.gather) + return x + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = dout * self.scale + dout = _split(dout, self.dim, self.rank, self.world_size) + return (dout,) + + +class AlltoAll(nn.Cell): + def __init__(self, split_dim: int = 2, concat_dim: int = 1, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__() + assert split_dim >= 0 and concat_dim >= 0 + self.split_dim = split_dim + self.concat_dim = concat_dim + self.group = group + + @staticmethod + def _all_to_all(x: Tensor, split_dim: int, concat_dim: int, group: str = GlobalComm.WORLD_COMM_GROUP): + world_size = get_group_size(group) + input_list = list(mint.chunk(x, world_size, dim=split_dim)) + output_list = [mint.empty_like(input_list[0]) for _ in range(world_size)] + mint.distributed.all_to_all(output_list, input_list, group=group) + return mint.cat(output_list, dim=concat_dim) + + def construct(self, x: Tensor) -> Tensor: + x = self._all_to_all(x, self.split_dim, self.concat_dim, group=self.group) + return x + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = self._all_to_all(dout, self.concat_dim, self.split_dim, group=self.group) + return (dout,) diff --git a/examples/wan2_1/wan/acceleration/parallel_states.py b/examples/wan2_1/wan/acceleration/parallel_states.py new file mode 100644 index 0000000000..b9384cac62 --- /dev/null +++ b/examples/wan2_1/wan/acceleration/parallel_states.py @@ -0,0 +1,35 @@ +from typing import Optional + +from mindspore.communication import create_group, get_group_size, get_rank + +_GLOBAL_PARALLEL_GROUPS = dict() + + +def set_sequence_parallel_group(group: str) -> None: + _GLOBAL_PARALLEL_GROUPS["sequence"] = group + + +def get_sequence_parallel_group() -> Optional[str]: + return _GLOBAL_PARALLEL_GROUPS.get("sequence", None) + + +def create_parallel_group(sequence_parallel_shards: int) -> None: + if sequence_parallel_shards <= 1: + raise ValueError( + f"`sequence_parallel_shards` must be larger than 1 to enable sequence parallel, but get `{sequence_parallel_shards}`." + ) + + device_num = get_group_size() + if device_num % sequence_parallel_shards != 0: + raise ValueError( + f"Total number of devices `{device_num}` must be divisible by the number of sequence parallel shards `{sequence_parallel_shards}`." + ) + + rank_id = get_rank() + sp_group_id = rank_id // sequence_parallel_shards + sp_group_rank_ids = list( + range(sp_group_id * sequence_parallel_shards, (sp_group_id + 1) * sequence_parallel_shards) + ) + sp_group_name = f"sp_group_{sp_group_id}" + create_group(sp_group_name, sp_group_rank_ids) + set_sequence_parallel_group(sp_group_name) diff --git a/examples/wan2_1/wan/configs/__init__.py b/examples/wan2_1/wan/configs/__init__.py new file mode 100644 index 0000000000..78cf59fd12 --- /dev/null +++ b/examples/wan2_1/wan/configs/__init__.py @@ -0,0 +1,43 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import copy +import os + +from .wan_i2v_14B import i2v_14B +from .wan_t2v_1_3B import t2v_1_3B +from .wan_t2v_14B import t2v_14B + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +# the config of t2i_14B is the same as t2v_14B +t2i_14B = copy.deepcopy(t2v_14B) +t2i_14B.__name__ = "Config: Wan T2I 14B" + +WAN_CONFIGS = { + "t2v-14B": t2v_14B, + "t2v-1.3B": t2v_1_3B, + "i2v-14B": i2v_14B, + "t2i-14B": t2i_14B, +} + +SIZE_CONFIGS = { + "720*1280": (720, 1280), + "1280*720": (1280, 720), + "480*832": (480, 832), + "832*480": (832, 480), + "1024*1024": (1024, 1024), +} + +MAX_AREA_CONFIGS = { + "720*1280": 720 * 1280, + "1280*720": 1280 * 720, + "480*832": 480 * 832, + "832*480": 832 * 480, +} + +SUPPORTED_SIZES = { + "t2v-14B": ("720*1280", "1280*720", "480*832", "832*480"), + "t2v-1.3B": ("480*832", "832*480"), + "i2v-14B": ("720*1280", "1280*720", "480*832", "832*480"), + "t2i-14B": tuple(SIZE_CONFIGS.keys()), +} diff --git a/examples/wan2_1/wan/configs/shared_config.py b/examples/wan2_1/wan/configs/shared_config.py new file mode 100644 index 0000000000..cdadb4697c --- /dev/null +++ b/examples/wan2_1/wan/configs/shared_config.py @@ -0,0 +1,23 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from easydict import EasyDict + +import mindspore as ms + +# ------------------------ Wan shared config ------------------------# +wan_shared_cfg = EasyDict() + +# t5 +wan_shared_cfg.t5_model = "umt5_xxl" +wan_shared_cfg.t5_dtype = ms.bfloat16 +wan_shared_cfg.text_len = 512 + +# transformer +wan_shared_cfg.param_dtype = ms.bfloat16 + +# inference +wan_shared_cfg.num_train_timesteps = 1000 +wan_shared_cfg.sample_fps = 16 +wan_shared_cfg.sample_neg_prompt = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留," + "丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +) diff --git a/examples/wan2_1/wan/configs/wan_i2v_14B.py b/examples/wan2_1/wan/configs/wan_i2v_14B.py new file mode 100644 index 0000000000..221e134ba6 --- /dev/null +++ b/examples/wan2_1/wan/configs/wan_i2v_14B.py @@ -0,0 +1,36 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from easydict import EasyDict + +import mindspore as ms + +from .shared_config import wan_shared_cfg + +# ------------------------ Wan I2V 14B ------------------------# + +i2v_14B = EasyDict(__name__="Config: Wan I2V 14B") +i2v_14B.update(wan_shared_cfg) + +i2v_14B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" +i2v_14B.t5_tokenizer = "google/umt5-xxl" + +# clip +i2v_14B.clip_model = "clip_xlm_roberta_vit_h_14" +i2v_14B.clip_dtype = ms.float16 +i2v_14B.clip_checkpoint = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" +i2v_14B.clip_tokenizer = "xlm-roberta-large" + +# vae +i2v_14B.vae_checkpoint = "Wan2.1_VAE.pth" +i2v_14B.vae_stride = (4, 8, 8) + +# transformer +i2v_14B.patch_size = (1, 2, 2) +i2v_14B.dim = 5120 +i2v_14B.ffn_dim = 13824 +i2v_14B.freq_dim = 256 +i2v_14B.num_heads = 40 +i2v_14B.num_layers = 40 +i2v_14B.window_size = (-1, -1) +i2v_14B.qk_norm = True +i2v_14B.cross_attn_norm = True +i2v_14B.eps = 1e-6 diff --git a/examples/wan2_1/wan/configs/wan_t2v_14B.py b/examples/wan2_1/wan/configs/wan_t2v_14B.py new file mode 100644 index 0000000000..ac3ae0161a --- /dev/null +++ b/examples/wan2_1/wan/configs/wan_t2v_14B.py @@ -0,0 +1,29 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +# ------------------------ Wan T2V 14B ------------------------# + +t2v_14B = EasyDict(__name__="Config: Wan T2V 14B") +t2v_14B.update(wan_shared_cfg) + +# t5 +t2v_14B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" +t2v_14B.t5_tokenizer = "google/umt5-xxl" + +# vae +t2v_14B.vae_checkpoint = "Wan2.1_VAE.pth" +t2v_14B.vae_stride = (4, 8, 8) + +# transformer +t2v_14B.patch_size = (1, 2, 2) +t2v_14B.dim = 5120 +t2v_14B.ffn_dim = 13824 +t2v_14B.freq_dim = 256 +t2v_14B.num_heads = 40 +t2v_14B.num_layers = 40 +t2v_14B.window_size = (-1, -1) +t2v_14B.qk_norm = True +t2v_14B.cross_attn_norm = True +t2v_14B.eps = 1e-6 diff --git a/examples/wan2_1/wan/configs/wan_t2v_1_3B.py b/examples/wan2_1/wan/configs/wan_t2v_1_3B.py new file mode 100644 index 0000000000..63d0a037ea --- /dev/null +++ b/examples/wan2_1/wan/configs/wan_t2v_1_3B.py @@ -0,0 +1,29 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +# ------------------------ Wan T2V 1.3B ------------------------# + +t2v_1_3B = EasyDict(__name__="Config: Wan T2V 1.3B") +t2v_1_3B.update(wan_shared_cfg) + +# t5 +t2v_1_3B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" +t2v_1_3B.t5_tokenizer = "google/umt5-xxl" + +# vae +t2v_1_3B.vae_checkpoint = "Wan2.1_VAE.pth" +t2v_1_3B.vae_stride = (4, 8, 8) + +# transformer +t2v_1_3B.patch_size = (1, 2, 2) +t2v_1_3B.dim = 1536 +t2v_1_3B.ffn_dim = 8960 +t2v_1_3B.freq_dim = 256 +t2v_1_3B.num_heads = 12 +t2v_1_3B.num_layers = 30 +t2v_1_3B.window_size = (-1, -1) +t2v_1_3B.qk_norm = True +t2v_1_3B.cross_attn_norm = True +t2v_1_3B.eps = 1e-6 diff --git a/examples/wan2_1/wan/image2video.py b/examples/wan2_1/wan/image2video.py new file mode 100644 index 0000000000..6b3035e7d1 --- /dev/null +++ b/examples/wan2_1/wan/image2video.py @@ -0,0 +1,289 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math +import os +import random +import sys +from functools import partial + +import numpy as np +from tqdm import tqdm + +import mindspore as ms +import mindspore.mint as mint +import mindspore.mint.distributed as dist +import mindspore.mint.nn.functional as functional +from mindspore.communication import GlobalComm, get_group_size +from mindspore.nn.utils import no_init_parameters + +from mindone.trainers.zero import prepare_network + +from .acceleration.parallel_states import create_parallel_group +from .modules.clip import CLIPModel +from .modules.model import WanModel +from .modules.t5 import T5EncoderModel +from .modules.vae import WanVAE +from .utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from .utils.utils import pil2tensor + + +class WanI2V: + def __init__( + self, + config, + checkpoint_dir, + rank=0, + t5_zero3=False, + dit_zero3=False, + use_usp=False, + t5_cpu=False, + ): + r""" + Initializes the image-to-video generation model components. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_zero3 (`bool`, *optional*, defaults to False): + Enable ZeRO3 sharding for T5 model + dit_zero3 (`bool`, *optional*, defaults to False): + Enable ZeRO3 sharding for DiT model + use_usp (`bool`, *optional*, defaults to False): + Enable distribution strategy of USP. + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_zero3. + """ + self.config = config + self.rank = rank + self.use_usp = use_usp + self.t5_cpu = t5_cpu + + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + if use_usp: + self.sp_size = get_group_size(GlobalComm.WORLD_COMM_GROUP) + create_parallel_group(self.sp_size) + else: + self.sp_size = 1 + + shard_fn = partial(prepare_network, zero_stage=3, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn=shard_fn if t5_zero3 else None, + ) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = WanVAE(vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype=self.param_dtype) + + self.clip = CLIPModel( + dtype=config.clip_dtype, + checkpoint_path=os.path.join(checkpoint_dir, config.clip_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer), + ) + + logging.info(f"Creating WanModel from {checkpoint_dir}") + with no_init_parameters(): + self.model = WanModel.from_pretrained(checkpoint_dir, mindspore_dtype=self.param_dtype) + self.model.init_parameters_data() + self.model.set_train(False) + for param in self.model.trainable_params(): + param.requires_grad = False + + # TODO: GlobalComm.INITED -> mint.is_initialzed + if GlobalComm.INITED: + dist.barrier() + if dit_zero3: + self.model = shard_fn(self.model) + + self.sample_neg_prompt = config.sample_neg_prompt + + def generate( + self, + input_prompt, + img, + max_area=720 * 1280, + frame_num=81, + shift=5.0, + sample_solver="unipc", + sampling_steps=40, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + ): + r""" + Generates video frames from input image and text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation. + img (PIL.Image.Image): + Input image tensor. Shape: [3, H, W] + max_area (`int`, *optional*, defaults to 720*1280): + Maximum pixel area for latent space calculation. Controls video resolution scaling + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + mindspore.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from max_area) + - W: Frame width from max_area) + """ + img = pil2tensor(img).sub_(0.5).div_(0.5) + + F = frame_num + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round(np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]) + lat_w = round(np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + + max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (self.patch_size[1] * self.patch_size[2]) + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size + + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = ms.Generator() + seed_g.manual_seed(seed) + noise = mint.randn(16, 21, lat_h, lat_w, dtype=self.param_dtype, generator=seed_g) + + msk = mint.ones((1, 81, lat_h, lat_w)) + msk[:, 1:] = 0 + msk = mint.concat([mint.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + + # preprocess + if not self.t5_cpu: + context = self.text_encoder([input_prompt]) + context_null = self.text_encoder([n_prompt]) + if offload_model: + raise NotImplementedError() + else: + raise NotImplementedError() + + clip_context = self.clip.visual([img[:, None, :, :]]).to(self.param_dtype) + if offload_model: + raise NotImplementedError() + + y = self.vae.encode( + [ + mint.concat( + [ + functional.interpolate(img[None], size=(h, w), mode="bicubic") + .transpose(0, 1) + .to(self.param_dtype), + mint.zeros((3, 80, h, w), dtype=self.param_dtype), + ], + dim=1, + ) + ] + )[0] + y = mint.concat([msk, y]).to(self.param_dtype) + + # evaluation mode + if sample_solver == "unipc": + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(sampling_steps, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == "dpm++": + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps(sample_scheduler, sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latent = noise + + arg_c = { + "context": [context[0]], + "clip_fea": clip_context, + "seq_len": max_seq_len, + "y": [y], + } + + arg_null = { + "context": context_null, + "clip_fea": clip_context, + "seq_len": max_seq_len, + "y": [y], + } + + if offload_model: + raise NotImplementedError() + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = [latent] + timestep = [t] + + timestep = mint.stack(timestep) + + noise_pred_cond = self.model(latent_model_input, t=timestep, **arg_c)[0] + if offload_model: + raise NotImplementedError() + noise_pred_uncond = self.model(latent_model_input, t=timestep, **arg_null)[0] + if offload_model: + raise NotImplementedError() + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=seed_g + )[0] + latent = temp_x0.squeeze(0) + + x0 = [latent] + del latent_model_input, timestep + + if offload_model: + raise NotImplementedError() + + if self.rank == 0: + # TODO: handle this + # np.save("latent.npy", x0[0].to(ms.float32).asnumpy()) + videos = self.vae.decode(x0) + + del noise, latent + del sample_scheduler + if offload_model: + raise NotImplementedError() + # TODO: GlobalComm.INITED -> mint.is_initialzed + if GlobalComm.INITED: + dist.barrier() + + return videos[0] if self.rank == 0 else None diff --git a/examples/wan2_1/wan/modules/__init__.py b/examples/wan2_1/wan/modules/__init__.py new file mode 100644 index 0000000000..8b86ac828a --- /dev/null +++ b/examples/wan2_1/wan/modules/__init__.py @@ -0,0 +1,14 @@ +from .model import WanModel +from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model +from .tokenizers import HuggingfaceTokenizer +from .vae import WanVAE + +__all__ = [ + "WanVAE", + "WanModel", + "T5Model", + "T5Encoder", + "T5Decoder", + "T5EncoderModel", + "HuggingfaceTokenizer", +] diff --git a/examples/wan2_1/wan/modules/clip.py b/examples/wan2_1/wan/modules/clip.py new file mode 100644 index 0000000000..fba9e2f874 --- /dev/null +++ b/examples/wan2_1/wan/modules/clip.py @@ -0,0 +1,578 @@ +# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip'' +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math +from typing import List, Tuple + +import numpy as np + +import mindspore as ms +import mindspore.dataset.vision as vision +import mindspore.mint as mint +import mindspore.mint.nn.functional as F +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Parameter, Tensor +from mindspore.dataset.transforms import Compose +from mindspore.nn.utils import no_init_parameters + +from ..utils.utils import load_pth +from .tokenizers import HuggingfaceTokenizer +from .xlm_roberta import XLMRoberta + +__all__ = [ + "XLMRobertaCLIP", + "clip_xlm_roberta_vit_h_14", + "CLIPModel", +] + + +def pos_interpolate(pos: Tensor, seq_len: int) -> Tensor: + if pos.shape[1] == seq_len: + return pos + else: + src_grid = int(math.sqrt(pos.shape[1])) + tar_grid = int(math.sqrt(seq_len)) + n = pos.shape[1] - src_grid * src_grid + return mint.cat( + [ + pos[:, :n], + F.interpolate( + pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2), + size=(tar_grid, tar_grid), + mode="bicubic", + align_corners=False, + ) + .flatten(2) + .transpose(1, 2), + ], + dim=1, + ) + + +class QuickGELU(nn.Cell): + def construct(self, x: Tensor) -> Tensor: + return x * mint.sigmoid(1.702 * x) + + +class LayerNorm(mint.nn.LayerNorm): + # TODO: to float32 + def construct(self, x: Tensor) -> Tensor: + return super().construct(x).type_as(x) + + +class SelfAttention(nn.Cell): + def __init__( + self, + dim: int, + num_heads: int, + causal: bool = False, + attn_dropout: float = 0.0, + proj_dropout: float = 0.0, + dtype: ms.Type = ms.float32, + ) -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.causal = causal + self.attn_dropout = attn_dropout + self.proj_dropout = proj_dropout + + # layers + self.to_qkv = mint.nn.Linear(dim, dim * 3, dtype=dtype) + self.proj = mint.nn.Linear(dim, dim, dtype=dtype) + + def construct(self, x: Tensor) -> Tensor: + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.shape, self.num_heads, self.head_dim + + # compute query, key, value + q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) + + # compute attention + p = self.attn_dropout if self.training else 0.0 + x = ops.flash_attention_score( + query=q, + key=k, + value=v, + head_num=self.num_heads, + keep_prob=1.0 - p, + input_layout="BSND", + ) + x = x.reshape(b, s, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + return x + + +class SwiGLU(nn.Cell): + def __init__(self, dim: int, mid_dim: int, dtype: ms.Type = ms.float32) -> None: + super().__init__() + self.dim = dim + self.mid_dim = mid_dim + + # layers + self.fc1 = mint.nn.Linear(dim, mid_dim, dtype=dtype) + self.fc2 = mint.nn.Linear(dim, mid_dim, dtype=dtype) + self.fc3 = mint.nn.Linear(mid_dim, dim, dtype=dtype) + + def construct(self, x: Tensor) -> Tensor: + x = F.silu(self.fc1(x)) * self.fc2(x) + x = self.fc3(x) + return x + + +class AttentionBlock(nn.Cell): + def __init__( + self, + dim: int, + mlp_ratio: float, + num_heads: int, + post_norm: bool = False, + causal: bool = False, + activation: str = "quick_gelu", + attn_dropout: str = 0.0, + proj_dropout: str = 0.0, + norm_eps: str = 1e-5, + dtype: ms.Type = ms.float32, + ) -> None: + assert activation in ["quick_gelu", "gelu", "swi_glu"] + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.post_norm = post_norm + self.causal = causal + self.norm_eps = norm_eps + + # layers + self.norm1 = LayerNorm(dim, eps=norm_eps, dtype=dtype) + self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout, dtype=dtype) + self.norm2 = LayerNorm(dim, eps=norm_eps, dtype=dtype) + if activation == "swi_glu": + self.mlp = SwiGLU(dim, int(dim * mlp_ratio), dtype=dtype) + else: + self.mlp = nn.SequentialCell( + mint.nn.Linear(dim, int(dim * mlp_ratio), dtype=dtype), + QuickGELU() if activation == "quick_gelu" else mint.nn.GELU(), + mint.nn.Linear(int(dim * mlp_ratio), dim, dtype=dtype), + mint.nn.Dropout(proj_dropout), + ) + + def construct(self, x: Tensor) -> Tensor: + if self.post_norm: + x = x + self.norm1(self.attn(x)) + x = x + self.norm2(self.mlp(x)) + else: + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class AttentionPool(nn.Cell): + def __init__( + self, + dim: int, + mlp_ratio: float, + num_heads: int, + activation: str = "gelu", + proj_dropout: float = 0.0, + norm_eps: float = 1e-5, + dtype: ms.Type = ms.float32, + ) -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.proj_dropout = proj_dropout + self.norm_eps = norm_eps + + # layers + gain = 1.0 / math.sqrt(dim) + self.cls_embedding = Parameter(Tensor(gain * np.random.randn(1, 1, dim), dtype=dtype)) + self.to_q = mint.nn.Linear(dim, dim, dtype=dtype) + self.to_kv = mint.nn.Linear(dim, dim * 2, dtype=dtype) + self.proj = mint.nn.Linear(dim, dim, dtype=dtype) + self.norm = LayerNorm(dim, eps=norm_eps, dtype=dtype) + self.mlp = nn.SequentialCell( + mint.nn.Linear(dim, int(dim * mlp_ratio), dtype=dtype), + QuickGELU() if activation == "quick_gelu" else nn.GELU(), + mint.nn.Linear(int(dim * mlp_ratio), dim, dtype=dtype), + mint.nn.Dropout(proj_dropout), + ) + + def construct(self, x: Tensor) -> Tensor: + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.shape, self.num_heads, self.head_dim + + # compute query, key, value + q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand((b, -1, -1, -1)) + k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) + + # compute attention + x = ops.flash_attention_score( + query=q, + key=k, + value=v, + head_num=self.num_heads, + input_layout="BSND", + ) + x = x.reshape(b, 1, c) + + # output + x = self.proj(x) + x = F.dropout(x, self.proj_dropout, self.training) + + # mlp + x = x + self.mlp(self.norm(x)) + return x[:, 0] + + +class VisionTransformer(nn.Cell): + def __init__( + self, + image_size: int = 224, + patch_size: int = 16, + dim: int = 768, + mlp_ratio: int = 4, + out_dim: int = 512, + num_heads: int = 12, + num_layers: int = 12, + pool_type: str = "token", + pre_norm: bool = True, + post_norm: bool = False, + activation: str = "quick_gelu", + attn_dropout: float = 0.0, + proj_dropout: float = 0.0, + embedding_dropout: float = 0.0, + norm_eps: float = 1e-5, + dtype: ms.Type = ms.float32, + ) -> None: + if image_size % patch_size != 0: + print("[WARNING] image_size is not divisible by patch_size", flush=True) + assert pool_type in ("token", "token_fc", "attn_pool") + out_dim = out_dim or dim + super().__init__() + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = (image_size // patch_size) ** 2 + self.dim = dim + self.mlp_ratio = mlp_ratio + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pool_type = pool_type + self.post_norm = post_norm + self.norm_eps = norm_eps + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = mint.nn.Conv2d( + 3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm, dtype=dtype + ) + if pool_type in ("token", "token_fc"): + self.cls_embedding = Parameter(Tensor(gain * np.random.randn(1, 1, dim), dtype=dtype)) + self.pos_embedding = Parameter( + Tensor( + gain * np.random.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim), + dtype=dtype, + ) + ) + self.dropout = mint.nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype) if pre_norm else None + self.transformer = nn.SequentialCell( + *[ + AttentionBlock( + dim, + mlp_ratio, + num_heads, + post_norm, + False, + activation, + attn_dropout, + proj_dropout, + norm_eps, + dtype=dtype, + ) + for _ in range(num_layers) + ] + ) + self.post_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype) + + # head + if pool_type == "token": + self.head = Parameter(Tensor(gain * np.random.randn(dim, out_dim), dtype=dtype)) + elif pool_type == "token_fc": + self.head = mint.nn.Linear(dim, out_dim, dtype=dtype) + elif pool_type == "attn_pool": + self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps, dtype=dtype) + + def construct(self, x: Tensor, interpolation: bool = False, use_31_block: bool = False) -> Tensor: + b = x.shape[0] + + # embeddings + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) + if self.pool_type in ("token", "token_fc"): + x = mint.cat([self.cls_embedding.expand((b, -1, -1)), x], dim=1) + if interpolation: + e = pos_interpolate(self.pos_embedding, x.shape[1]) + else: + e = self.pos_embedding + x = self.dropout(x + e) + if self.pre_norm is not None: + x = self.pre_norm(x) + + # transformer + if use_31_block: + x = self.transformer[:-1](x) + return x + else: + x = self.transformer(x) + return x + + +class XLMRobertaWithHead(XLMRoberta): + def __init__(self, dtype: ms.Type = ms.float32, **kwargs) -> None: + self.out_dim = kwargs.pop("out_dim") + super().__init__(dtype=dtype, **kwargs) + + # head + mid_dim = (self.dim + self.out_dim) // 2 + self.head = nn.SequentialCell( + mint.nn.Linear(self.dim, mid_dim, bias=False, dtype=dtype), + mint.nn.GELU(), + mint.nn.Linear(mid_dim, self.out_dim, bias=False, dtype=dtype), + ) + + def construct(self, ids: Tensor) -> Tensor: + # xlm-roberta + x = super().construct(ids) + + # average pooling + mask = ids.ne(self.pad_id).unsqueeze(-1).to(x.dtype) + x = (x * mask).sum(dim=1) / mask.sum(dim=1) + + # head + x = self.head(x) + return x + + +class XLMRobertaCLIP(nn.Cell): + def __init__( + self, + embed_dim: int = 1024, + image_size: int = 224, + patch_size: int = 14, + vision_dim: int = 1280, + vision_mlp_ratio: float = 4, + vision_heads: int = 16, + vision_layers: int = 32, + vision_pool: str = "token", + vision_pre_norm: bool = True, + vision_post_norm: bool = False, + activation="gelu", + vocab_size: int = 250002, + max_text_len: int = 514, + type_size: int = 1, + pad_id: int = 1, + text_dim: int = 1024, + text_heads: int = 16, + text_layers: int = 24, + text_post_norm: bool = True, + text_dropout: float = 0.1, + attn_dropout: float = 0.0, + proj_dropout: float = 0.0, + embedding_dropout: float = 0.0, + norm_eps: float = 1e-5, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_mlp_ratio = vision_mlp_ratio + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vision_pre_norm = vision_pre_norm + self.vision_post_norm = vision_post_norm + self.activation = activation + self.vocab_size = vocab_size + self.max_text_len = max_text_len + self.type_size = type_size + self.pad_id = pad_id + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + self.text_post_norm = text_post_norm + self.norm_eps = norm_eps + self.dtype = dtype + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + mlp_ratio=vision_mlp_ratio, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + pool_type=vision_pool, + pre_norm=vision_pre_norm, + post_norm=vision_post_norm, + activation=activation, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout, + norm_eps=norm_eps, + dtype=dtype, + ) + self.textual = XLMRobertaWithHead( + vocab_size=vocab_size, + max_seq_len=max_text_len, + type_size=type_size, + pad_id=pad_id, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + post_norm=text_post_norm, + dropout=text_dropout, + dtype=dtype, + ) + self.log_scale = Parameter(Tensor(math.log(1 / 0.07) * np.ones([]), dtype=dtype)) + + def construct(self, imgs: Tensor, txt_ids: Tensor) -> Tuple[Tensor, Tensor]: + """ + imgs: [B, 3, H, W] of ms.float32. + - mean: [0.48145466, 0.4578275, 0.40821073] + - std: [0.26862954, 0.26130258, 0.27577711] + txt_ids: [B, L] of ms.int32. + Encoded by data.CLIPTokenizer. + """ + xi = self.visual(imgs) + xt = self.textual(txt_ids) + return xi, xt + + +def _clip( + pretrained=False, + pretrained_name=None, + model_cls=XLMRobertaCLIP, + return_transforms=False, + return_tokenizer=False, + tokenizer_padding="eos", + dtype: ms.Type = ms.float32, + **kwargs, +): + # init model + model = model_cls(**kwargs, dtype=dtype) + output = (model,) + + # init transforms + if return_transforms: + # mean and std + if "siglip" in pretrained_name.lower(): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + else: + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # transforms + transforms = Compose( + [ + vision.Resize((model.image_size, model.image_size), interpolation=vision.Inter.BICUBIC), + vision.ToTensor(), + vision.Normalize(mean=mean, std=std, is_hwc=False), + ] + ) + output += (transforms,) + return output[0] if len(output) == 1 else output + + +def clip_xlm_roberta_vit_h_14( + pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", dtype: ms.Type = ms.float32, **kwargs +): + cfg = dict( + embed_dim=1024, + image_size=224, + patch_size=14, + vision_dim=1280, + vision_mlp_ratio=4, + vision_heads=16, + vision_layers=32, + vision_pool="token", + activation="gelu", + vocab_size=250002, + max_text_len=514, + type_size=1, + pad_id=1, + text_dim=1024, + text_heads=16, + text_layers=24, + text_post_norm=True, + text_dropout=0.1, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0, + ) + cfg.update(**kwargs) + return _clip(pretrained, pretrained_name, XLMRobertaCLIP, dtype=dtype, **cfg) + + +class CLIPModel: + def __init__(self, dtype, checkpoint_path, tokenizer_path): + self.dtype = dtype + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + with no_init_parameters(): + model, self.transforms = clip_xlm_roberta_vit_h_14( + pretrained=False, + return_transforms=True, + return_tokenizer=False, + dtype=dtype, + ) + model.set_train(False) + for param in model.trainable_params(): + param.requires_grad = False + + if checkpoint_path is not None: + logging.info(f"loading {checkpoint_path}") + if checkpoint_path.endswith(".pth"): + param_dict = load_pth(checkpoint_path, dtype=model.dtype) + ms.load_param_into_net(model, param_dict) + else: + ms.load_checkpoint(checkpoint_path, model) + model.init_parameters_data() + + self.model = model + + # init tokenizer + self.tokenizer = HuggingfaceTokenizer( + name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace" + ) + + def visual(self, videos: List[Tensor]) -> Tensor: + # preprocess + size = (self.model.image_size,) * 2 + videos = mint.cat( + [F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos] + ) + videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5).asnumpy()) + + out = self.model.visual(Tensor(videos, dtype=self.model.dtype), use_31_block=True) + return out diff --git a/examples/wan2_1/wan/modules/model.py b/examples/wan2_1/wan/modules/model.py new file mode 100644 index 0000000000..c741feae89 --- /dev/null +++ b/examples/wan2_1/wan/modules/model.py @@ -0,0 +1,678 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +from typing import List, Optional, Tuple + +import numpy as np + +import mindspore as ms +import mindspore.mint as mint +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Parameter, Tensor +from mindspore.communication import GlobalComm, get_group_size + +from mindone.diffusers.configuration_utils import ConfigMixin, register_to_config +from mindone.diffusers.models.modeling_utils import ModelMixin +from mindone.models.utils import normal_, xavier_uniform_, zeros_ + +from ..acceleration.communications import AlltoAll, GatherForwardSplitBackward, SplitForwardGatherBackward +from ..acceleration.parallel_states import get_sequence_parallel_group + +__all__ = ["WanModel"] + + +def sinusoidal_embedding_1d(dim: int, position: Tensor) -> Tensor: + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(ms.float32) + + # calculation + sinusoid = mint.outer(position, mint.pow(10000, -mint.arange(half).to(position.dtype).div(half))) + x = mint.cat([mint.cos(sinusoid), mint.sin(sinusoid)], dim=1) + return x + + +def rope_params(max_seq_len: int, dim: int, theta: float = 10000) -> Tensor: + assert dim % 2 == 0 + freqs = mint.outer(mint.arange(max_seq_len), 1.0 / mint.pow(theta, mint.arange(0, dim, 2).to(ms.float32).div(dim))) + freqs = mint.stack([mint.cos(freqs), mint.sin(freqs)], dim=-1) + return freqs + + +def complex_mult(a: Tensor, b: Tensor) -> Tensor: + a_real, a_complex = a[..., 0], a[..., 1] + b_real, b_complex = b[..., 0], b[..., 1] + out_real = a_real * b_real - a_complex * b_complex + out_complex = a_real * b_complex + b_real * a_complex + return mint.stack([out_real, out_complex], dim=-1) + + +def rope_apply(x: Tensor, grid_sizes: Tensor, freqs: Tensor) -> Tensor: + n, c = x.shape[2], x.shape[3] // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = x[i, :seq_len].to(ms.float32).reshape(seq_len, n, -1, 2) + freqs_i = mint.cat( + [ + freqs[0][:f].view(f, 1, 1, -1, 2).expand((f, h, w, -1, 2)), + freqs[1][:h].view(1, h, 1, -1, 2).expand((f, h, w, -1, 2)), + freqs[2][:w].view(1, 1, w, -1, 2).expand((f, h, w, -1, 2)), + ], + dim=-2, + ).reshape(seq_len, 1, -1, 2) + + # apply rotary embedding + x_i = complex_mult(x_i, freqs_i).flatten(2) + x_i = mint.cat([x_i.to(x.dtype), x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return mint.stack(output) + + +class WanRMSNorm(nn.Cell): + def __init__(self, dim: int, eps: float = 1e-5, dtype: ms.Type = ms.float32) -> None: + super().__init__() + self.dim = dim + self.eps = eps + self.weight = Parameter(Tensor(np.ones(dim), dtype=dtype)) + + def construct(self, x: Tensor) -> Tensor: + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x: Tensor) -> Tensor: + return x * mint.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(mint.nn.LayerNorm): + def __init__( + self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False, dtype: ms.Type = ms.float32 + ) -> None: + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps, dtype=dtype) + + def construct(self, x: Tensor) -> Tensor: + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + # TODO: to float32 + return super().construct(x).type_as(x) + + +class WanSelfAttention(nn.Cell): + def __init__( + self, + dim: int, + num_heads: int, + window_size: Tuple[int, int] = (-1, -1), + qk_norm: bool = True, + eps=1e-6, + dtype: ms.Type = ms.float32, + ) -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = mint.nn.Linear(dim, dim, dtype=dtype) + self.k = mint.nn.Linear(dim, dim, dtype=dtype) + self.v = mint.nn.Linear(dim, dim, dtype=dtype) + self.o = mint.nn.Linear(dim, dim, dtype=dtype) + self.norm_q = WanRMSNorm(dim, eps=eps, dtype=dtype) if qk_norm else mint.nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps, dtype=dtype) if qk_norm else mint.nn.Identity() + + sp_group = get_sequence_parallel_group() + if sp_group is not None: + self.all_to_all = AlltoAll(split_dim=2, concat_dim=1, group=sp_group) + self.all_to_all_back = AlltoAll(split_dim=1, concat_dim=2, group=sp_group) + self.sp_size = get_group_size(sp_group) + else: + self.all_to_all = mint.nn.Identity() + self.all_to_all_back = mint.nn.Identity() + self.sp_size = 1 + + def construct(self, x: Tensor, seq_lens: Tensor, grid_sizes: Tensor, freqs: Tensor) -> Tensor: + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + q = self.all_to_all(q) + k = self.norm_k(self.k(x)).view(b, s, n, d) + k = self.all_to_all(k) + v = self.v(x).view(b, s, n, d) + v = self.all_to_all(v) + return q, k, v + + q, k, v = qkv_fn(x) + + assert self.window_size == (-1, -1) + + x = ops.flash_attention_score( + query=rope_apply(q, grid_sizes, freqs), + key=rope_apply(k, grid_sizes, freqs), + value=v, + head_num=self.num_heads // self.sp_size, + actual_seq_kvlen=seq_lens // self.sp_size, + scalar_value=1 / math.sqrt(q.shape[-1]), + input_layout="BSND", + ) + + # output + x = self.all_to_all_back(x) + x = x.flatten(2) + x = self.o(x) + return x + + +class WanT2VCrossAttention(WanSelfAttention): + def construct(self, x: Tensor, context: Tensor, context_lens: Optional[Tensor]) -> Tensor: + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.shape[0], self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # compute attention + x = ops.flash_attention_score( + q, + k, + v, + head_num=self.num_heads, + actual_seq_kvlen=context_lens, + scalar_value=1 / math.sqrt(q.shape[-1]), + input_layout="BSND", + ) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanI2VCrossAttention(WanSelfAttention): + def __init__( + self, + dim: int, + num_heads: int, + window_size: Tuple[int, int] = (-1, -1), + qk_norm: bool = True, + eps: float = 1e-6, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__(dim, num_heads, window_size, qk_norm, eps, dtype=dtype) + + self.k_img = mint.nn.Linear(dim, dim, dtype=dtype) + self.v_img = mint.nn.Linear(dim, dim, dtype=dtype) + self.norm_k_img = WanRMSNorm(dim, eps=eps, dtype=dtype) if qk_norm else mint.nn.Identity() + + def construct(self, x: Tensor, context: Tensor, context_lens: Tensor) -> Tensor: + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + context_img = context[:, :257] + context = context[:, 257:] + b, n, d = x.shape[0], self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) + k_img = self.all_to_all(k_img) + v_img = self.v_img(context_img).view(b, -1, n, d) + v_img = self.all_to_all(v_img) + img_x = ops.flash_attention_score( + q, + k_img, + v_img, + head_num=self.num_heads, + scalar_value=1 / math.sqrt(q.shape[-1]), + input_layout="BSND", + ) + # compute attention + x = ops.flash_attention_score( + q, + k, + v, + head_num=self.num_heads, + actual_seq_kvlen=context_lens, + scalar_value=1 / math.sqrt(q.shape[-1]), + input_layout="BSND", + ) + + # output + x = x.flatten(2) + img_x = img_x.flatten(2) + x = x + img_x + x = self.o(x) + return x + + +WAN_CROSSATTENTION_CLASSES = { + "t2v_cross_attn": WanT2VCrossAttention, + "i2v_cross_attn": WanI2VCrossAttention, +} + + +class WanAttentionBlock(nn.Cell): + def __init__( + self, + cross_attn_type: str, + dim: int, + ffn_dim: int, + num_heads: int, + window_size: Tuple[int, int] = (-1, -1), + qk_norm: bool = True, + cross_attn_norm: bool = False, + eps: float = 1e-6, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps, dtype=dtype) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps, dtype=dtype) + self.norm3 = ( + WanLayerNorm(dim, eps, elementwise_affine=True, dtype=dtype) if cross_attn_norm else mint.nn.Identity() + ) + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type]( + dim, num_heads, (-1, -1), qk_norm, eps, dtype=dtype + ) + self.norm2 = WanLayerNorm(dim, eps, dtype=dtype) + # TODO: mint.nn.GELU -> mint.nn.GELU(approximate="tanh") + self.ffn = nn.SequentialCell( + mint.nn.Linear(dim, ffn_dim, dtype=dtype), + mint.nn.GELU(), + mint.nn.Linear(ffn_dim, dim, dtype=dtype), + ) + + # modulation + self.modulation = Parameter(Tensor(np.random.randn(1, 6, dim) / dim**0.5, dtype=dtype)) + + def construct( + self, + x: Tensor, + e: Tensor, + seq_lens: Tensor, + grid_sizes: Tensor, + freqs: Tensor, + context: Tensor, + context_lens: Tensor, + ) -> Tensor: + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + e = (self.modulation + e).chunk(6, dim=1) + + # self-attention + y = self.self_attn(self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs) + x = x + y * e[2] + + # cross-attention & ffn function + def cross_attn_ffn(x: Tensor, context: Tensor, context_lens: Tensor, e: Tensor) -> Tensor: + x = x + self.cross_attn(self.norm3(x), context, context_lens) + y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) + x = x + y * e[5] + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class Head(nn.Cell): + def __init__(self, dim: int, out_dim: int, patch_size: int, eps: float = 1e-6, dtype: ms.Type = ms.float32) -> None: + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps, dtype=dtype) + self.head = mint.nn.Linear(dim, out_dim, dtype=dtype) + + # modulation + self.modulation = Parameter(Tensor(np.random.randn(1, 2, dim) / dim**0.5, dtype=dtype)) + + def construct(self, x: Tensor, e: Tensor) -> Tensor: + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = self.head(self.norm(x) * (1 + e[1]) + e[0]) + return x + + +class MLPProj(nn.Cell): + def __init__(self, in_dim: int, out_dim: int, dtype: ms.Type = ms.float32) -> None: + super().__init__() + + self.proj = nn.SequentialCell( + mint.nn.LayerNorm(in_dim, dtype=dtype), + mint.nn.Linear(in_dim, in_dim, dtype=dtype), + mint.nn.GELU(), + mint.nn.Linear(in_dim, out_dim, dtype=dtype), + mint.nn.LayerNorm(out_dim, dtype=dtype), + ) + + def construct(self, image_embeds: Tensor) -> Tensor: + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class WanModel(ModelMixin, ConfigMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"] + _no_split_modules = ["WanAttentionBlock"] + + @register_to_config + def __init__( + self, + model_type: str = "t2v", + patch_size: Tuple[int, int, int] = (1, 2, 2), + text_len: int = 512, + in_dim: int = 16, + dim: int = 2048, + ffn_dim: int = 8192, + freq_dim: int = 256, + text_dim: int = 4096, + out_dim: int = 16, + num_heads: int = 16, + num_layers: int = 32, + window_size: Tuple[int, int] = (-1, -1), + qk_norm: bool = True, + cross_attn_norm: bool = True, + eps: float = 1e-6, + dtype: ms.Type = ms.float32, + ) -> None: + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ["t2v", "i2v"] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # embeddings + self.patch_embedding = mint.nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, dtype=dtype) + # TODO: mint.nn.GELU -> mint.nn.GELU(approximate="tanh") + self.text_embedding = nn.SequentialCell( + mint.nn.Linear(text_dim, dim, dtype=dtype), + mint.nn.GELU(), + mint.nn.Linear(dim, dim, dtype=dtype), + ) + + self.time_embedding = nn.SequentialCell( + mint.nn.Linear(freq_dim, dim, dtype=dtype), mint.nn.SiLU(), mint.nn.Linear(dim, dim, dtype=dtype) + ) + self.time_projection = nn.SequentialCell(mint.nn.SiLU(), mint.nn.Linear(dim, dim * 6)) + + # blocks + cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn" + self.blocks = nn.CellList( + [ + WanAttentionBlock( + cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, dtype=dtype + ) + for _ in range(num_layers) + ] + ) + + # head + self.head = Head(dim, out_dim, patch_size, eps, dtype=dtype) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = mint.cat( + [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], + dim=1, + ) + + if model_type == "i2v": + self.img_emb = MLPProj(1280, dim, dtype=dtype) + + sp_group = get_sequence_parallel_group() + if sp_group is not None: + self.split_forward_gather_backward = SplitForwardGatherBackward(dim=1, group=sp_group) + self.gather_forward_split_backward = GatherForwardSplitBackward(dim=1, group=sp_group) + self.sp_size = get_group_size(GlobalComm.WORLD_COMM_GROUP) + assert self.num_heads % self.sp_size == 0 + else: + self.split_forward_gather_backward = mint.nn.Identity() + self.gather_forward_split_backward = mint.nn.Identity() + self.sp_size = 1 + + # initialize weights + self.init_weights() + + def construct( + self, + x: List[Tensor], + t: Tensor, + context: List[Tensor], + seq_len: int, + clip_fea: Optional[Tensor] = None, + y: List[Tensor] = None, + ) -> List[Tensor]: + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + if self.model_type == "i2v": + assert clip_fea is not None and y is not None + # params + if y is not None: + x = [mint.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = mint.stack([Tensor(u.shape[2:], dtype=ms.int32) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = Tensor([u.shape[1] for u in x], dtype=ms.int32) + assert seq_lens.max() <= seq_len + x = mint.cat([mint.cat([u, u.new_zeros((1, seq_len - u.shape[1], u.shape[2]))], dim=1) for u in x]) + + # time embeddings + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).to(self.dtype)) + e0 = self.time_projection(e) + # TODO: reshape -> unflatten + e0 = e0.reshape(e0.shape[0], 6, self.dim, *e.shape[2:]) + + # context + context_lens = None + context = self.text_embedding( + mint.stack([mint.cat([u, u.new_zeros((self.text_len - u.shape[0], u.shape[1]))]) for u in context]) + ) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = mint.concat([context_clip, context], dim=1) + + assert x.shape[1] % self.sp_size == 0 + x = self.split_forward_gather_backward(x) + + # arguments + kwargs = dict( + e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=self.freqs, context=context, context_lens=context_lens + ) + + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + x = self.gather_forward_split_backward(x) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + def unpatchify(self, x: Tensor, grid_sizes: Tensor) -> List[Tensor]: + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[: math.prod(v)].view(*v, *self.patch_size, c) + u = mint.einsum("fhwpqrc->cfphqwr", u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self) -> None: + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for _, m in self.cells_and_names(): + if isinstance(m, mint.nn.Linear): + xavier_uniform_(m.weight) + if m.bias is not None: + zeros_(m.bias) + + # init embeddings + xavier_uniform_(self.patch_embedding.weight) + for _, m in self.text_embedding.cells_and_names(): + if isinstance(m, mint.nn.Linear): + normal_(m.weight, std=0.02) + for _, m in self.time_embedding.cells_and_names(): + if isinstance(m, mint.nn.Linear): + normal_(m.weight, std=0.02) + + # init output layer + zeros_(self.head.head.weight) diff --git a/examples/wan2_1/wan/modules/t5.py b/examples/wan2_1/wan/modules/t5.py new file mode 100644 index 0000000000..a39bd5c3d3 --- /dev/null +++ b/examples/wan2_1/wan/modules/t5.py @@ -0,0 +1,557 @@ +# Modified from transformers.models.t5.modeling_t5 +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math +from typing import Any, Callable, Dict, Optional, Union + +import numpy as np + +import mindspore as ms +import mindspore.mint as mint +import mindspore.mint.nn.functional as F +import mindspore.nn as nn +from mindspore import Parameter, Tensor +from mindspore.nn.utils import no_init_parameters + +from mindone.models.utils import normal_, ones_ +from mindone.transformers.modeling_attn_mask_utils import dtype_to_min + +from ..utils.utils import load_pth +from .tokenizers import HuggingfaceTokenizer + +__all__ = ["T5Model", "T5Encoder", "T5Decoder", "T5EncoderModel"] + + +def fp16_clamp(x: Tensor) -> Tensor: + if x.dtype == ms.float16 and mint.isinf(x).any(): + clamp = Tensor(np.finfo(np.float16).max - 1000) + x = mint.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m: Any) -> None: + if isinstance(m, T5LayerNorm): + ones_(m.weight) + elif isinstance(m, T5Model): + normal_(m.token_embedding.weight, std=1.0) + elif isinstance(m, T5FeedForward): + normal_(m.gate[0].weight, std=m.dim**-0.5) + normal_(m.fc1.weight, std=m.dim**-0.5) + normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5) + normal_(m.k.weight, std=m.dim**-0.5) + normal_(m.v.weight, std=m.dim**-0.5) + normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5) + elif isinstance(m, T5RelativeEmbedding): + normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5) + + +class GELU(nn.Cell): + def construct(self, x: Tensor) -> Tensor: + return 0.5 * x * (1.0 + mint.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * mint.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Cell): + def __init__(self, dim: int, eps: float = 1e-6, dtype: ms.Type = ms.float32) -> None: + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = Parameter(Tensor(np.ones(dim), dtype=dtype)) + + def construct(self, x: Tensor) -> Tensor: + x = x * mint.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps) + if self.weight.dtype in [ms.float16, ms.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Cell): + def __init__( + self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.1, dtype: ms.Type = ms.float32 + ) -> None: + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = mint.nn.Linear(dim, dim_attn, bias=False, dtype=dtype) + self.k = mint.nn.Linear(dim, dim_attn, bias=False, dtype=dtype) + self.v = mint.nn.Linear(dim, dim_attn, bias=False, dtype=dtype) + self.o = mint.nn.Linear(dim_attn, dim, bias=False, dtype=dtype) + self.dropout = mint.nn.Dropout(dropout) + + def construct( + self, + x: Tensor, + context: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + pos_bias: Optional[Tensor] = None, + ) -> Tensor: + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.shape[0], self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros((b, n, q.shape[1], k.shape[1])) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, dtype_to_min(x.dtype)) + + # compute attention (T5 does not use scaling) + attn = mint.einsum("binc,bjnc->bnij", q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = mint.einsum("bnij,bjnc->binc", attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Cell): + def __init__(self, dim: int, dim_ffn: int, dropout: float = 0.1, dtype: ms.Type = ms.float32) -> None: + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.SequentialCell(mint.nn.Linear(dim, dim_ffn, bias=False, dtype=dtype), GELU()) + self.fc1 = mint.nn.Linear(dim, dim_ffn, bias=False, dtype=dtype) + self.fc2 = mint.nn.Linear(dim_ffn, dim, bias=False, dtype=dtype) + self.dropout = mint.nn.Dropout(dropout) + + def construct(self, x: Tensor) -> Tensor: + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Cell): + def __init__( + self, + dim: int, + dim_attn: int, + dim_ffn: int, + num_heads: int, + num_buckets: int, + shared_pos: bool = True, + dropout: float = 0.1, + dtype: ms.Type = ms.float32, + ) -> None: + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim, dtype=dtype) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout, dtype=dtype) + self.norm2 = T5LayerNorm(dim, dtype=dtype) + self.ffn = T5FeedForward(dim, dim_ffn, dropout, dtype=dtype) + self.pos_embedding = ( + None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype) + ) + + def construct(self, x: Tensor, mask: Optional[Tensor] = None, pos_bias: Optional[Tensor] = None) -> Tensor: + e = pos_bias if self.shared_pos else self.pos_embedding(x.shape[1], x.shape[1]) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5CrossAttention(nn.Cell): + def __init__( + self, + dim: int, + dim_attn: int, + dim_ffn: int, + num_heads: int, + num_buckets: int, + shared_pos: bool = True, + dropout: float = 0.1, + dtype: ms.Type = ms.float32, + ) -> None: + super(T5CrossAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim, dtype=dtype) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout, dtype=dtype) + self.norm2 = T5LayerNorm(dim, dtype=dtype) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout, dtype=dtype) + self.norm3 = T5LayerNorm(dim, dtype=dtype) + self.ffn = T5FeedForward(dim, dim_ffn, dropout, dtype=dtype) + self.pos_embedding = ( + None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False, dtype=dtype) + ) + + def construct( + self, + x: Tensor, + mask: Optional[Tensor] = None, + encoder_states: Optional[Tensor] = None, + encoder_mask: Optional[Tensor] = None, + pos_bias: Optional[Tensor] = None, + ) -> Tensor: + e = pos_bias if self.shared_pos else self.pos_embedding(x.shape[1], x.shape[1]) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Cell): + def __init__( + self, num_buckets: int, num_heads: int, bidirectional: bool, max_dist: int = 128, dtype: ms.Type = ms.float32 + ) -> None: + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = mint.nn.Embedding(num_buckets, num_heads, dtype=dtype) + + def construct(self, lq: int, lk: int) -> Tensor: + rel_pos = mint.arange(lk).unsqueeze(0) - mint.arange(lq).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos: Tensor) -> Tensor: + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).to(ms.int32) * num_buckets + rel_pos = mint.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -mint.min(rel_pos, mint.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + ( + mint.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact) + ).to(ms.int32) + rel_pos_large = mint.min(rel_pos_large, mint.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += mint.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + + +class T5Encoder(nn.Cell): + def __init__( + self, + vocab: Union[int, mint.nn.Embedding], + dim: int, + dim_attn: int, + dim_ffn: int, + num_heads: int, + num_layers: int, + num_buckets: int, + shared_pos: bool = True, + dropout: float = 0.1, + dtype: ms.Type = ms.float32, + ) -> None: + super(T5Encoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + self.dtype = dtype + + # layers + self.token_embedding = ( + vocab if isinstance(vocab, mint.nn.Embedding) else mint.nn.Embedding(vocab, dim, dtype=dtype) + ) + self.pos_embedding = ( + T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype) if shared_pos else None + ) + self.dropout = mint.nn.Dropout(dropout) + self.blocks = nn.CellList( + [ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, dtype=dtype) + for _ in range(num_layers) + ] + ) + self.norm = T5LayerNorm(dim, dtype=dtype) + + # initialize weights + self.apply(init_weights) + + def construct(self, ids: Tensor, mask: Optional[Tensor] = None) -> Tensor: + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.shape[1], x.shape[1]) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Decoder(nn.Cell): + def __init__( + self, + vocab: Union[int, mint.nn.Embedding], + dim: int, + dim_attn: int, + dim_ffn: int, + num_heads: int, + num_layers: int, + num_buckets: int, + shared_pos: bool = True, + dropout: float = 0.1, + dtype: ms.Type = ms.float32, + ) -> None: + super(T5Decoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + self.dtype = dtype + + # layers + self.token_embedding = ( + vocab if isinstance(vocab, mint.nn.Embedding) else mint.nn.Embedding(vocab, dim, dtype=dtype) + ) + self.pos_embedding = ( + T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False, dtype=dtype) if shared_pos else None + ) + self.dropout = mint.nn.Dropout(dropout) + self.blocks = nn.CellList( + [ + T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, dtype=dtype) + for _ in range(num_layers) + ] + ) + self.norm = T5LayerNorm(dim, dtype=dtype) + + # initialize weights + self.apply(init_weights) + + def construct( + self, + ids: Tensor, + mask: Optional[Tensor] = None, + encoder_states: Optional[Tensor] = None, + encoder_mask: Optional[Tensor] = None, + ) -> Tensor: + b, s = ids.shape + + # causal mask + if mask is None: + mask = mint.tril(mint.ones((1, s, s))) + elif mask.ndim == 2: + mask = mint.tril(mask.unsqueeze(1).expand((-1, s, -1))) + + # layers + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.shape[1], x.shape[1]) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Model(nn.Cell): + def __init__( + self, + vocab_size: int, + dim: int, + dim_attn: int, + dim_ffn: int, + num_heads: int, + encoder_layers: int, + decoder_layers: int, + num_buckets: int, + shared_pos: bool = True, + dropout: float = 0.1, + dtype: ms.Type = ms.float32, + ) -> None: + super(T5Model, self).__init__() + self.vocab_size = vocab_size + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.num_buckets = num_buckets + self.dtype = dtype + + # layers + self.token_embedding = mint.nn.Embedding(vocab_size, dim, dtype=dtype) + self.encoder = T5Encoder( + self.token_embedding, + dim, + dim_attn, + dim_ffn, + num_heads, + encoder_layers, + num_buckets, + shared_pos, + dropout, + dtype=dtype, + ) + self.decoder = T5Decoder( + self.token_embedding, + dim, + dim_attn, + dim_ffn, + num_heads, + decoder_layers, + num_buckets, + shared_pos, + dropout, + dtype=dtype, + ) + self.head = mint.nn.Linear(dim, vocab_size, bias=False, dtype=dtype) + + # initialize weights + self.apply(init_weights) + + def construct(self, encoder_ids: Tensor, encoder_mask: Tensor, decoder_ids: Tensor, decoder_mask: Tensor) -> Tensor: + x = self.encoder(encoder_ids, encoder_mask) + x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) + x = self.head(x) + return x + + +def _t5( + name: str, + encoder_only: bool = False, + decoder_only: bool = False, + return_tokenizer: bool = False, + tokenizer_kwargs: Dict[str, Any] = {}, + dtype: ms.Type = ms.float32, + **kwargs, +): + # sanity check + assert not (encoder_only and decoder_only) + + # params + if encoder_only: + model_cls = T5Encoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("encoder_layers") + _ = kwargs.pop("decoder_layers") + elif decoder_only: + model_cls = T5Decoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("decoder_layers") + _ = kwargs.pop("encoder_layers") + else: + model_cls = T5Model + + # init model + model = model_cls(dtype=dtype, **kwargs) + + # init tokenizer + if return_tokenizer: + from .tokenizers import HuggingfaceTokenizer + + tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs) + return model, tokenizer + else: + return model + + +def umt5_xxl(**kwargs): + cfg = dict( + vocab_size=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + encoder_layers=24, + decoder_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1, + ) + cfg.update(**kwargs) + return _t5("umt5-xxl", **cfg) + + +class T5EncoderModel: + def __init__( + self, + text_len: int, + dtype=ms.bfloat16, + checkpoint_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, + shard_fn: Optional[Callable] = None, + ) -> None: + self.text_len = text_len + self.dtype = dtype + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + with no_init_parameters(): + model = umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype) + model.set_train(False) + for param in model.trainable_params(): + param.requires_grad = False + + if checkpoint_path is not None: + logging.info(f"loading {checkpoint_path}") + if checkpoint_path.endswith(".pth"): + param_dict = load_pth(checkpoint_path, dtype=model.dtype) + ms.load_param_into_net(model, param_dict) + else: + ms.load_checkpoint(checkpoint_path, model) + model.init_parameters_data() + + self.model = model + if shard_fn is not None: + self.model = shard_fn(self.model) + + # init tokenizer + self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace") + + def __call__(self, texts): + ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True, return_tensors="np") + ids, mask = Tensor(ids), Tensor(mask) + seq_lens = mask.gt(0).sum(dim=1).to(ms.int32) + context = self.model(ids, mask) + return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/examples/wan2_1/wan/modules/tokenizers.py b/examples/wan2_1/wan/modules/tokenizers.py new file mode 100644 index 0000000000..ec85c97538 --- /dev/null +++ b/examples/wan2_1/wan/modules/tokenizers.py @@ -0,0 +1,78 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import html +import string + +import ftfy +import regex as re +from transformers import AutoTokenizer + +__all__ = ["HuggingfaceTokenizer"] + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace("_", " ") + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans("", "", string.punctuation)) + for part in text.split(keep_punctuation_exact_string) + ) + else: + text = text.translate(str.maketrans("", "", string.punctuation)) + text = text.lower() + text = re.sub(r"\s+", " ", text) + return text.strip() + + +class HuggingfaceTokenizer: + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, "whitespace", "lower", "canonicalize") + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop("return_mask", False) + + # arguments + _kwargs = {"return_tensors": "pt"} + if self.seq_len is not None: + _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len}) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == "whitespace": + text = whitespace_clean(basic_clean(text)) + elif self.clean == "lower": + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == "canonicalize": + text = canonicalize(basic_clean(text)) + return text diff --git a/examples/wan2_1/wan/modules/vae.py b/examples/wan2_1/wan/modules/vae.py new file mode 100644 index 0000000000..b97c1cd9e2 --- /dev/null +++ b/examples/wan2_1/wan/modules/vae.py @@ -0,0 +1,623 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math +from typing import List, Optional, Tuple, Union + +import numpy as np + +import mindspore as ms +import mindspore.mint as mint +import mindspore.mint.nn.functional as F +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Parameter, Tensor +from mindspore.nn.utils import no_init_parameters + +from mindone.models.utils import zeros_ + +from ..utils.utils import load_pth + +__all__ = ["WanVAE"] + +CACHE_T = 2 + + +class CausalConv3d(mint.nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def construct(self, x: Tensor, cache_x: Optional[Tensor] = None) -> Tensor: + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + x = mint.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().construct(x) + + +class RMS_norm(nn.Cell): + def __init__( + self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False, dtype: ms.Type = ms.float32 + ) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = Parameter(Tensor(np.ones(shape), dtype=dtype)) + self.bias = Parameter(Tensor(np.zeros(shape), dtype=dtype)) if bias else 0.0 + + def construct(self, x: Tensor) -> Tensor: + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class Upsample(mint.nn.Upsample): + def construct(self, x: Tensor) -> Tensor: + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().construct(x.float()).type_as(x) + + +class Resample(nn.Cell): + def __init__(self, dim: int, mode: str, dtype: ms.Type = ms.float32) -> None: + assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d") + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.SequentialCell( + Upsample(scale_factor=(2.0, 2.0), mode="nearest"), + mint.nn.Conv2d(dim, dim // 2, 3, padding=1, dtype=dtype), + ) + elif mode == "upsample3d": + self.resample = nn.SequentialCell( + Upsample(scale_factor=(2.0, 2.0), mode="nearest"), + mint.nn.Conv2d(dim, dim // 2, 3, padding=1, dtype=dtype), + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0), dtype=dtype) + + elif mode == "downsample2d": + self.resample = nn.SequentialCell( + mint.nn.ZeroPad2d((0, 1, 0, 1)), mint.nn.Conv2d(dim, dim, 3, stride=(2, 2), dtype=dtype) + ) + elif mode == "downsample3d": + self.resample = nn.SequentialCell( + nn.ZeroPad2d((0, 1, 0, 1)), mint.nn.Conv2d(dim, dim, 3, stride=(2, 2), dtype=dtype) + ) + self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0), dtype=dtype) + + else: + self.resample = mint.nn.Identity() + + def construct(self, x: Tensor, feat_cache: Optional[Tensor] = None, feat_idx: List[int] = [0]) -> Tensor: + b, c, t, h, w = x.shape + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = mint.cat([mint.zeros_like(cache_x), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = mint.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.transpose(1, 2).flatten(0, 1) # b c t h w -> (b t) c h w + x = self.resample(x) + x = x.reshape(b, t, *x.shape[1:]).transpose(1, 2) # (b t) c h w -> b c t h w + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(mint.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class ResidualBlock(nn.Cell): + def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.0, dtype: ms.Type = ms.float32) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.SequentialCell( + RMS_norm(in_dim, images=False, dtype=dtype), + mint.nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1, dtype=dtype), + RMS_norm(out_dim, images=False, dtype=dtype), + mint.nn.SiLU(), + mint.nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1, dtype=dtype), + ) + self.shortcut = CausalConv3d(in_dim, out_dim, 1, dtype=dtype) if in_dim != out_dim else mint.nn.Identity() + + def construct(self, x: Tensor, feat_cache: Optional[Tensor] = None, feat_idx: List[int] = [0]) -> Tensor: + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Cell): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim: int, dtype: ms.Type = ms.float32) -> None: + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim, dtype=dtype) + self.to_qkv = mint.nn.Conv2d(dim, dim * 3, 1, dtype=dtype) + self.proj = mint.nn.Conv2d(dim, dim, 1, dtype=dtype) + + # zero out the last layer params + zeros_(self.proj.weight) + + def construct(self, x: Tensor) -> Tensor: + identity = x + b, c, t, h, w = x.shape + x = x.transpose(1, 2).reshape(-1, c, h, w) # b c t h w -> (b t) c h w + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = ops.flash_attention_score(q, k, v, 1, scalar_value=1 / math.sqrt(q.shape[-1]), input_layout="BNSD") + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = x.reshape(b, t, *x.shape[1:]).transpose(1, 2) # (b t) c h w-> b c t h w + return x + identity + + +class Encoder3d(nn.Cell): + def __init__( + self, + dim: int = 128, + z_dim: int = 4, + dim_mult: List[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [True, True, False], + dropout: float = 0.0, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1, dtype=dtype) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout, dtype=dtype)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim, dtype=dtype)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode, dtype=dtype)) + scale /= 2.0 + self.downsamples = nn.SequentialCell(*downsamples) + + # middle blocks + self.middle = nn.SequentialCell( + ResidualBlock(out_dim, out_dim, dropout, dtype=dtype), + AttentionBlock(out_dim, dtype=dtype), + ResidualBlock(out_dim, out_dim, dropout, dtype=dtype), + ) + + # output blocks + self.head = nn.SequentialCell( + RMS_norm(out_dim, images=False, dtype=dtype), + mint.nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1, dtype=dtype), + ) + + def construct(self, x: Tensor, feat_cache: Optional[Tensor] = None, feat_idx: List[int] = [0]) -> Tensor: + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + # middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + # head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Cell): + def __init__( + self, + dim: int = 128, + z_dim: int = 4, + dim_mult: List[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_upsample: List[bool] = [False, True, True], + dropout=0.0, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1, dtype=dtype) + + # middle blocks + self.middle = nn.SequentialCell( + ResidualBlock(dims[0], dims[0], dropout, dtype=dtype), + AttentionBlock(dims[0], dtype=dtype), + ResidualBlock(dims[0], dims[0], dropout, dtype=dtype), + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout, dtype=dtype)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim, dtype=dtype)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode, dtype=dtype)) + scale *= 2.0 + self.upsamples = nn.SequentialCell(*upsamples) + + # output blocks + self.head = nn.SequentialCell( + RMS_norm(out_dim, images=False, dtype=dtype), + mint.nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1, dtype=dtype), + ) + + def construct(self, x: Tensor, feat_cache: Optional[Tensor] = None, feat_idx: List[int] = [0]) -> Tensor: + # conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + # upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + # head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = mint.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2), cache_x], dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model: nn.Cell) -> int: + count = 0 + for _, m in model.cells_and_names(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Cell): + def __init__( + self, + dim: int = 128, + z_dim: int = 4, + dim_mult: List[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [True, True, False], + dropout: float = 0.0, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + self.dtype = dtype + + # modules + self.encoder = Encoder3d( + dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, dtype=dtype + ) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1, dtype=dtype) + self.conv2 = CausalConv3d(z_dim, z_dim, 1, dtype=dtype) + self.decoder = Decoder3d( + dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, dtype=dtype + ) + + def construct(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x: Tensor, scale: List[Union[float, Tensor]]) -> Tensor: + self.clear_cache() + # cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + # 对encode输入的x,按时间拆分为1、4、4、4.... + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = mint.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z: Tensor, scale: List[Union[float, Tensor]]) -> Tensor: + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = mint.cat([out, out_], 2) + self.clear_cache() + return out + + def reparameterize(self, mu: Tensor, log_var: Tensor) -> Tensor: + std = mint.exp(0.5 * log_var) + eps = mint.randn_like(std) + return eps * std + mu + + def sample(self, imgs: Tensor, deterministic: bool = False) -> Tensor: + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = mint.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * mint.randn_like(std) + + def clear_cache(self) -> None: + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path: Optional[str] = None, z_dim: Optional[int] = None, **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + # params + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0, + ) + cfg.update(**kwargs) + + # init model + with no_init_parameters(): + model = WanVAE_(**cfg) + + # load checkpoint + if pretrained_path is not None: + logging.info(f"loading {pretrained_path}") + if pretrained_path.endswith(".pth"): + param_dict = load_pth(pretrained_path, dtype=model.dtype) + ms.load_param_into_net(model, param_dict) + else: + ms.load_checkpoint(pretrained_path, model) + model.init_parameters_data() + return model + + +class WanVAE: + def __init__(self, z_dim: int = 16, vae_pth: Optional[str] = None, dtype=ms.float32) -> None: + self.dtype = dtype + + mean = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ] + std = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ] + self.mean = Tensor(mean, dtype=dtype) + self.std = Tensor(std, dtype=dtype) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, dtype=dtype) + self.model.set_train(False) + for param in self.model.trainable_params(): + param.requires_grad = False + + def encode(self, videos: List[Tensor]) -> List[Tensor]: + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos] + + def decode(self, zs: List[Tensor]) -> List[Tensor]: + return [self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs] diff --git a/examples/wan2_1/wan/modules/xlm_roberta.py b/examples/wan2_1/wan/modules/xlm_roberta.py new file mode 100644 index 0000000000..431a1d23b3 --- /dev/null +++ b/examples/wan2_1/wan/modules/xlm_roberta.py @@ -0,0 +1,184 @@ +# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import mindspore as ms +import mindspore.mint as mint +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + +from mindone.transformers.modeling_attn_mask_utils import dtype_to_min + +__all__ = ["XLMRoberta", "xlm_roberta_large"] + + +class SelfAttention(nn.Cell): + def __init__( + self, dim: int, num_heads: int, dropout: float = 0.1, eps: float = 1e-5, dtype: ms.Type = ms.float32 + ) -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.eps = eps + + # layers + self.q = mint.nn.Linear(dim, dim, dtype=dtype) + self.k = mint.nn.Linear(dim, dim, dtype=dtype) + self.v = mint.nn.Linear(dim, dim, dtype=dtype) + self.o = mint.nn.Linear(dim, dim, dtype=dtype) + self.dropout = mint.nn.Dropout(dropout) + + def construct(self, x: Tensor, mask: Tensor) -> Tensor: + """ + x: [B, L, C]. + """ + b, s, c, n, d = *x.shape, self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) + + # compute attention + p = self.dropout.p if self.training else 0.0 + # TODO: check mask + x = ops.flash_attention_score(q, k, v, self.num_heads, attn_mask=mask, keep_prob=1 - p) + x = x.permute(0, 2, 1, 3).reshape(b, s, c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Cell): + def __init__( + self, + dim: int, + num_heads: int, + post_norm: bool, + dropout: float = 0.1, + eps: float = 1e-5, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.post_norm = post_norm + self.eps = eps + + # layers + self.attn = SelfAttention(dim, num_heads, dropout, eps, dtype=dtype) + self.norm1 = mint.nn.LayerNorm(dim, eps=eps, dtype=dtype) + self.ffn = nn.SequentialCell( + mint.nn.Linear(dim, dim * 4, dtype=dtype), + mint.nn.GELU(), + mint.nn.Linear(dim * 4, dim, dtype=dtype), + mint.nn.Dropout(dropout), + ) + self.norm2 = mint.nn.LayerNorm(dim, eps=eps, dtype=dtype) + + def construct(self, x: Tensor, mask: Tensor) -> Tensor: + if self.post_norm: + x = self.norm1(x + self.attn(x, mask)) + x = self.norm2(x + self.ffn(x)) + else: + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class XLMRoberta(nn.Cell): + """ + XLMRobertaModel with no pooler and no LM head. + """ + + def __init__( + self, + vocab_size: int = 250002, + max_seq_len: int = 514, + type_size: int = 1, + pad_id: int = 1, + dim: int = 1024, + num_heads: int = 16, + num_layers: int = 24, + post_norm: bool = True, + dropout: float = 0.1, + eps: float = 1e-5, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.type_size = type_size + self.pad_id = pad_id + self.dim = dim + self.num_heads = num_heads + self.num_layers = num_layers + self.post_norm = post_norm + self.eps = eps + + # embeddings + self.token_embedding = mint.nn.Embedding(vocab_size, dim, padding_idx=pad_id, dtype=dtype) + self.type_embedding = mint.nn.Embedding(type_size, dim, dtype=dtype) + self.pos_embedding = mint.nn.Embedding(max_seq_len, dim, padding_idx=pad_id, dtype=dtype) + self.dropout = mint.nn.Dropout(dropout) + + # blocks + self.blocks = nn.CellList( + [AttentionBlock(dim, num_heads, post_norm, dropout, eps, dtype=dtype) for _ in range(num_layers)] + ) + + # norm layer + self.norm = mint.nn.LayerNorm(dim, eps=eps, dtype=dtype) + + def construct(self, ids: Tensor) -> Tensor: + """ + ids: [B, L] of mindspore.Tensor. + """ + b, s = ids.shape + mask = ids.ne(self.pad_id).to(ms.int32) + + # embeddings + x = ( + self.token_embedding(ids) + + self.type_embedding(mint.zeros_like(ids)) + + self.pos_embedding(self.pad_id + mint.cumsum(mask, dim=1) * mask) + ) + if self.post_norm: + x = self.norm(x) + x = self.dropout(x) + + # blocks + mask = mint.where(mask.view(b, 1, 1, s).gt(0), 0.0, dtype_to_min(x.dtype)) + for block in self.blocks: + x = block(x, mask) + + # output + if not self.post_norm: + x = self.norm(x) + return x + + +def xlm_roberta_large(pretrained: bool = False, return_tokenizer: bool = False, dtype: ms.Type = ms.float32, **kwargs): + """ + XLMRobertaLarge adapted from Huggingface. + """ + # params + cfg = dict( + vocab_size=250002, + max_seq_len=514, + type_size=1, + pad_id=1, + dim=1024, + num_heads=16, + num_layers=24, + post_norm=True, + dropout=0.1, + eps=1e-5, + ) + cfg.update(**kwargs) + + model = XLMRoberta(**cfg, dtype=dtype) + return model diff --git a/examples/wan2_1/wan/text2video.py b/examples/wan2_1/wan/text2video.py new file mode 100644 index 0000000000..97491b9e02 --- /dev/null +++ b/examples/wan2_1/wan/text2video.py @@ -0,0 +1,243 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math +import os +import random +import sys +from functools import partial + +from tqdm import tqdm + +import mindspore as ms +import mindspore.mint as mint +import mindspore.mint.distributed as dist +from mindspore.communication import GlobalComm, get_group_size +from mindspore.nn.utils import no_init_parameters + +from mindone.trainers.zero import prepare_network + +from .acceleration.parallel_states import create_parallel_group +from .modules.model import WanModel +from .modules.t5 import T5EncoderModel +from .modules.vae import WanVAE +from .utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + + +class WanT2V: + def __init__( + self, + config, + checkpoint_dir, + rank=0, + t5_zero3=False, + dit_zero3=False, + use_usp=False, + t5_cpu=False, + ): + r""" + Initializes the Wan text-to-video generation model components. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_zero3 (`bool`, *optional*, defaults to False): + Enable ZeRO3 sharding for T5 model + dit_zero3 (`bool`, *optional*, defaults to False): + Enable ZeRO3 sharding for DiT model + use_usp (`bool`, *optional*, defaults to False): + Enable distribution strategy of USP. + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_zero3. + """ + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + if use_usp: + self.sp_size = get_group_size(GlobalComm.WORLD_COMM_GROUP) + create_parallel_group(self.sp_size) + else: + self.sp_size = 1 + + shard_fn = partial(prepare_network, zero_stage=3, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn=shard_fn if t5_zero3 else None, + ) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = WanVAE(vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype=self.param_dtype) + + logging.info(f"Creating WanModel from {checkpoint_dir}") + with no_init_parameters(): + self.model = WanModel.from_pretrained(checkpoint_dir, mindspore_dtype=self.param_dtype) + self.model.init_parameters_data() + self.model.set_train(False) + for param in self.model.trainable_params(): + param.requires_grad = False + + # TODO: GlobalComm.INITED -> mint.is_initialzed + if GlobalComm.INITED: + dist.barrier() + if dit_zero3: + self.model = shard_fn(self.model) + + self.sample_neg_prompt = config.sample_neg_prompt + + def generate( + self, + input_prompt, + size=(1280, 720), + frame_num=81, + shift=5.0, + sample_solver="unipc", + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=False, + ): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation + size (tupele[`int`], *optional*, defaults to (1280,720)): + Controls video resolution, (width,height). + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to False): + If True, offloads models to CPU during generation to save VRAM + + Returns: + mindspore.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + # preprocess + F = frame_num + target_shape = ( + self.vae.model.z_dim, + (F - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2], + ) + + seq_len = ( + math.ceil( + (target_shape[2] * target_shape[3]) + / (self.patch_size[1] * self.patch_size[2]) + * target_shape[1] + / self.sp_size + ) + * self.sp_size + ) + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = ms.Generator() + seed_g.manual_seed(seed) + + if not self.t5_cpu: + context = self.text_encoder([input_prompt]) + context_null = self.text_encoder([n_prompt]) + if offload_model: + raise NotImplementedError() + else: + raise NotImplementedError() + + noise = [ + mint.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=self.param_dtype, + generator=seed_g, + ) + ] + + # evaluation mode + if sample_solver == "unipc": + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(sampling_steps, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == "dpm++": + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False + ) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps(sample_scheduler, sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + + arg_c = {"context": context, "seq_len": seq_len} + arg_null = {"context": context_null, "seq_len": seq_len} + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = latents + timestep = [t] + + timestep = mint.stack(timestep) + + noise_pred_cond = self.model(latent_model_input, t=timestep, **arg_c)[0] + noise_pred_uncond = self.model(latent_model_input, t=timestep, **arg_null)[0] + + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), t, latents[0].unsqueeze(0), return_dict=False, generator=seed_g + )[0] + latents = [temp_x0.squeeze(0)] + + x0 = latents + if offload_model: + raise NotImplementedError() + if self.rank == 0: + # TODO: handle this + # np.save("latent.npy", x0[0].to(ms.float32).asnumpy()) + videos = self.vae.decode(x0) + + del noise, latents + del sample_scheduler + if offload_model: + raise NotImplementedError() + # TODO: GlobalComm.INITED -> mint.is_initialzed + if GlobalComm.INITED: + dist.barrier() + + return videos[0] if self.rank == 0 else None diff --git a/examples/wan2_1/wan/utils/__init__.py b/examples/wan2_1/wan/utils/__init__.py new file mode 100644 index 0000000000..a2088df628 --- /dev/null +++ b/examples/wan2_1/wan/utils/__init__.py @@ -0,0 +1,10 @@ +from .fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps +from .fm_solvers_unipc import FlowUniPCMultistepScheduler + +__all__ = [ + "HuggingfaceTokenizer", + "get_sampling_sigmas", + "retrieve_timesteps", + "FlowDPMSolverMultistepScheduler", + "FlowUniPCMultistepScheduler", +] diff --git a/examples/wan2_1/wan/utils/fm_solvers.py b/examples/wan2_1/wan/utils/fm_solvers.py new file mode 100644 index 0000000000..bf0afe2caa --- /dev/null +++ b/examples/wan2_1/wan/utils/fm_solvers.py @@ -0,0 +1,812 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +# Convert dpm solver for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import inspect +import math +from typing import List, Optional, Tuple, Union + +import numpy as np + +import mindspore as ms +import mindspore.mint as mint +import mindspore.ops as ops +from mindspore import Tensor + +from mindone.diffusers.configuration_utils import ConfigMixin, register_to_config +from mindone.diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from mindone.diffusers.utils import deprecate +from mindone.diffusers.utils.mindspore_utils import randn_tensor + + +def get_sampling_sigmas(sampling_steps, shift): + sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] + sigma = shift * sigma / (1 + (shift - 1) * sigma) + + return sigma + + +def retrieve_timesteps( + scheduler, + num_inference_steps=None, + timesteps=None, + sigmas=None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. This determines the resolution of the diffusion process. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored + and used in multistep updates. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + shift (`float`, *optional*, defaults to 1.0): + A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling + process. + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is + applied on the fly. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent + saturation and improve photorealism. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, *optional*, defaults to "zero"): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + invert_sigmas: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = ( + f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. " + "Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + ) + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = ms.from_numpy(sigmas).to(dtype=ms.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = ms.from_numpy(sigmas) + self.timesteps = ms.from_numpy(timesteps).to(dtype=ms.int32) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + self._step_index = None + self._begin_index = None + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: Tensor) -> Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (ms.float32, ms.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + # TODO: ops.quantile -> mint.quantile + s = ops.quantile(abs_sample, self.config.dynamic_thresholding_ratio, axis=1) + s = mint.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = mint.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output + def convert_model_output( + self, + model_output: Tensor, + *args, + sample: Tensor = None, + **kwargs, + ) -> Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + Args: + model_output (`Tensor`): + The direct output from the learned diffusion model. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update + def dpm_solver_first_order_update( + self, + model_output: Tensor, + *args, + sample: Tensor = None, + noise: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + Args: + model_output (`Tensor`): + The direct output from the learned diffusion model. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = mint.log(alpha_t) - mint.log(sigma_t) + lambda_s = mint.log(alpha_s) - mint.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (mint.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (mint.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * mint.exp(-h)) * sample + + (alpha_t * (1 - mint.exp(-2.0 * h))) * model_output + + sigma_t * mint.sqrt(1.0 - mint.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (mint.exp(h) - 1.0)) * model_output + + sigma_t * mint.sqrt(mint.exp(2 * h) - 1.0) * noise + ) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[Tensor], + *args, + sample: Tensor = None, + noise: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + """ + One step for the second-order multistep DPMSolver. + Args: + model_output_list (`List[Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = mint.log(alpha_t) - mint.log(sigma_t) + lambda_s0 = mint.log(alpha_s0) - mint.log(sigma_s0) + lambda_s1 = mint.log(alpha_s1) - mint.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (mint.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (mint.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (mint.exp(-h) - 1.0)) * D0 + + (alpha_t * ((mint.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (mint.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (mint.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (mint.exp(h) - 1.0)) * D0 + - (sigma_t * ((mint.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * mint.exp(-h)) * sample + + (alpha_t * (1 - mint.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - mint.exp(-2.0 * h))) * D1 + + sigma_t * mint.sqrt(1.0 - mint.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * mint.exp(-h)) * sample + + (alpha_t * (1 - mint.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - mint.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * mint.sqrt(1.0 - mint.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (mint.exp(h) - 1.0)) * D0 + - (sigma_t * (mint.exp(h) - 1.0)) * D1 + + sigma_t * mint.sqrt(mint.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (mint.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((mint.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * mint.sqrt(mint.exp(2 * h) - 1.0) * noise + ) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[Tensor], + *args, + sample: Tensor = None, + **kwargs, + ) -> Tensor: + """ + One step for the third-order multistep DPMSolver. + Args: + model_output_list (`List[Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`Tensor`): + A current instance of a sample created by diffusion process. + Returns: + `Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + self.sigmas[self.step_index - 2], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = mint.log(alpha_t) - mint.log(sigma_t) + lambda_s0 = mint.log(alpha_s0) - mint.log(sigma_s0) + lambda_s1 = mint.log(alpha_s1) - mint.log(sigma_s1) + lambda_s2 = mint.log(alpha_s2) - mint.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (mint.exp(-h) - 1.0)) * D0 + + (alpha_t * ((mint.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((mint.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (mint.exp(h) - 1.0)) * D0 + - (sigma_t * ((mint.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((mint.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t # pyright: ignore + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step + def step( + self, + model_output: Tensor, + timestep: Union[int, Tensor], + sample: Tensor, + generator=None, + variance_noise: Optional[Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + Args: + model_output (`Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + generator (`mindspore.Generator`, *optional*): + A random number generator. + variance_noise (`Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(ms.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: + noise = randn_tensor(model_output.shape, generator=generator, dtype=ms.float32) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to(dtype=ms.float32) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def scale_model_input(self, sample: Tensor, *args, **kwargs) -> Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`Tensor`): + The input sample. + Returns: + `Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def add_noise( + self, + original_samples: Tensor, + noise: Tensor, + timesteps: Tensor, + ) -> Tensor: + # Make sure sigmas and timesteps have the same dtype as original_samples + sigmas = self.sigmas.to(original_samples.dtype) + + schedule_timesteps = self.timesteps + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/examples/wan2_1/wan/utils/fm_solvers_unipc.py b/examples/wan2_1/wan/utils/fm_solvers_unipc.py new file mode 100644 index 0000000000..bd0ad87d1c --- /dev/null +++ b/examples/wan2_1/wan/utils/fm_solvers_unipc.py @@ -0,0 +1,747 @@ +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np + +import mindspore as ms +import mindspore.mint as mint +import mindspore.ops as ops +from mindspore import Tensor + +from mindone.diffusers.configuration_utils import ConfigMixin, register_to_config +from mindone.diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from mindone.diffusers.utils import deprecate + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = ms.from_numpy(sigmas).to(dtype=ms.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore + + self.sigmas = ms.from_numpy(sigmas) + self.timesteps = ms.from_numpy(timesteps).to(dtype=ms.int32) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: Tensor) -> Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (ms.float32, ms.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + # TODO: ops.quantile -> mint.quantile + s = ops.quantile(abs_sample, self.config.dynamic_thresholding_ratio, axis=1) + s = mint.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = mint.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def convert_model_output( + self, + model_output: Tensor, + *args, + sample: Tensor = None, + **kwargs, + ) -> Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: Tensor, + *args, + sample: Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError(" missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = mint.log(alpha_t) - mint.log(sigma_t) + lambda_s0 = mint.log(alpha_s0) - mint.log(sigma_s0) + + h = lambda_t - lambda_s0 + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = mint.log(alpha_si) - mint.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = Tensor(rks) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = mint.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = mint.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(mint.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = mint.stack(R) + b = Tensor(b) + + if len(D1s) > 0: + D1s = mint.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = Tensor([0.5], dtype=x.dtype) + else: + rhos_p = ms.Tensor(np.linalg.solve(R[:-1, :-1].asnumpy(), b[:-1].asnumpy()), dtype=x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = mint.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = mint.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: Tensor, + *args, + last_sample: Tensor = None, + this_sample: Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError(" missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError(" missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = mint.log(alpha_t) - mint.log(sigma_t) + lambda_s0 = mint.log(alpha_s0) - mint.log(sigma_s0) + + h = lambda_t - lambda_s0 + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = mint.log(alpha_si) - mint.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = Tensor(rks) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = mint.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = mint.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(mint.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = mint.stack(R) + b = Tensor(b) + + if len(D1s) > 0: + D1s = mint.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = Tensor([0.5], dtype=x.dtype) + else: + rhos_c = ms.Tensor(np.linalg.solve(R.asnumpy(), b.asnumpy()), dtype=x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = mint.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = mint.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: Tensor, + timestep: Union[int, Tensor], + sample: Tensor, + return_dict: bool = True, + generator=None, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 + and self.step_index - 1 not in self.disable_corrector + and self.last_sample is not None # pyright: ignore + ) + + model_output_convert = self.convert_model_output(model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: Tensor, *args, **kwargs) -> Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`Tensor`): + The input sample. + + Returns: + `Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: Tensor, + noise: Tensor, + timesteps: Tensor, + ) -> Tensor: + # Make sure sigmas and timesteps have the same dtype as original_samples + sigmas = self.sigmas.to(dtype=original_samples.dtype) + + schedule_timesteps = self.timesteps + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/examples/wan2_1/wan/utils/utils.py b/examples/wan2_1/wan/utils/utils.py new file mode 100644 index 0000000000..6f4f9055b0 --- /dev/null +++ b/examples/wan2_1/wan/utils/utils.py @@ -0,0 +1,251 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +import binascii +import logging +import math +import os +import os.path as osp +from typing import List, Optional, Tuple, Union + +import imageio +import ml_dtypes +import numpy as np +import torch +import tqdm +from PIL import Image + +import mindspore as ms +import mindspore.mint as mint +from mindspore import Parameter, Tensor + +__all__ = ["cache_video", "cache_image", "str2bool", "load_pth"] + +logger = logging.getLogger(__name__) + + +def rand_name(length=8, suffix=""): + name = binascii.b2a_hex(os.urandom(length)).decode("utf-8") + if suffix: + if not suffix.startswith("."): + suffix = "." + suffix + name += suffix + return name + + +def cache_video(tensor, save_file=None, fps=30, suffix=".mp4", nrow=8, normalize=True, value_range=(-1, 1), retry=5): + tensor = tensor.float() + + # cache file + cache_file = osp.join("/tmp", rand_name(suffix=suffix)) if save_file is None else save_file + + # save to cache + error = None + for _ in range(retry): + try: + # preprocess + tensor = tensor.clamp(min(value_range), max(value_range)) + tensor = mint.stack( + [make_grid_ms(u, nrow=nrow, normalize=normalize, value_range=value_range) for u in tensor.unbind(2)], + dim=1, + ).permute(1, 2, 3, 0) + tensor = (tensor * 255).type(ms.uint8) + + # write video + writer = imageio.get_writer(cache_file, fps=fps, codec="libx264", quality=8) + for frame in tensor.asnumpy(): + writer.append_data(frame) + writer.close() + return cache_file + except Exception as e: + logger.warning(e) + continue + else: + print(f"cache_video failed, error: {error}", flush=True) + return None + + +def cache_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1), retry=5): + # cache file + suffix = osp.splitext(save_file)[1] + if suffix.lower() not in [".jpg", ".jpeg", ".png", ".tiff", ".gif", ".webp"]: + suffix = ".png" + + # save to cache + for _ in range(retry): + try: + tensor = tensor.clamp(min(value_range), max(value_range)) + save_image_ms(tensor, save_file, nrow=nrow, normalize=normalize, value_range=value_range) + return save_file + except Exception as e: + logger.warning(e) + continue + + +def str2bool(v): + """ + Convert a string to a boolean. + + Supported true values: 'yes', 'true', 't', 'y', '1' + Supported false values: 'no', 'false', 'f', 'n', '0' + + Args: + v (str): String to convert. + + Returns: + bool: Converted boolean value. + + Raises: + argparse.ArgumentTypeError: If the value cannot be converted to boolean. + """ + if isinstance(v, bool): + return v + v_lower = v.lower() + if v_lower in ("yes", "true", "t", "y", "1"): + return True + elif v_lower in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected (True/False)") + + +def load_pth(pth_path: str, dtype: ms.Type = ms.bfloat16): + logger.info(f"Loading PyTorch ckpt from {pth_path}.") + torch_data = torch.load(pth_path, map_location="cpu") + mindspore_data = dict() + for name, value in tqdm.tqdm(torch_data.items(), desc="converting to MindSpore format"): + if value.dtype == torch.bfloat16: + mindspore_data[name] = Parameter( + Tensor(value.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16), dtype=dtype) + ) + else: + mindspore_data[name] = Parameter(Tensor(value.numpy(), dtype=dtype)) + return mindspore_data + + +def pil2tensor(pic: Image.Image) -> ms.Tensor: + """ + convert PIL image to mindspore.Tensor + """ + pic = np.array(pic) + if pic.dtype != np.uint8: + pic = pic.astype(np.uint8) + pic = np.transpose(pic, (2, 0, 1)) # hwc -> chw + tensor = Tensor(pic, dtype=ms.float32) + tensor = tensor / 255.0 + + return tensor + + +def make_grid_ms( + tensor: ms.Tensor, + nrow: int = 8, + padding: int = 2, + normalize: bool = False, + value_range: Optional[Tuple[int, int]] = None, + scale_each: bool = False, + pad_value: float = 0.0, +) -> ms.Tensor: + """ + Make a grid of images. + + Args: + tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) + or a list of images all of the same size. + nrow (int, optional): Number of images displayed in each row of the grid. + The final grid size is ``(B / nrow, nrow)``. Default: ``8``. + padding (int, optional): amount of padding. Default: ``2``. + normalize (bool, optional): If True, shift the image to the range (0, 1), + by the min and max values specified by ``value_range``. Default: ``False``. + value_range (tuple, optional): tuple (min, max) where min and max are numbers, + then these numbers are used to normalize the image. By default, min and max + are computed from the tensor. + scale_each (bool, optional): If ``True``, scale each image in the batch of + images separately rather than the (min, max) over all images. Default: ``False``. + pad_value (float, optional): Value for the padded pixels. Default: ``0``. + + Returns: + grid (Tensor): the tensor containing grid of images. + """ + # if list of tensors, convert to a 4D mini-batch Tensor + if isinstance(tensor, list): + tensor = mint.stack(tensor, dim=0) + + if tensor.dim() == 2: # single image H x W + tensor = tensor.unsqueeze(0) + if tensor.dim() == 3: # single image + if tensor.shape[0] == 1: # if single-channel, convert to 3-channel + tensor = mint.cat((tensor, tensor, tensor), 0) + tensor = tensor.unsqueeze(0) + + if tensor.dim() == 4 and tensor.shape[1] == 1: # single-channel images + tensor = mint.cat((tensor, tensor, tensor), 1) + + if normalize is True: + tensor = tensor.clone() # avoid modifying tensor in-place + if value_range is not None and not isinstance(value_range, tuple): + raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers") + + def norm_ip(img, low, high): + img.clamp_(min=low, max=high) + img.sub_(low).div_(max(high - low, 1e-5)) + + def norm_range(t, value_range): + if value_range is not None: + norm_ip(t, value_range[0], value_range[1]) + else: + norm_ip(t, float(t.min()), float(t.max())) + + if scale_each is True: + for t in tensor: # loop over mini-batch dimension + norm_range(t, value_range) + else: + norm_range(tensor, value_range) + + if not isinstance(tensor, ms.Tensor): + raise TypeError("tensor should be of type ms Tensor") + if tensor.shape[0] == 1: + return tensor.squeeze(0) + + # make the mini-batch of images into a grid + nmaps = tensor.shape[0] + xmaps = min(nrow, nmaps) + ymaps = int(math.ceil(float(nmaps) / xmaps)) + height, width = int(tensor.shape[2] + padding), int(tensor.shape[3] + padding) + num_channels = tensor.shape[1] + grid = mint.full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) + k = 0 + for y in range(ymaps): + for x in range(xmaps): + if k >= nmaps: + break + grid.narrow(1, y * height + padding, height - padding).narrow( + 2, x * width + padding, width - padding + ).copy_(tensor[k]) + k = k + 1 + return grid + + +def save_image_ms( + tensor: Union[ms.Tensor, List[ms.Tensor]], + fp: str, + format: Optional[str] = None, + **kwargs, +) -> None: + """ + Save a given Tensor into an image file. + + Args: + tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, + saves the tensor as a grid of images by calling ``make_grid``. + fp (string or file object): A filename or a file object + format(Optional): If omitted, the format to use is determined from the filename extension. + If a file object was used instead of a filename, this parameter should always be used. + **kwargs: Other arguments are documented in ``make_grid``. + """ + + grid = make_grid_ms(tensor, **kwargs) + # Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to(ms.uint8).asnumpy() + + im = Image.fromarray(ndarr) + im.save(fp, format=format) diff --git a/mindone/models/modules/parallel/conv.py b/mindone/models/modules/parallel/conv.py index 78e79a597c..2498eeff47 100644 --- a/mindone/models/modules/parallel/conv.py +++ b/mindone/models/modules/parallel/conv.py @@ -89,7 +89,7 @@ def construct(self, x): bias = self.param_wrapper_b(self.net.bias) if self.net.padding_mode != "zeros": output = self.net.conv2d( - mint.pad(input, self.net._reversed_padding, mode=self.net.padding_mode), + mint.pad(x, self.net._reversed_padding, mode=self.net.padding_mode), weight, bias, self.net.stride, @@ -99,7 +99,7 @@ def construct(self, x): ) else: output = self.net.conv2d( - input, weight, bias, self.net.stride, self.net.padding, self.net.dilation, self.net.groups + x, weight, bias, self.net.stride, self.net.padding, self.net.dilation, self.net.groups ) return output @@ -110,7 +110,7 @@ def construct(self, x): bias = self.param_wrapper_b(self.net.bias) if self.net.padding_mode != "zeros": output = self.net.conv3d( - mint.pad(input, self.net._reversed_padding, mode=self.net.padding_mode), + mint.pad(x, self.net._reversed_padding, mode=self.net.padding_mode), weight, bias, self.net.stride, @@ -120,6 +120,6 @@ def construct(self, x): ) else: output = self.net.conv3d( - input, weight, bias, self.net.stride, self.net.padding, self.net.dilation, self.net.groups + x, weight, bias, self.net.stride, self.net.padding, self.net.dilation, self.net.groups ) return output