We can replace scaled_dot_product_attention
easily.
We will take Cogvideo as an example:
Just add the following codes and run!
from sageattention import sageattn
import torch.nn.functional as F
F.scaled_dot_product_attention = sageattn
Specifically,
cd example
python sageattn_cogvideo.py
You can get a lossless video in ./example
faster than by using python original_cogvideo.py
We will take Cogvideo SAT as an example:
Once you have set up the environment for cogvideoX's SAT and can generate videos, you can plug SageAttention and play easily by replacing lines 67-72 in CogVideo/sat/sat/transformer_defaults.py:
67 | attn_output = torch.nn.functional.scaled_dot_product_attention(
68 | query_layer, key_layer, value_layer,
69 | attn_mask=None,
70 | dropout_p=dropout_p,
71 | is_causal=not is_full
72 | )
with the following code:
attn_output = sageattn(
query_layer, key_layer, value_layer,
is_causal=not is_full
)