Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]: support flex sta for gpus aside H100 #263

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add flex sta
rlsu9 committed Mar 12, 2025

Verified

This commit was signed with the committer’s verified signature.
jameslamb James Lamb
commit e45b41be7911cfa7e9c6a1f9d923cc13874340fb
86 changes: 83 additions & 3 deletions fastvideo/models/hunyuan/modules/attenion.py
Original file line number Diff line number Diff line change
@@ -7,11 +7,67 @@
except ImportError:
print("Could not load Sliding Tile Attention.")
sliding_tile_attention = None

from functools import lru_cache
from torch.nn.attention.flex_attention import flex_attention
from fastvideo.models.flash_attn_no_pad import flash_attn_no_pad
from fastvideo.utils.communications import all_gather, all_to_all_4D
from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info

from csrc.sliding_tile_attention.test.flex_sta_ref import get_sliding_tile_attention_mask
@lru_cache(maxsize=32)
def get_compiled_flex_attention(strategy, tile_size, image_size, text_length, device):
"""
Create and compile flex attention with a specific sliding block mask.
This function is cached to avoid recompiling for the same parameters.
Args:
strategy (tuple): A tuple (t, h, w) defining the strategy
tile_size (tuple): A tuple (ts_t, ts_h, ts_w) defining the tile size
image_size (tuple): A tuple (n_t, n_h, n_w) defining the image size
text_length (int): The text length
device (str): The device to use
Returns:
function: A compiled flex attention function with the specified mask
"""
# Convert strategy to the required format (ceil(t*3/2), h*2, w)
adjusted_strategy = strategy

# Get the sliding block attention mask
mask = get_sliding_tile_attention_mask(
adjusted_strategy,
tile_size,
image_size,
text_length,
device
)

def flex_attn_with_mask(q, k, v, scale=None):
return flex_attention(q, k, v, block_mask=mask, scale=scale)

# Compile the wrapper function
compiled_flex_attn = torch.compile(flex_attn_with_mask)

return compiled_flex_attn

def flex_sliding_tile_attention(q_all, k_all, v_all, strategy, tile_size,
image_size, text_length, scale=None):
device = q_all.device

# Get the compiled flex attention function (cached if called with same parameters)
compiled_flex_attn = get_compiled_flex_attention(
strategy,
tile_size,
image_size,
text_length,
device
)


# Apply the compiled flex attention
output = compiled_flex_attn(q_all, k_all, v_all, scale=scale)


return output

def attention(
q,
@@ -92,7 +148,31 @@ def shrink_head(encoder_state, dim):
start_head = current_rank * head_num
windows = [mask_strategy[head_idx + start_head] for head_idx in range(head_num)]

hidden_states = sliding_tile_attention(query, key, value, windows, text_length).transpose(1, 2)
if sliding_tile_attention is not None:
hidden_states = sliding_tile_attention(query, key, value, windows, text_length).transpose(1, 2)
else:
hidden_states = torch.empty_like(query)
strategy_to_heads = {}
for head_index in range(head_num):
strategy = tuple(windows[head_index]) # Convert list to tuple for dict key
if strategy not in strategy_to_heads:
strategy_to_heads[strategy] = []
strategy_to_heads[strategy].append(head_index)
for strategy, heads in strategy_to_heads.items():
# Gather all heads with this strategy
query_heads = torch.cat([query[:, head_idx:head_idx + 1, :, :] for head_idx in heads], dim=1)
key_heads = torch.cat([key[:, head_idx:head_idx + 1, :, :] for head_idx in heads], dim=1)
value_heads = torch.cat([value[:, head_idx:head_idx + 1, :, :] for head_idx in heads], dim=1)

# Process all heads with this strategy at once
# processed_heads = selected_attn_processor[processor_idx](query_heads, key_heads, value_heads)
processed_heads = flex_sliding_tile_attention(query_heads, key_heads, value_heads, strategy, (6, 8, 8), (12, 48, 80), text_length)

# Distribute results back to the correct positions
for i, head_idx in enumerate(heads):
hidden_states[:, head_idx:head_idx + 1, :, :] = processed_heads[:, i:i + 1, :, :]

hidden_states = hidden_states.transpose(1, 2)
else:
query = torch.cat([query, encoder_query], dim=1)
key = torch.cat([key, encoder_key], dim=1)