diff --git a/csrc/sliding_tile_attention/st_attn/__init__.py b/csrc/sliding_tile_attention/st_attn/__init__.py index 9d1e0a23..2017c79e 100644 --- a/csrc/sliding_tile_attention/st_attn/__init__.py +++ b/csrc/sliding_tile_attention/st_attn/__init__.py @@ -1,35 +1,212 @@ import math - +import subprocess import torch -from st_attn_cuda import sta_fwd +from torch.nn.attention.flex_attention import flex_attention +from functools import lru_cache +from typing import Tuple +from torch import BoolTensor, IntTensor +from torch.nn.attention.flex_attention import create_block_mask + +# Peiyuan: This is neccesay. Dont know why. see https://github.com/pytorch/pytorch/issues/135028 +torch._inductor.config.realize_opcount_threshold = 100 + + +def generate_sta_mask(canvas_twh, kernel_twh, tile_twh, text_length): + """Generates a 3D NATTEN attention mask with a given kernel size. + + Args: + canvas_t: The time dimension of the canvas. + canvas_h: The height of the canvas. + canvas_w: The width of the canvas. + kernel_t: The time dimension of the kernel. + kernel_h: The height of the kernel. + kernel_w: The width of the kernel. + """ + canvas_t, canvas_h, canvas_w = canvas_twh + kernel_t, kernel_h, kernel_w = kernel_twh + tile_t_size, tile_h_size, tile_w_size = tile_twh + total_tile_size = tile_t_size * tile_h_size * tile_w_size + canvas_tile_t, canvas_tile_h, canvas_tile_w = canvas_t // tile_t_size, canvas_h // tile_h_size, canvas_w // tile_w_size + img_seq_len = canvas_t * canvas_h * canvas_w + + def get_tile_t_x_y(idx: IntTensor) -> Tuple[IntTensor, IntTensor, IntTensor]: + tile_id = idx // total_tile_size + tile_t = tile_id // (canvas_tile_h * canvas_tile_w) + tile_h = (tile_id % (canvas_tile_h * canvas_tile_w)) // canvas_tile_w + tile_w = tile_id % canvas_tile_w + return tile_t, tile_h, tile_w + + def sta_mask_3d( + b: IntTensor, + h: IntTensor, + q_idx: IntTensor, + kv_idx: IntTensor, + ) -> BoolTensor: + q_t_tile, q_x_tile, q_y_tile = get_tile_t_x_y(q_idx) + kv_t_tile, kv_x_tile, kv_y_tile = get_tile_t_x_y(kv_idx) + # kernel nominally attempts to center itself on the query, but kernel center + # is clamped to a fixed distance (kernel half-length) from the canvas edge + kernel_center_t = q_t_tile.clamp(kernel_t // 2, (canvas_tile_t - 1) - kernel_t // 2) + kernel_center_x = q_x_tile.clamp(kernel_h // 2, (canvas_tile_h - 1) - kernel_h // 2) + kernel_center_y = q_y_tile.clamp(kernel_w // 2, (canvas_tile_w - 1) - kernel_w // 2) + time_mask = (kernel_center_t - kv_t_tile).abs() <= kernel_t // 2 + hori_mask = (kernel_center_x - kv_x_tile).abs() <= kernel_h // 2 + vert_mask = (kernel_center_y - kv_y_tile).abs() <= kernel_w // 2 + image_mask = (q_idx < img_seq_len) & (kv_idx < img_seq_len) + image_to_text_mask = (q_idx < img_seq_len) & (kv_idx >= img_seq_len) & (kv_idx < img_seq_len + text_length) + text_to_all_mask = (q_idx >= img_seq_len) & (kv_idx < img_seq_len + text_length) + return (image_mask & time_mask & hori_mask & vert_mask) | image_to_text_mask | text_to_all_mask + + sta_mask_3d.__name__ = f"natten_3d_c{canvas_t}x{canvas_w}x{canvas_h}_k{kernel_t}x{kernel_w}x{kernel_h}" + return sta_mask_3d + + +def get_sliding_tile_attention_mask(kernel_size, tile_size, img_size, text_length, device, text_max_len=256): + img_seq_len = img_size[0] * img_size[1] * img_size[2] + image_mask = generate_sta_mask(img_size, kernel_size, tile_size, text_length) + mask = create_block_mask(image_mask, + B=None, + H=None, + Q_LEN=img_seq_len + text_max_len, + KV_LEN=img_seq_len + text_max_len, + device=device, + _compile=True) + return mask + +def get_gpu_type(): + try: + # Run nvidia-smi to get GPU information + result = subprocess.check_output(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader']).decode('utf-8') + + # Check if H100 is in any of the GPU names + gpus = [gpu.strip() for gpu in result.split('\n') if gpu.strip()] + + for gpu in gpus: + if 'H100' in gpu: + return 'H100' + if '4090' in gpu: + return '4090' + + return None + except Exception as e: + return None + +gpu_type = get_gpu_type() + +if gpu_type == 'H100': + from st_attn_cuda import sta_fwd + + +@lru_cache(maxsize=32) +def get_compiled_flex_attention(strategy, tile_size, image_size, text_length, device): + """ + Create and compile flex attention with a specific sliding block mask. + This function is cached to avoid recompiling for the same parameters. + + Args: + strategy (tuple): A tuple (t, h, w) defining the strategy + tile_size (tuple): A tuple (ts_t, ts_h, ts_w) defining the tile size + image_size (tuple): A tuple (n_t, n_h, n_w) defining the image size + text_length (int): The text length + device (str): The device to use + + Returns: + function: A compiled flex attention function with the specified mask + """ + # Convert strategy to the required format (ceil(t*3/2), h*2, w) + adjusted_strategy = strategy + + # Get the sliding block attention mask + mask = get_sliding_tile_attention_mask( + adjusted_strategy, + tile_size, + image_size, + text_length, + device + ) + + def flex_attn_with_mask(q, k, v, scale=None): + return flex_attention(q, k, v, block_mask=mask, scale=scale) + + # Compile the wrapper function + compiled_flex_attn = torch.compile(flex_attn_with_mask) + + return compiled_flex_attn + +def flex_sliding_tile_attention(q_all, k_all, v_all, strategy, tile_size, + image_size, text_length, scale=None): + device = q_all.device + + # Get the compiled flex attention function (cached if called with same parameters) + compiled_flex_attn = get_compiled_flex_attention( + strategy, + tile_size, + image_size, + text_length, + device + ) + + + # Apply the compiled flex attention + output = compiled_flex_attn(q_all, k_all, v_all, scale=scale) + + + return output def sliding_tile_attention(q_all, k_all, v_all, window_size, text_length, has_text=True): - seq_length = q_all.shape[2] - if has_text: - assert q_all.shape[ - 2] == 115456, "STA currently only supports video with latent size (30, 48, 80), which is 117 frames x 768 x 1280 pixels" - assert q_all.shape[1] == len(window_size), "Number of heads must match the number of window sizes" - target_size = math.ceil(seq_length / 384) * 384 - pad_size = target_size - seq_length - if pad_size > 0: - q_all = torch.cat([q_all, q_all[:, :, -pad_size:]], dim=2) - k_all = torch.cat([k_all, k_all[:, :, -pad_size:]], dim=2) - v_all = torch.cat([v_all, v_all[:, :, -pad_size:]], dim=2) + if gpu_type == 'H100': + seq_length = q_all.shape[2] + if has_text: + assert q_all.shape[ + 2] == 115456, "STA currently only supports video with latent size (30, 48, 80), which is 117 frames x 768 x 1280 pixels" + assert q_all.shape[1] == len(window_size), "Number of heads must match the number of window sizes" + target_size = math.ceil(seq_length / 384) * 384 + pad_size = target_size - seq_length + if pad_size > 0: + q_all = torch.cat([q_all, q_all[:, :, -pad_size:]], dim=2) + k_all = torch.cat([k_all, k_all[:, :, -pad_size:]], dim=2) + v_all = torch.cat([v_all, v_all[:, :, -pad_size:]], dim=2) + else: + assert q_all.shape[2] == 82944 + + hidden_states = torch.empty_like(q_all) + # This for loop is ugly. but it is actually quite efficient. The sequence dimension alone can already oversubscribe SMs + for head_index, (t_kernel, h_kernel, w_kernel) in enumerate(window_size): + for batch in range(q_all.shape[0]): + q_head, k_head, v_head, o_head = (q_all[batch:batch + 1, head_index:head_index + 1], + k_all[batch:batch + 1, + head_index:head_index + 1], v_all[batch:batch + 1, + head_index:head_index + 1], + hidden_states[batch:batch + 1, head_index:head_index + 1]) + + _ = sta_fwd(q_head, k_head, v_head, o_head, t_kernel, h_kernel, w_kernel, text_length, False, has_text) + if has_text: + _ = sta_fwd(q_all, k_all, v_all, hidden_states, 3, 3, 3, text_length, True, True) + return hidden_states[:, :, :seq_length] else: - assert q_all.shape[2] == 82944 - - hidden_states = torch.empty_like(q_all) - # This for loop is ugly. but it is actually quite efficient. The sequence dimension alone can already oversubscribe SMs - for head_index, (t_kernel, h_kernel, w_kernel) in enumerate(window_size): - for batch in range(q_all.shape[0]): - q_head, k_head, v_head, o_head = (q_all[batch:batch + 1, head_index:head_index + 1], - k_all[batch:batch + 1, - head_index:head_index + 1], v_all[batch:batch + 1, - head_index:head_index + 1], - hidden_states[batch:batch + 1, head_index:head_index + 1]) - - _ = sta_fwd(q_head, k_head, v_head, o_head, t_kernel, h_kernel, w_kernel, text_length, False, has_text) - if has_text: - _ = sta_fwd(q_all, k_all, v_all, hidden_states, 3, 3, 3, text_length, True, True) - return hidden_states[:, :, :seq_length] + assert q_all.shape[ + 2] == 46336, "Flex STA currently only supports video with latent size (12, 48, 80), which is 45 frames x 768 x 1280 pixels" + head_num = q_all.size(1) + hidden_states = torch.empty_like(q_all) + strategy_to_heads = {} + for head_index in range(head_num): + strategy = tuple(window_size[head_index]) # Convert list to tuple for dict key + if strategy not in strategy_to_heads: + strategy_to_heads[strategy] = [] + strategy_to_heads[strategy].append(head_index) + for strategy, heads in strategy_to_heads.items(): + # Gather all heads with this strategy + query_heads = torch.cat([q_all[:, head_idx:head_idx + 1, :, :] for head_idx in heads], dim=1) + key_heads = torch.cat([k_all[:, head_idx:head_idx + 1, :, :] for head_idx in heads], dim=1) + value_heads = torch.cat([v_all[:, head_idx:head_idx + 1, :, :] for head_idx in heads], dim=1) + + # Process all heads with this strategy at once + # processed_heads = selected_attn_processor[processor_idx](query_heads, key_heads, value_heads) + processed_heads = flex_sliding_tile_attention(query_heads, key_heads, value_heads, strategy, (6, 8, 8), (12, 48, 80), text_length) + + # Distribute results back to the correct positions + for i, head_idx in enumerate(heads): + hidden_states[:, head_idx:head_idx + 1, :, :] = processed_heads[:, i:i + 1, :, :] + + return hidden_states diff --git a/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py b/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py index 2b0b9cc6..27f1d891 100644 --- a/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py +++ b/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py @@ -546,6 +546,7 @@ def __call__( MultiPipelineCallbacks, ]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], vae_ver: str = "88-4c-sd", + use_cpu_offload: bool = False, enable_tiling: bool = False, enable_vae_sp: bool = False, n_tokens: Optional[int] = None, @@ -904,14 +905,22 @@ def dict_to_3d_list(mask_strategy, t_max=50, l_max=60, h_max=24): latents = (latents / self.vae.config.scaling_factor + self.vae.config.shift_factor) else: latents = latents / self.vae.config.scaling_factor + + if use_cpu_offload: + print("cpu offloaded") + self.transformer = self.transformer.to('cpu') with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled): if enable_tiling: + print("tiling enabled") self.vae.enable_tiling() if enable_vae_sp: self.vae.enable_parallel() image = self.vae.decode(latents, return_dict=False, generator=generator)[0] - + + if use_cpu_offload: + self.transformer = self.transformer.to(device) + if expand_temporal_dim or image.shape[2] == 1: image = image.squeeze(2) diff --git a/fastvideo/models/hunyuan/inference.py b/fastvideo/models/hunyuan/inference.py index 62694e6d..4384439c 100644 --- a/fastvideo/models/hunyuan/inference.py +++ b/fastvideo/models/hunyuan/inference.py @@ -15,7 +15,7 @@ from fastvideo.models.hunyuan.utils.data_utils import align_to from fastvideo.models.hunyuan.vae import load_vae from fastvideo.utils.parallel_states import nccl_info - +from fastvideo.models.hunyuan.modules.fp8 import convert_fp8_linear class Inference(object): @@ -76,7 +76,10 @@ def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs): # =========================== Build main model =========================== logger.info("Building model...") - factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]} + if args.use_cpu_offload: + factor_kwargs = {"device": 'cpu', "dtype": PRECISION_TO_TYPE[args.precision]} + else: + factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]} in_channels = args.latent_channels out_channels = args.latent_channels @@ -86,6 +89,11 @@ def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs): out_channels=out_channels, factor_kwargs=factor_kwargs, ) + + if args.use_fp8: + print("loading fp8 model") + convert_fp8_linear(model, args.dit_weight, original_dtype=PRECISION_TO_TYPE[args.precision]) + model = model.to(device) model = Inference.load_state_dict(args, model, pretrained_model_path) if args.enable_torch_compile: @@ -453,6 +461,7 @@ def predict( # Pipeline inference # ======================================================================== start_time = time.time() + torch._dynamo.config.cache_size_limit = 125 samples = self.pipeline( prompt=prompt, height=target_height, @@ -469,6 +478,7 @@ def predict( data_type="video" if target_video_length > 1 else "image", is_progress_bar=True, vae_ver=self.args.vae, + use_cpu_offload=self.args.use_cpu_offload, enable_tiling=self.args.vae_tiling, enable_vae_sp=self.args.vae_sp, mask_strategy=mask_strategy, diff --git a/fastvideo/models/hunyuan/modules/attenion.py b/fastvideo/models/hunyuan/modules/attenion.py index 699b3ae1..c5f61dcd 100644 --- a/fastvideo/models/hunyuan/modules/attenion.py +++ b/fastvideo/models/hunyuan/modules/attenion.py @@ -34,11 +34,11 @@ def attention( return out -def tile(x, sp_size): - x = rearrange(x, "b (sp t h w) head d -> b (t sp h w) head d", sp=sp_size, t=30 // sp_size, h=48, w=80) +def tile(x, sp_size, t_size): + x = rearrange(x, "b (sp t h w) head d -> b (t sp h w) head d", sp=sp_size, t=(t_size // sp_size), h=48, w=80) return rearrange(x, "b (n_t ts_t n_h ts_h n_w ts_w) h d -> b (n_t n_h n_w ts_t ts_h ts_w) h d", - n_t=5, + n_t= (t_size // 6), n_h=6, n_w=10, ts_t=6, @@ -46,16 +46,16 @@ def tile(x, sp_size): ts_w=8) -def untile(x, sp_size): +def untile(x, sp_size, t_size): x = rearrange(x, "b (n_t n_h n_w ts_t ts_h ts_w) h d -> b (n_t ts_t n_h ts_h n_w ts_w) h d", - n_t=5, + n_t=(t_size // 6), n_h=6, n_w=10, ts_t=6, ts_h=8, ts_w=8) - return rearrange(x, "b (t sp h w) head d -> b (sp t h w) head d", sp=sp_size, t=30 // sp_size, h=48, w=80) + return rearrange(x, "b (t sp h w) head d -> b (sp t h w) head d", sp=sp_size, t=(t_size // sp_size), h=48, w=80) def parallel_attention(q, k, v, img_q_len, img_kv_len, text_mask, mask_strategy=None): @@ -83,9 +83,10 @@ def shrink_head(encoder_state, dim): encoder_sequence_length = encoder_query.size(1) if mask_strategy[0] is not None: - query = torch.cat([tile(query, nccl_info.sp_size), encoder_query], dim=1).transpose(1, 2) - key = torch.cat([tile(key, nccl_info.sp_size), encoder_key], dim=1).transpose(1, 2) - value = torch.cat([tile(value, nccl_info.sp_size), encoder_value], dim=1).transpose(1, 2) + t_size = int(query.shape[1] / (8 * 6 * 8 * 10)) # ts_h n_h ts_w n_w + query = torch.cat([tile(query, nccl_info.sp_size, t_size), encoder_query], dim=1).transpose(1, 2) + key = torch.cat([tile(key, nccl_info.sp_size, t_size), encoder_key], dim=1).transpose(1, 2) + value = torch.cat([tile(value, nccl_info.sp_size, t_size), encoder_value], dim=1).transpose(1, 2) head_num = query.size(1) current_rank = nccl_info.rank_within_group @@ -107,7 +108,7 @@ def shrink_head(encoder_state, dim): dim=1) if mask_strategy[0] is not None: - hidden_states = untile(hidden_states, nccl_info.sp_size) + hidden_states = untile(hidden_states, nccl_info.sp_size, t_size) if get_sequence_parallel_state(): hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) diff --git a/fastvideo/models/hunyuan/modules/fp8.py b/fastvideo/models/hunyuan/modules/fp8.py new file mode 100644 index 00000000..0cff7d9d --- /dev/null +++ b/fastvideo/models/hunyuan/modules/fp8.py @@ -0,0 +1,100 @@ +import os + +import torch +import torch.nn as nn +from torch.nn import functional as F + +def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1): + _bits = torch.tensor(bits) + _mantissa_bit = torch.tensor(mantissa_bit) + _sign_bits = torch.tensor(sign_bits) + M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits) + E = _bits - _sign_bits - M + bias = 2 ** (E - 1) - 1 + mantissa = 1 + for i in range(mantissa_bit - 1): + mantissa += 1 / (2 ** (i+1)) + maxval = mantissa * 2 ** (2**E - 1 - bias) + return maxval + +def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1): + """ + Default is E4M3. + """ + bits = torch.tensor(bits) + mantissa_bit = torch.tensor(mantissa_bit) + sign_bits = torch.tensor(sign_bits) + M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits) + E = bits - sign_bits - M + bias = 2 ** (E - 1) - 1 + mantissa = 1 + for i in range(mantissa_bit - 1): + mantissa += 1 / (2 ** (i+1)) + maxval = mantissa * 2 ** (2**E - 1 - bias) + minval = - maxval + minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval) + input_clamp = torch.min(torch.max(x, minval), maxval) + log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0) + log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype)) + # dequant + qdq_out = torch.round(input_clamp / log_scales) * log_scales + return qdq_out, log_scales + +def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1): + for i in range(len(x.shape) - 1): + scale = scale.unsqueeze(-1) + new_x = x / scale + quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits) + return quant_dequant_x, scale, log_scales + +def fp8_activation_dequant(qdq_out, scale, dtype): + qdq_out = qdq_out.type(dtype) + quant_dequant_x = qdq_out * scale.to(dtype) + return quant_dequant_x + +def fp8_linear_forward(cls, original_dtype, input): + weight_dtype = cls.weight.dtype + ##### + if cls.weight.dtype != torch.float8_e4m3fn: + maxval = get_fp_maxval() + scale = torch.max(torch.abs(cls.weight.flatten())) / maxval + linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale) + linear_weight = linear_weight.to(torch.float8_e4m3fn) + weight_dtype = linear_weight.dtype + else: + scale = cls.fp8_scale.to(cls.weight.device) + linear_weight = cls.weight + ##### + + if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0: + if True or len(input.shape) == 3: + cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype) + if cls.bias != None: + output = F.linear(input, cls_dequant, cls.bias) + else: + output = F.linear(input, cls_dequant) + return output + else: + return cls.original_forward(input.to(original_dtype)) + else: + return cls.original_forward(input) + +def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}): + setattr(module, "fp8_matmul_enabled", True) + + # loading fp8 mapping file + fp8_map_path = dit_weight_path.replace('.pt', '_map.pt') + if os.path.exists(fp8_map_path): + fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage) + else: + raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.") + + fp8_layers = [] + for key, layer in module.named_modules(): + if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key): + fp8_layers.append(key) + original_forward = layer.forward + layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn)) + setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype)) + setattr(layer, "original_forward", original_forward) + setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input)) \ No newline at end of file diff --git a/fastvideo/sample/sample_t2v_hunyuan_STA.py b/fastvideo/sample/sample_t2v_hunyuan_STA.py index 9354ca7a..27b7a058 100644 --- a/fastvideo/sample/sample_t2v_hunyuan_STA.py +++ b/fastvideo/sample/sample_t2v_hunyuan_STA.py @@ -85,7 +85,9 @@ def teacache_forward( img_mod2_gate, ) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1) normed_inp = self.double_blocks[0].img_norm1(inp) - modulated_inp = modulate(normed_inp, shift=img_mod1_shift, scale=img_mod1_scale) + modulated_inp = modulate(normed_inp, shift=img_mod1_shift, scale=img_mod1_scale).to("cpu") + del inp, vec_, img_mod1_shift, img_mod1_scale, normed_inp + if self.cnt == 0 or self.cnt == self.num_steps - 1: should_calc = True self.accumulated_rel_l1_distance = 0 @@ -106,9 +108,10 @@ def teacache_forward( self.cnt = 0 if self.enable_teacache: if not should_calc: - img += self.previous_residual + img += self.previous_residual.to(img.device) + self.previous_residual = self.previous_residual.to(img.device) else: - ori_img = img.clone() + ori_img = img.clone().to("cpu") # --------------------- Pass through DiT blocks ------------------------ for index, block in enumerate(self.double_blocks): double_block_args = [img, txt, vec, freqs_cis, text_mask, mask_strategy[index]] @@ -133,7 +136,8 @@ def teacache_forward( features_list.append(x[:, :img_seq_len, ...]) img = x[:, :img_seq_len, ...] - self.previous_residual = img - ori_img + self.previous_residual = (img.clone().to("cpu") - ori_img).to("cpu") + del ori_img else: # --------------------- Pass through DiT blocks ------------------------ for index, block in enumerate(self.double_blocks): @@ -301,6 +305,11 @@ def main(args): action="store_true", help="Use CPU offload for the model load.", ) + parser.add_argument( + "--use-fp8", + action="store_true", + help="Use FP8 Quantization for the model load.", + ) parser.add_argument( "--dit-weight", type=str, @@ -373,6 +382,7 @@ def main(args): parser.add_argument("--text-states-dim-2", type=int, default=768) parser.add_argument("--tokenizer-2", type=str, default="clipL") parser.add_argument("--text-len-2", type=int, default=77) + parser.add_argument("--vae_tiling", action='store_true') parser.add_argument("--skip_time_steps", type=int, default=10) parser.add_argument( "--mask_strategy_selected", diff --git a/fastvideo/sample/sample_t2v_stepvideo_STA.py b/fastvideo/sample/sample_t2v_stepvideo_STA.py index fc0c8cd8..4a09d37e 100644 --- a/fastvideo/sample/sample_t2v_stepvideo_STA.py +++ b/fastvideo/sample/sample_t2v_stepvideo_STA.py @@ -341,7 +341,7 @@ def teacache_forward( ) # TeaCache - pipeline.transformer.__class__.enable_teacache = True + pipeline.transformer.__class__.enable_teacache = args.enable_teacache pipeline.transformer.__class__.cnt = 0 pipeline.transformer.__class__.num_steps = args.infer_steps pipeline.transformer.__class__.rel_l1_thresh = args.rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup diff --git a/scripts/inference/inference_hunyuan_STA.sh b/scripts/inference/inference_hunyuan_STA.sh index e17f284b..efde6992 100644 --- a/scripts/inference/inference_hunyuan_STA.sh +++ b/scripts/inference/inference_hunyuan_STA.sh @@ -45,5 +45,4 @@ CUDA_VISIBLE_DEVICES=1 torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_p --model_path $MODEL_BASE \ --mask_strategy_file_path $mask_strategy_file_path \ --dit-weight ${MODEL_BASE}/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \ - --vae-sp \ - --enable_torch_compile \ No newline at end of file + --vae-sp \ No newline at end of file