Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiheadAttention conversion #64

Open
Maritime-Moon opened this issue Dec 6, 2024 · 10 comments
Open

MultiheadAttention conversion #64

Maritime-Moon opened this issue Dec 6, 2024 · 10 comments

Comments

@Maritime-Moon
Copy link

How torch.nn.MultiheadAttention integrates with sageattention?
What should I do?

@jt-zhang
Copy link
Member

jt-zhang commented Dec 6, 2024

Thank you for reaching out.
SageAttention only supports q, k, v as inputs and outputs the attn_output. Unlike MultiheadAttention, which takes q, k, v before the linear transformation as input and outputs a linear transformation of attn_output. You might need to complete those transformations before and after sageattention using Torch by referring the code of MultiheadAttention.

@Maritime-Moon
Copy link
Author

You mean you need to refactor MultiheadAttention, do you

@jt-zhang
Copy link
Member

jt-zhang commented Dec 7, 2024

Yes, you should complete the computation contained in the MultiheadAttention function except for $\mathrm{Softmax}(QK^\top / \sqrt{d}) V$ that can be done by sageattention.

By the way, maybe you can Star and Watch our repository for possible updates, and we will appreciate that.

@Maritime-Moon
Copy link
Author

That is, That is, I need to replace the Softmax calculations in it, right?

@Maritime-Moon
Copy link
Author

Do I need to do attention calculations myself like this?

def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = True, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

    if self.batch_first:  
        query, key, value = [x.transpose(1, 0) for x in (query, key, value)]  
    
    # Linear projections  
    query = F.linear(query, self.q_proj_weight)  # (N, L_Q, E)  
    key = F.linear(key, self.k_proj_weight)       # (N, L_K, E)  
    value = F.linear(value, self.v_proj_weight)   # (N, L_V, E)  

    # Split into multiple heads  
    query = query.view(query.size(0), query.size(1), self.num_heads, self.head_dim).transpose(1, 2)  # (N, num_heads, L_Q, head_dim)  
    key = key.view(key.size(0), key.size(1), self.num_heads, self.head_dim).transpose(1, 2)          # (N, num_heads, L_K, head_dim)  
    value = value.view(value.size(0), value.size(1), self.num_heads, self.head_dim).transpose(1, 2)  # (N, num_heads, L_V, head_dim)  

    # 注意力计算:QK^T / √d  
    dk = query.size(-1) ** 0.5  
    attn_scores = torch.matmul(query, key.transpose(-2, -1)) / dk  # (N, num_heads, L_Q, L_K)  
    
    # 使用 SageAttention  
    attn_output = sageattn(attn_scores

@Maritime-Moon
Copy link
Author

I see that the internal calculation of MultiheadAttention looks like this, and here is the part that can be replaced, but I don't know how to change it, because it has two return values, and ours has only one return value, I don't know how to fix it.

attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)

@Maritime-Moon
Copy link
Author

start = time.time()
attn_output1 = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
print('ini_times:', start, '\t', time.time() - start, end='\t')
time.sleep(3)
start = time.time()
attn_output = sageattn(q, k, v, tensor_layout="HND", is_causal=is_causal)
print('sage_times:', start, '\t', time.time() - start)

这样进行统计之后,我发现我的运行速度没有降低,为什么呢?
ini_times: 1735125476.4466343 0.0 sage_times: 1735125479.447369 0.0010123252868652344
ini_times: 1735125479.4504008 0.0 sage_times: 1735125482.4540765 0.0017654895782470703
ini_times: 1735125482.4588583 0.0 sage_times: 1735125485.465288 0.0007486343383789062
ini_times: 1735125485.4700654 0.0 sage_times: 1735125488.4787986 0.0017690658569335938
ini_times: 1735125488.485545 0.0 sage_times: 1735125491.4895458 0.0
ini_times: 1735125491.4915357 0.0 sage_times: 1735125494.5012853 0.0006971359252929688
ini_times: 1735125494.50397 0.0 sage_times: 1735125497.5077531 0.0
ini_times: 1735125497.5097473 0.0 sage_times: 1735125500.5140347 0.0010018348693847656
ini_times: 1735125500.5170624 0.0 sage_times: 1735125503.518924 0.0
ini_times: 1735125503.5209193 0.0 sage_times: 1735125506.535997 0.0
ini_times: 1735125506.539023 0.0 sage_times: 1735125509.549657 0.0

@jt-zhang
Copy link
Member

Please refer to our FLOPS testing codes.
FLOPS testing needs warmup, e.g., 5 times, and taking the average latency of multiple executions, e.g., 100 times. Also, the sequence length of q k v should be long enough.

@Maritime-Moon
Copy link
Author

Can you provide me with the following link?

@Maritime-Moon
Copy link
Author

My sequence is only 197 in length, is that too short?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants