Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mamba causal-conv1d-update kernel #48

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from einops import rearrange
from typing import Literal, Optional

# vllm/attention/backends/utils.py
PAD_SLOT_ID = -1


@triton.autotune(
configs=[
Expand Down Expand Up @@ -328,3 +331,324 @@ def causal_conv1d_fn(
final_states_out,
activation,
)


@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_N": 64}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_N": 32}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_N": 16}, num_stages=3, num_warps=8),
],
key=["dim"],
restore_value=["conv_state_ptr", "x_ptr"],
)
@triton.jit()
def _causal_conv1d_update_kernel(
# Pointers to matrices
x_ptr, # (batch, dim, seqlen)
w_ptr, # (dim, width)
bias_ptr,
conv_state_ptr,
cache_seqlens_ptr,
conv_state_indices_ptr,
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch,
dim,
seqlen,
state_len,
num_cache_lines,
# Strides
stride_x_seq, # stride to get to next sequence,
stride_x_dim, # stride to get to next feature-value,
stride_x_token, # stride to get to next token (same feature-index, same sequence-index)
stride_weight_dim, # stride to get to next dim-axis value
stride_weight_width, # stride to get to next width-axis value
stride_conv_state_seq,
stride_conv_state_dim,
stride_conv_state_tok,
stride_o_seq,
stride_o_dim,
stride_o_token,
# others
pad_slot_id,
# Meta-parameters
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_CIRCULAR_BUFFER: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
BLOCK_N: tl.constexpr,
):
idx_seq = tl.program_id(0)
if idx_seq >= batch:
return

idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)

w_base = w_ptr + (idx_feats * stride_weight_dim)
if IS_CIRCULAR_BUFFER:
cache_seqlen = tl.load(cache_seqlens_ptr + idx_seq) # modulo later
else:
cache_seqlen = 0
# store output data at the corresponding tokens (BLOCK_M of them) and feature-indices (BLOCK_N of them) in these tokens
if IS_CONTINUOUS_BATCHING:
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq)
else:
conv_state_batch_coord = idx_seq
if USE_PAD_SLOT:
if conv_state_batch_coord == pad_slot_id:
# not processing
return
conv_state_base = (
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)
) # [BLOCK_N,]

for idx_token in range(seqlen):
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N, ]

if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
acc = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N]
else:
acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
PADDING_W = KERNEL_WIDTH - 1
for j in range(KERNEL_WIDTH):
# the token index to multiply with kernel[:, 0], given kernel with width-columns, i.e. kernel[:, 0..(width-1)]
idx_x_w = j - PADDING_W + idx_token
x_ptrs = x_base + (idx_x_w * stride_x_token) # [BLOCK_N]
mask_x = (idx_x_w >= 0) & (idx_x_w < seqlen) & (idx_feats < dim)
if IS_CIRCULAR_BUFFER:
assert 0 # TUAN TODO: double check the logic - it seems correct
conv_state_ptrs = (
conv_state_base + (((idx_x_w + cache_seqlen) % state_len) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
else:
conv_state_ptrs = conv_state_base + ((idx_x_w + state_len) * stride_conv_state_tok) # [BLOCK_N]
mask_w = (conv_state_batch_coord < num_cache_lines) & (idx_x_w < 0) & (idx_feats < dim)
conv_state = tl.load(conv_state_ptrs, mask_w, 0.0)
matrix_x = tl.load(x_ptrs, mask=mask_x, other=conv_state)

w_ptrs = w_base + (j * stride_weight_width) # [BLOCK_N] tensor
mask_w = idx_feats < dim
matrix_w = tl.load(w_ptrs, mask_w, other=0.0)
acc += matrix_x * matrix_w # [BLOCK_N]

if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
mask = (idx_token < seqlen) & (idx_feats < dim) # sequence-index # token-index # feature-index
o_ptrs = o_ptr + (idx_seq * stride_o_seq) + (idx_token * stride_o_token) + (idx_feats * stride_o_dim)
tl.store(o_ptrs, acc, mask=mask)

if IS_CIRCULAR_BUFFER:
# TODO:
assert 0
else:
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]

conv_state_ptrs_source = (
conv_state_ptr
+ (conv_state_batch_coord * stride_conv_state_seq)
+ (idx_feats * stride_conv_state_dim)[None, :]
+ ((idx_tokens + seqlen) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = (
(conv_state_batch_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :]
)
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)

VAL = state_len - seqlen
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim)[None, :] # [1, BLOCK_N]

x_ptrs = x_base + ((idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]

mask_x = (
(idx_tokens - VAL >= 0)[:, None] & (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier()

new_conv_state = tl.where(mask, conv_state, loaded_x)
conv_state_ptrs_target = conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
tl.store(conv_state_ptrs_target, new_conv_state, mask)


def causal_conv1d_update(
x,
conv_state,
weight,
bias=None,
activation: Optional[Literal["silu", "swish"]] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
pad_slot_id: int = None,
):
"""
x: (batch, dim) or (batch, dim, seqlen)
new tokens whose causal-conv-1d need to be computed
conv_state: (..., dim, state_len), where state_len >= width - 1
(function as `init_state` in causal_conv1d_fn API)
hold the previous `state_len` tokens that we can use to compute causal-conv-1d of new tokens from 'x'
* if `conv_sate_indices` is provided: behave like continuous batching mode
* if `cache_seqlens` is provided: also behave like a circular buffer
==============
[in standard batching, naturally we expect conv_state[i] is used for x[i] with i is sequence-index
[in continuous batching, the corresponding prior data for sequence x[i] is
NOT NECESSARY from conv_state[i];
BUT CAN BE from conv_state[conv_state_indices[i]]
given i=batch_id=sequence_id
IN OTHER WORDS: conv_state[j] | x[i]
with j = i [if conv_state_indices is NOne
with j = conv_state_indices[i] otherwise
]
[NOTE: can be used as a circular buffer if `cache_seqlens` is provided]
weight: (dim, width)
(causal) 1d conv kernel
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
[ PRIOR:
i.e. [conv_state[j][k] | x[i][0] ]
]
Hold the token-index (3rd axis) in the `conv_state` where we ...
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If present, then it is used to extract the row in `conv_state` to be used with corresponding sequence x[i]
i.e. for the given sequence i-th, and j-th is the index in `conv_state` where to get data to combine with x[i] for computing causal-1d-conv
j = i if conv_state_indices is None
j = conv_state_indices[i] otherwise
i.e. [conv_state[j] | x[i] ]
Useful for a continuous batching scenario.
pad_slot_id: int | None
If used, the constant value that we can use to compare with conv_state_indices[i], if
conv_state_indices[i] == pad_slot_id, then we ignore data from that row of conv_state[conv_state_indices[i]]

out: (batch, dim) or (batch, dim, seqlen)
"""
unsqueeze = x.dim() == 2
if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
_, width = weight.shape

# conv_state: (..., dim, state_len), where state_len >= width - 1
state_len = conv_state.size(2)
assert state_len >= width - 1
assert dim == conv_state.size(1)
if conv_state_indices is None:
assert conv_state.size(0) >= batch
else:
assert (batch,) == conv_state_indices.shape
num_cache_lines = conv_state.size(0)

stride_w_dim = weight.stride(0)
stride_w_width = weight.stride(1)

def grid(META):
return (
batch,
triton.cdiv(dim, META["BLOCK_N"]),
)

assert cache_seqlens is None # TUAN: FOR NOW (not needed for vLLM) - circular buffer # fmt:off
out = torch.empty_like(x)
with torch.cuda.device(x.device.index):
_causal_conv1d_update_kernel[grid](
# Pointers to matrices
x,
weight,
bias,
conv_state,
cache_seqlens,
conv_state_indices,
out,
# Matrix dimensions
batch,
dim,
seqlen,
state_len,
num_cache_lines,
# stride
x.stride(0), # X (batch, dim, seqlen)
x.stride(1),
x.stride(2),
stride_w_dim,
stride_w_width,
conv_state.stride(0),
conv_state.stride(1),
conv_state.stride(2),
out.stride(0),
out.stride(1),
out.stride(2),
# others
pad_slot_id,
# META
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
IS_CIRCULAR_BUFFER=cache_seqlens is not None,
NP2_STATELEN=triton.next_power_of_2(state_len),
USE_PAD_SLOT=pad_slot_id is not None,
)
if unsqueeze:
out = out.squeeze(-1)
return out


def causal_conv1d_update_vllm(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[Literal["silu", "swish"]] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim) or (batch, dim, seqlen)
[shape=2: single token prediction]
[shape=3: multiple tokens prediction]
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
assert cache_seqlens is None
# TODO : adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
o = causal_conv1d_update(
x,
conv_state,
weight,
bias=bias,
activation=activation,
cache_seqlens=cache_seqlens,
conv_state_indices=conv_state_indices,
pad_slot_id=pad_slot_id,
)
return o