Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[do not merge] Rebased refactor branch #266

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
42f902c
Initial set of common files and layers from vLLM (#226)
SolitaryThinker Mar 1, 2025
ac07e43
add sp comm (#231)
SolitaryThinker Mar 3, 2025
5252d50
Initial clip encoder and cli args organization (#232)
SolitaryThinker Mar 4, 2025
0f4c8d1
[Refactor] Add Hunyuan DiT Modeling (#241)
jzhang38 Mar 5, 2025
fc5a4bc
Refactor py (#246)
jzhang38 Mar 6, 2025
6e8b11c
v1 staging architecture
jzhang38 Mar 9, 2025
23fd3ed
[Do not merge] V1 encoders and model loading (#261)
SolitaryThinker Mar 14, 2025
b2def4b
DiT done and plub in pipeline (#252)
jzhang38 Mar 9, 2025
b631546
move refactor to fastvideo/v1 (#265)
SolitaryThinker Mar 14, 2025
8d99ec3
V1 (#257)
jzhang38 Mar 11, 2025
bf1fc27
remove unneeded file
SolitaryThinker Mar 14, 2025
eafeea4
revert pyproject.toml
SolitaryThinker Mar 14, 2025
1976b23
move v1's v0 code into v1/v0_reference_src
SolitaryThinker Mar 14, 2025
d5ac1e9
remove unused attention
SolitaryThinker Mar 14, 2025
d2db0d4
fix import paths
SolitaryThinker Mar 14, 2025
e976583
Add wan dit
JerryZhou54 Mar 15, 2025
fb44fba
debugging encoders
SolitaryThinker Mar 15, 2025
759f243
fix attn
jzhang38 Mar 15, 2025
f51e9d4
running, correctness isues
SolitaryThinker Mar 15, 2025
796eaf8
add toggle flags for v0 pipeline components
SolitaryThinker Mar 16, 2025
691f9d1
vae update
SolitaryThinker Mar 16, 2025
83348c6
update
jzhang38 Mar 16, 2025
4e056b9
Merge branch 'rebased-refactor' of https://github.com/SolitaryThinker…
jzhang38 Mar 16, 2025
c37535a
magic line
jzhang38 Mar 16, 2025
ae5ed0c
update
jzhang38 Mar 16, 2025
6c52846
Merge pull request #1 from SolitaryThinker/wei
SolitaryThinker Mar 16, 2025
e14d384
model/loader.py -> component_loader.py
SolitaryThinker Mar 16, 2025
bb6ae36
moved loader/ into models/
SolitaryThinker Mar 16, 2025
8140cb2
cleanup
SolitaryThinker Mar 16, 2025
dd10588
cleanup
SolitaryThinker Mar 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions env_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ pip install packaging ninja && pip install flash-attn==2.7.0.post2 --no-build-is

pip install -r requirements-lint.txt

pip install -r requirements.txt

# install fastvideo
pip install -e .
3 changes: 3 additions & 0 deletions fastvideo/v1/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .flash_attn import (DistributedAttention, LocalAttention)

__all__ = ["DistributedAttention", "LocalAttention"]
138 changes: 138 additions & 0 deletions fastvideo/v1/attention/flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from itertools import accumulate
from typing import List, Optional

import torch
import torch.nn as nn
from fastvideo.v1.distributed.communication_op import sequence_model_parallel_all_to_all_4D, sequence_model_parallel_all_gather
from fastvideo.v1.distributed.parallel_state import get_sequence_model_parallel_rank, get_sequence_model_parallel_world_size
from flash_attn import flash_attn_func, flash_attn_varlen_func


class DistributedAttention(nn.Module):
"""Distributed attention module that supports sequence parallelism.

This class implements a minimal attention operation with support for distributed
processing across multiple GPUs using sequence parallelism. The implementation assumes
batch_size=1 and no padding tokens for simplicity.

The sequence parallelism strategy follows the Ulysses paper (https://arxiv.org/abs/2309.14509),
which proposes redistributing attention heads across sequence dimension to enable efficient
parallel processing of long sequences.

Args:
dropout_rate (float, optional): Dropout probability. Defaults to 0.0.
causal (bool, optional): Whether to use causal attention. Defaults to False.
softmax_scale (float, optional): Custom scaling factor for attention scores.
If None, uses 1/sqrt(head_dim). Defaults to None.
"""
def __init__(
self,
dropout_rate: float = 0.0,
causal: bool = False,
softmax_scale: Optional[float] = None,
):
super().__init__()
self.dropout_rate = dropout_rate
self.causal = causal
self.softmax_scale = softmax_scale

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
replicated_q: Optional[torch.Tensor] = None,
replicated_k: Optional[torch.Tensor] = None,
replicated_v: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward pass for distributed attention.

Args:
q (torch.Tensor): Query tensor [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor): Key tensor [batch_size, seq_len, num_heads, head_dim]
v (torch.Tensor): Value tensor [batch_size, seq_len, num_heads, head_dim]
replicated_q (Optional[torch.Tensor]): Replicated query tensor, typically for text tokens
replicated_k (Optional[torch.Tensor]): Replicated key tensor
replicated_v (Optional[torch.Tensor]): Replicated value tensor

Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing:
- o (torch.Tensor): Output tensor after attention for the main sequence
- replicated_o (Optional[torch.Tensor]): Output tensor for replicated tokens, if provided
"""
# Check input shapes
assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Expected 4D tensors"
# assert bs = 1
assert q.shape[0] == 1, "Batch size must be 1, and there should be no padding tokens"
batch_size, seq_len, num_heads, head_dim = q.shape
local_rank = get_sequence_model_parallel_rank()
world_size = get_sequence_model_parallel_world_size()

# Stack QKV
qkv = torch.cat([q, k, v], dim=0) # [3, seq_len, num_heads, head_dim]

# Redistribute heads across sequence dimension
qkv = sequence_model_parallel_all_to_all_4D(qkv, scatter_dim=2, gather_dim=1)

# Concatenate with replicated QKV if provided
if replicated_q is not None:
assert replicated_k is not None and replicated_v is not None
replicated_qkv = torch.cat([replicated_q, replicated_k, replicated_v], dim=0) # [3, seq_len, num_heads, head_dim]
heads_per_rank = num_heads // world_size
replicated_qkv = replicated_qkv[:, :, local_rank * heads_per_rank:(local_rank + 1) * heads_per_rank]
qkv = torch.cat([qkv, replicated_qkv], dim=1)

q, k, v = qkv.chunk(3, dim=0)
# Apply flash attention
output = flash_attn_func(
q,
k,
v,
dropout_p=self.dropout_rate,
softmax_scale=self.softmax_scale,
causal=self.causal
)
# Redistribute back if using sequence parallelism
replicated_output = None
if replicated_q is not None:
replicated_output = output[:, seq_len*world_size:]
output = output[:, :seq_len*world_size]
# TODO: make this asynchronous
replicated_output = sequence_model_parallel_all_gather(replicated_output, dim=2)
output = sequence_model_parallel_all_to_all_4D(output, scatter_dim=1, gather_dim=2)
return output, replicated_output


class LocalAttention(nn.Module):
def __init__(self, dropout_rate: float = 0.0, causal: bool = False, softmax_scale: Optional[float] = None):
super().__init__()
self.dropout_rate = dropout_rate
self.causal = causal
self.softmax_scale = softmax_scale

def forward(self, q, k, v):
"""
Apply local attention between query, key and value tensors.

Args:
q (torch.Tensor): Query tensor of shape [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor): Key tensor of shape [batch_size, seq_len, num_heads, head_dim]
v (torch.Tensor): Value tensor of shape [batch_size, seq_len, num_heads, head_dim]

Returns:
torch.Tensor: Output tensor after local attention
"""
# Check input shapes
assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Expected 4D tensors"

# Apply flash attention
output = flash_attn_func(
q,
k,
v,
dropout_p=self.dropout_rate,
softmax_scale=self.softmax_scale,
causal=self.causal
)

return output
5 changes: 5 additions & 0 deletions fastvideo/v1/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

from .communication_op import *
from .parallel_state import *
from .utils import *
55 changes: 55 additions & 0 deletions fastvideo/v1/distributed/communication_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/communication_op.py

from typing import Any, Dict, Optional, Union

import torch
import torch.distributed

from .parallel_state import get_tp_group, get_sp_group


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)


def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_tp_group().all_gather(input_, dim)


def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)


def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
Any]]] = None,
src: int = 0):
if not torch.distributed.is_initialized():
return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)


# TODO: remove model, make it sequence_parallel
def sequence_model_parallel_all_to_all_4D(input_: torch.Tensor,
scatter_dim: int = 2,
gather_dim: int = 1) -> torch.Tensor:
"""All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group."""
return get_sp_group().all_to_all_4D(input_, scatter_dim, gather_dim)

def sequence_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_sp_group().all_gather(input_, dim)

def sequence_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_sp_group().all_gather(input_, dim)


Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/base_device_communicator.py

from typing import Optional

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from einops import rearrange

class DeviceCommunicatorBase:
"""
Base class for device-specific communicator.
It can use the `cpu_group` to initialize the communicator.
If the device has PyTorch integration (PyTorch can recognize its
communication backend), the `device_group` will also be given.
"""

def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor

def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()

# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def all_to_all_4D(self,
input_: torch.Tensor,
scatter_dim: int = 2,
gather_dim: int = 1) -> torch.Tensor:
"""Specialized all-to-all operation for 4D tensors (e.g., for QKV matrices).

Args:
input_ (torch.Tensor): 4D input tensor to be scattered and gathered.
scatter_dim (int, optional): Dimension along which to scatter. Defaults to 2.
gather_dim (int, optional): Dimension along which to gather. Defaults to 1.

Returns:
torch.Tensor: Output tensor after all-to-all operation.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_

assert input_.dim() == 4, f"input must be 4D tensor, got {input_.dim()} and shape {input_.shape}"

if scatter_dim == 2 and gather_dim == 1:
# input: (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
bs, shard_seqlen, hc, hs = input_.shape
seqlen = shard_seqlen * self.world_size
shard_hc = hc // self.world_size

# Reshape and transpose for scattering
input_t = (input_.reshape(bs, shard_seqlen, self.world_size, shard_hc, hs).transpose(0, 2).contiguous())

output = torch.empty_like(input_t)


torch.distributed.all_to_all_single(output, input_t, group=self.device_group)
torch.cuda.synchronize()

# Reshape and transpose back
output = output.reshape(seqlen, bs, shard_hc, hs).transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)

return output

elif scatter_dim == 1 and gather_dim == 2:
# input: (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
bs, seqlen, shard_hc, hs = input_.shape
hc = shard_hc * self.world_size
shard_seqlen = seqlen // self.world_size

# Reshape and transpose for scattering
input_t = (input_.reshape(bs, self.world_size, shard_seqlen, shard_hc,
hs).transpose(0,
3).transpose(0,
1).contiguous().reshape(self.world_size, shard_hc,
shard_seqlen, bs, hs))
output = torch.empty_like(input_t)


torch.distributed.all_to_all_single(output, input_t, group=self.device_group)
torch.cuda.synchronize()

# Reshape and transpose back
output = output.reshape(hc, shard_seqlen, bs, hs).transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs)

return output
else:
raise RuntimeError("scatter_dim must be 1 or 2 and gather_dim must be 1 or 2")


def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
torch.distributed.send(tensor, self.ranks[dst], self.device_group)

def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size

tensor = torch.empty(size, dtype=dtype, device=self.device)
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor

def destroy(self):
pass
Loading
Loading