Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compatibility with original HunyuanVideo repository #203

Open
YoadTew opened this issue Jan 9, 2025 · 6 comments
Open

compatibility with original HunyuanVideo repository #203

YoadTew opened this issue Jan 9, 2025 · 6 comments
Assignees

Comments

@YoadTew
Copy link

YoadTew commented Jan 9, 2025

Thanks @a-r-r-o-w for the great work!

I had a question about the HunyuanVideo training script:

From what I've seen the diffusers version of HunyuanVideo is not implemented with flash attention, opposed to the original implementation.
As a result, generation in diffusers is slower and takes more VRAM (>80GB for 720p videos, 129 frames vs ~60GB for flash-attention implementation).

Is this training script compatible with the original flash-attention implementation? And after training, is it possible to use the new checkpoints with the original HunyuanVideo repo?

@YoadTew YoadTew changed the title Comparability with original HunyuanVideo repository compatibility with original HunyuanVideo repository Jan 9, 2025
@ArEnSc
Copy link
Contributor

ArEnSc commented Jan 9, 2025

From what I understand flash attention is drop in so it should just work. But not an expert, this is how it works with LLM's

@a-r-r-o-w
Copy link
Owner

@YoadTew Thank you for your interest!

Yes, it is on my list of TODO's to support flash attention alongside sageattn.

In Diffusers, we expose the option of setting a custom attention processor, so all that's required is implementing the processor to deduce/create the right parameters to pass to flash.

It might also be possible to just leverage SDPBackend.FLASH_ATTENTION as a quick test, no? I haven't tried it yet for the trainer, so I'm unsure if it will work out-of-the-box.

https://pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend

You might have to make some changes for it to work and contributions are welcome, but will try to support it as soon as I find some extra time

@a-r-r-o-w
Copy link
Owner

Oh, we use an attention mask, so SDPBackend.FLASH_ATTENTION won't work :/

The workaround to use it would be to use very long prompts, so that there are no padding tokens, and not provide any attention mask at all. I think it should work, and could add some checks that enforce user dataset to have enough tokens so that no padding happens

@a-r-r-o-w a-r-r-o-w self-assigned this Jan 11, 2025
@YoadTew
Copy link
Author

YoadTew commented Jan 12, 2025

Hey @a-r-r-o-w , you are correct the SDPBackend.FLASH_ATTENTION solution does not work because of the padding mask.
But, based on the HunyuanVideo repository, I created a flash-attention processor that worked for me, with the same speed as the original repo.

Steps to reproduce:

  1. set CUDA_home:
conda install -c nvidia cuda-toolkit=12.4
export CUDA_HOME=$(dirname $(dirname $(which nvcc)))
export PATH=$CUDA_HOME/bin:$PATH
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH

and install flash-attention:

python -m pip install ninja
python -m pip install git+https://github.com/Dao-AILab/[email protected]
  1. Create a new AttnProcessor and replace it in HunyuanVideoSingleTransformerBlock and HunyuanVideoTransformerBlock

Code for new AttnProcessor:

try:
    import flash_attn
    from flash_attn.flash_attn_interface import _flash_attn_forward
    from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
    flash_attn = None
    flash_attn_varlen_func = None
    _flash_attn_forward = None

def get_cu_seqlens(attention_mask):
    """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len

    Args:
        text_mask (torch.Tensor): the mask of text
        img_len (int): the length of image

    Returns:
        torch.Tensor: the calculated cu_seqlens for flash attention
    """
    batch_size = attention_mask.shape[1]
    text_len = attention_mask.sum(dim=2)
    max_len = attention_mask.shape[2]

    cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")

    for i in range(batch_size):
        s = text_len[i]
        s1 = i * max_len + s
        s2 = (i + 1) * max_len
        cu_seqlens[2 * i + 1] = s1
        cu_seqlens[2 * i + 2] = s2

    return cu_seqlens

class HunyuanVideoAttnProcessor2_0:
    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError(
                "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
            )

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if attn.add_q_proj is None and encoder_hidden_states is not None:
            hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)

        # 1. QKV projections
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)

        # 2. QK normalization
        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        # 3. Rotational positional embeddings applied to latent stream
        if image_rotary_emb is not None:
            if attn.add_q_proj is None and encoder_hidden_states is not None:
                query = torch.cat(
                    [
                        apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
                        query[:, :, -encoder_hidden_states.shape[1] :],
                    ],
                    dim=2,
                )
                key = torch.cat(
                    [
                        apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
                        key[:, :, -encoder_hidden_states.shape[1] :],
                    ],
                    dim=2,
                )
            else:
                query = apply_rotary_emb(query, image_rotary_emb)
                key = apply_rotary_emb(key, image_rotary_emb)


        batch_size = hidden_states.shape[0]
        img_seq_len = hidden_states.shape[1]
        txt_seq_len = 0

        # 4. Encoder condition QKV projection and normalization
        if attn.add_q_proj is not None and encoder_hidden_states is not None:
            encoder_query = attn.add_q_proj(encoder_hidden_states)
            encoder_key = attn.add_k_proj(encoder_hidden_states)
            encoder_value = attn.add_v_proj(encoder_hidden_states)

            encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
            encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
            encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)

            if attn.norm_added_q is not None:
                encoder_query = attn.norm_added_q(encoder_query)
            if attn.norm_added_k is not None:
                encoder_key = attn.norm_added_k(encoder_key)

            query = torch.cat([query, encoder_query], dim=2)
            key = torch.cat([key, encoder_key], dim=2)
            value = torch.cat([value, encoder_value], dim=2)

            txt_seq_len = encoder_hidden_states.shape[1]

        # 5. Attention
        max_seqlen_q = max_seqlen_kv = img_seq_len + txt_seq_len
        cu_seqlens_q = cu_seqlens_kv = get_cu_seqlens(attention_mask)

        pre_attn_layout = lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:])
        query = pre_attn_layout(query.transpose(1,2))
        key = pre_attn_layout(key.transpose(1,2))
        value = pre_attn_layout(value.transpose(1,2))

        hidden_states = flash_attn_varlen_func(
            query,
            key,
            value,
            cu_seqlens_q,
            cu_seqlens_kv,
            max_seqlen_q,
            max_seqlen_kv,
        )

        hidden_states = hidden_states.view(
            batch_size, max_seqlen_q, hidden_states.shape[-2], hidden_states.shape[-1]
        )

        hidden_states = hidden_states.flatten(2, 3)
        hidden_states = hidden_states.to(query.dtype)

        # 6. Output projection
        if encoder_hidden_states is not None:
            hidden_states, encoder_hidden_states = (
                hidden_states[:, : -encoder_hidden_states.shape[1]],
                hidden_states[:, -encoder_hidden_states.shape[1] :],
            )

            if getattr(attn, "to_out", None) is not None:
                hidden_states = attn.to_out[0](hidden_states)
                hidden_states = attn.to_out[1](hidden_states)

            if getattr(attn, "to_add_out", None) is not None:
                encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

        return hidden_states, encoder_hidden_states

@sayakpaul
Copy link
Collaborator

Wow, @YoadTew! Given it benefits the runtime, I think it could be generally beneficial to the diffusers users, too. Would you maybe like to open a PR to the diffusers repository for this?

@YoadTew
Copy link
Author

YoadTew commented Jan 14, 2025

Hey @sayakpaul, sure, I will try to find time to open a PR this week

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants