|
3 | 3 | a video reconstruction based on a provided prompt. It utilizes the CogVideoX pipeline to
|
4 | 4 | process video frames, apply the DDIM inverse scheduler, and produce an output video.
|
5 | 5 |
|
| 6 | +**Please notice that this script is based on the CogVideoX 5B model, and would not generate |
| 7 | +a good result for 2B variants.** |
| 8 | +
|
6 | 9 | Usage:
|
7 |
| - python script.py --model-path /path/to/model --prompt "a prompt" --video-path /path/to/video.mp4 --output-path /path/to/output |
| 10 | + python ddim_inversion.py |
| 11 | + --model-path /path/to/model |
| 12 | + --prompt "a prompt" |
| 13 | + --video-path /path/to/video.mp4 |
| 14 | + --output-path /path/to/output |
| 15 | +
|
| 16 | +For more details about the cli arguments, please run `python ddim_inversion.py --help`. |
8 | 17 |
|
9 | 18 | Author:
|
10 | 19 | LittleNyima <littlenyima[at]163[dot]com>
|
|
15 | 24 | import os
|
16 | 25 | from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast
|
17 | 26 |
|
18 |
| -import decord |
19 | 27 | import torch
|
20 | 28 | import torch.nn.functional as F
|
21 | 29 | import torchvision.transforms as T
|
|
27 | 35 | from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler
|
28 | 36 | from diffusers.utils import export_to_video
|
29 | 37 |
|
| 38 | +# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error. |
| 39 | +# Very few bug reports but it happens. Look in decord Github issues for more relevant information. |
| 40 | +import decord # isort: skip |
| 41 | + |
30 | 42 |
|
31 | 43 | class DDIMInversionArguments(TypedDict):
|
32 | 44 | model_path: str
|
@@ -399,6 +411,8 @@ def ddim_inversion(
|
399 | 411 | device: torch.device,
|
400 | 412 | ):
|
401 | 413 | pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device=device)
|
| 414 | + if not pipeline.transformer.config.use_rotary_positional_embeddings: |
| 415 | + raise NotImplementedError("This script supports CogVideoX 5B model only.") |
402 | 416 | video_frames = get_video_frames(
|
403 | 417 | video_path=video_path,
|
404 | 418 | width=width,
|
|
0 commit comments