Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v1 #270

Merged
merged 60 commits into from
Mar 29, 2025
Merged

v1 #270

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
42f902c
Initial set of common files and layers from vLLM (#226)
SolitaryThinker Mar 1, 2025
ac07e43
add sp comm (#231)
SolitaryThinker Mar 3, 2025
5252d50
Initial clip encoder and cli args organization (#232)
SolitaryThinker Mar 4, 2025
0f4c8d1
[Refactor] Add Hunyuan DiT Modeling (#241)
jzhang38 Mar 5, 2025
fc5a4bc
Refactor py (#246)
jzhang38 Mar 6, 2025
6e8b11c
v1 staging architecture
jzhang38 Mar 9, 2025
23fd3ed
[Do not merge] V1 encoders and model loading (#261)
SolitaryThinker Mar 14, 2025
b2def4b
DiT done and plub in pipeline (#252)
jzhang38 Mar 9, 2025
b631546
move refactor to fastvideo/v1 (#265)
SolitaryThinker Mar 14, 2025
8d99ec3
V1 (#257)
jzhang38 Mar 11, 2025
bf1fc27
remove unneeded file
SolitaryThinker Mar 14, 2025
eafeea4
revert pyproject.toml
SolitaryThinker Mar 14, 2025
1976b23
move v1's v0 code into v1/v0_reference_src
SolitaryThinker Mar 14, 2025
d5ac1e9
remove unused attention
SolitaryThinker Mar 14, 2025
d2db0d4
fix import paths
SolitaryThinker Mar 14, 2025
e976583
Add wan dit
JerryZhou54 Mar 15, 2025
fb44fba
debugging encoders
SolitaryThinker Mar 15, 2025
759f243
fix attn
jzhang38 Mar 15, 2025
f51e9d4
running, correctness isues
SolitaryThinker Mar 15, 2025
796eaf8
add toggle flags for v0 pipeline components
SolitaryThinker Mar 16, 2025
691f9d1
vae update
SolitaryThinker Mar 16, 2025
83348c6
update
jzhang38 Mar 16, 2025
4e056b9
Merge branch 'rebased-refactor' of https://github.com/SolitaryThinker…
jzhang38 Mar 16, 2025
c37535a
magic line
jzhang38 Mar 16, 2025
ae5ed0c
update
jzhang38 Mar 16, 2025
6c52846
Merge pull request #1 from SolitaryThinker/wei
SolitaryThinker Mar 16, 2025
e14d384
model/loader.py -> component_loader.py
SolitaryThinker Mar 16, 2025
bb6ae36
moved loader/ into models/
SolitaryThinker Mar 16, 2025
8140cb2
cleanup
SolitaryThinker Mar 16, 2025
dd10588
cleanup
SolitaryThinker Mar 16, 2025
dc2a451
Streamline pipeline (#271)
jzhang38 Mar 17, 2025
2f59409
[v1] cleanup and remove more vllm dependencies (#274)
SolitaryThinker Mar 19, 2025
31057b6
[v1] Attention backend abstraction and Flash-attn backend (#276)
SolitaryThinker Mar 20, 2025
bc2c8df
refactor pipeline (#281)
SolitaryThinker Mar 25, 2025
86d2786
Add CLI for inference (#277)
kevin314 Mar 25, 2025
aa0ac8b
revert predict.py
SolitaryThinker Mar 26, 2025
f8890c7
format 80 column
SolitaryThinker Mar 26, 2025
42ad1e7
format config
SolitaryThinker Mar 26, 2025
5aeb825
removed composed package in pipeline
SolitaryThinker Mar 26, 2025
37ffba7
remove xpu/hip
SolitaryThinker Mar 26, 2025
ce56abb
remove more code
SolitaryThinker Mar 26, 2025
5085006
PY refactor text encoding stages (#286)
jzhang38 Mar 26, 2025
2fd05f5
Clean up pipeline (#287)
jzhang38 Mar 26, 2025
8140ccb
cleanup code
SolitaryThinker Mar 26, 2025
e092927
license headers and cite external code
SolitaryThinker Mar 26, 2025
baa5b0e
remove logging
SolitaryThinker Mar 26, 2025
2cc570a
Add Wan VAE & T5 Text Encoder (#291)
JerryZhou54 Mar 27, 2025
5ee07e3
More v1 cleanup (#292)
SolitaryThinker Mar 27, 2025
1a31234
Rebased numerical (#293)
SolitaryThinker Mar 28, 2025
6aee225
Rebased refactor (#295)
SolitaryThinker Mar 28, 2025
c836c6a
format (#296)
SolitaryThinker Mar 28, 2025
ac88ab6
done with type checking - v1/layers/ (#297)
BrianChen1129 Mar 28, 2025
024eb3e
Rebased refactor (#298)
SolitaryThinker Mar 29, 2025
5461e56
format platforms (#300)
SolitaryThinker Mar 29, 2025
f52ac3c
done with type checking - v1/models/ (#299)
JerryZhou54 Mar 29, 2025
8633119
Rebased refactor (#301)
SolitaryThinker Mar 29, 2025
2ad54bc
cleanup
SolitaryThinker Mar 29, 2025
a6fadc7
don't use data/ for model_path
SolitaryThinker Mar 29, 2025
aaac4d1
revert relative imports
SolitaryThinker Mar 29, 2025
de49a2d
format
SolitaryThinker Mar 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/sliding_tile_attention/st_attn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def sliding_tile_attention(q_all, k_all, v_all, window_size, text_length, has_te
seq_length = q_all.shape[2]
if has_text:
assert q_all.shape[
2] == 115456, "STA currently only supports video with latent size (30, 48, 80), which is 117 frames x 768 x 1280 pixels"
2] >= 115200, "STA currently only supports video with latent size (30, 48, 80), which is 117 frames x 768 x 1280 pixels"
assert q_all.shape[1] == len(window_size), "Number of heads must match the number of window sizes"
target_size = math.ceil(seq_length / 384) * 384
pad_size = target_size - seq_length
Expand Down
2 changes: 2 additions & 0 deletions env_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ pip install packaging ninja && pip install flash-attn==2.7.0.post2 --no-build-is

pip install -r requirements-lint.txt

pip install -r requirements.txt

# install fastvideo
pip install -e .
17 changes: 17 additions & 0 deletions fastvideo/v1/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-License-Identifier: Apache-2.0

from fastvideo.v1.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder)
from fastvideo.v1.attention.layer import DistributedAttention, LocalAttention
from fastvideo.v1.attention.selector import get_attn_backend

__all__ = [
"DistributedAttention",
"LocalAttention",
"AttentionBackend",
"AttentionMetadata",
"AttentionMetadataBuilder",
# "AttentionState",
"get_attn_backend",
]
Empty file.
245 changes: 245 additions & 0 deletions fastvideo/v1/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/attention/backends/abstract.py

from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, Generic, Optional, Protocol, Set,
Type, TypeVar)

if TYPE_CHECKING:
from fastvideo.v1.inference_args import InferenceArgs
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch

import torch


class AttentionBackend(ABC):
"""Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False

@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError

@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
raise NotImplementedError

@staticmethod
@abstractmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError

# @staticmethod
# @abstractmethod
# def get_state_cls() -> Type["AttentionState"]:
# raise NotImplementedError

# @classmethod
# def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
# return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError


@dataclass
class AttentionMetadata:
"""Attention metadata for prefill and decode batched together."""
# Current step of diffusion process
current_timestep: int

# @property
# @abstractmethod
# def inference_metadata(self) -> Optional["AttentionMetadata"]:
# """Return the attention metadata that's required to run prefill
# attention."""
# pass

# @property
# @abstractmethod
# def training_metadata(self) -> Optional["AttentionMetadata"]:
# """Return the attention metadata that's required to run decode
# attention."""
# pass

def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if skip_fields is None:
skip_fields = set()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self) if field.name not in skip_fields
}


T = TypeVar("T", bound=AttentionMetadata)

# class AttentionState(ABC, Generic[T]):
# """Holds attention backend-specific objects reused during the
# lifetime of the model runner."""

# @abstractmethod
# def __init__(self, runner: "ModelRunnerBase"):
# ...

# @abstractmethod
# @contextmanager
# def graph_capture(self, max_batch_size: int):
# """Context manager used when capturing CUDA graphs."""
# yield

# @abstractmethod
# def graph_clone(self, batch_size: int) -> "AttentionState[T]":
# """Clone attention state to save in CUDA graph metadata."""
# ...

# @abstractmethod
# def graph_capture_get_metadata_for_batch(
# self,
# batch_size: int,
# is_encoder_decoder_model: bool = False) -> T:
# """Get attention metadata for CUDA graph capture of batch_size."""
# ...

# @abstractmethod
# def get_graph_input_buffers(
# self,
# attn_metadata: T,
# is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
# """Get attention-specific input buffers for CUDA graph capture."""
# ...

# @abstractmethod
# def prepare_graph_input_buffers(
# self,
# input_buffers: Dict[str, Any],
# attn_metadata: T,
# is_encoder_decoder_model: bool = False) -> None:
# """In-place modify input buffers dict for CUDA graph replay."""
# ...

# @abstractmethod
# def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
# """Prepare state for forward pass."""
# ...


class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""

@abstractmethod
def __init__(self) -> None:
"""Create the builder, remember some configuration and parameters."""
raise NotImplementedError

@abstractmethod
def prepare(self) -> None:
"""Prepare for one batch."""
raise NotImplementedError

@abstractmethod
def build(
self,
current_timestep: int,
forward_batch: "ForwardBatch",
inference_args: "InferenceArgs",
) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError


class AttentionLayer(Protocol):

_k_scale: torch.Tensor
_v_scale: torch.Tensor
_k_scale_float: float
_v_scale_float: float

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
...


class AttentionImpl(ABC, Generic[T]):

@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
softmax_scale: float,
dropout_rate: float = 0.0,
causal: bool = False,
num_kv_heads: Optional[int] = None,
) -> None:
raise NotImplementedError

def preprocess_qkv(self, qkv: torch.Tensor,
attn_metadata: T) -> torch.Tensor:
"""Preprocess QKV tensor before performing attention operation.

Default implementation returns the tensor unchanged.
Subclasses can override this to implement custom preprocessing
like reshaping, tiling, scaling, or other transformations.

Called AFTER all_to_all for distributed attention

Args:
qkv: The query-key-value tensor
attn_metadata: Metadata for the attention operation

Returns:
Processed QKV tensor
"""
return qkv

def postprocess_output(
self,
output: torch.Tensor,
attn_metadata: T,
) -> torch.Tensor:
"""Postprocess the output tensor after the attention operation.

Default implementation returns the tensor unchanged.
Subclasses can override this to implement custom postprocessing
like untiling, scaling, or other transformations.

Called BEFORE all_to_all for distributed attention

Args:
output: The output tensor from the attention operation
attn_metadata: Metadata for the attention operation

Returns:
Postprocessed output tensor
"""

return output

@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: T,
) -> torch.Tensor:
raise NotImplementedError
70 changes: 70 additions & 0 deletions fastvideo/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional, Type

import torch
from flash_attn import flash_attn_func

from fastvideo.v1.attention.backends.abstract import (AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder)
from fastvideo.v1.logger import init_logger

logger = init_logger(__name__)


class FlashAttentionBackend(AttentionBackend):

accept_output_buffer: bool = True

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

@staticmethod
def get_name() -> str:
return "FLASH_ATTN"

@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl

@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError

@staticmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError


class FlashAttentionImpl(AttentionImpl):

def __init__(
self,
num_heads: int,
head_size: int,
dropout_rate: float,
causal: bool,
softmax_scale: float,
num_kv_heads: Optional[int] = None,
) -> None:
self.dropout_rate = dropout_rate
self.causal = causal
self.softmax_scale = softmax_scale

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
):
output = flash_attn_func(query,
key,
value,
dropout_p=self.dropout_rate,
softmax_scale=self.softmax_scale,
causal=self.causal)
return output
Loading
Loading