Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Multi-head Latent Attention (MLA) with Fast Path for Short Sequences #684

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 55 additions & 70 deletions inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,6 @@ class ModelArgs:


class ParallelEmbedding(nn.Module):
"""
Embedding layer with parallelism support across distributed processes.

Args:
vocab_size (int): Vocabulary size.
dim (int): Embedding dimension.
"""
def __init__(self, vocab_size: int, dim: int):
super().__init__()
self.vocab_size = vocab_size
Expand All @@ -103,18 +96,6 @@ def __init__(self, vocab_size: int, dim: int):
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for parallel embedding layer.

Args:
x (torch.Tensor): Input tensor containing token indices.

Returns:
torch.Tensor: Embedded representations.

Raises:
ValueError: If `world_size` is not defined.
"""
if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx
Expand Down Expand Up @@ -162,15 +143,6 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =


class Linear(nn.Module):
"""
Custom linear layer with support for quantized weights and optional bias.

Args:
in_features (int): Number of input features.
out_features (int): Number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
dtype = torch.bfloat16

def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
Expand All @@ -190,15 +162,6 @@ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtyp
self.register_parameter("bias", None)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the custom linear layer.

Args:
x (torch.Tensor): Input tensor.

Returns:
torch.Tensor: Transformed tensor after linear computation.
"""
return linear(x, self.weight, self.bias)


Expand Down Expand Up @@ -440,7 +403,7 @@ def __init__(self, args: ModelArgs):
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
"""
Forward pass for the Multi-Headed Attention Layer (MLA).

Expand All @@ -453,45 +416,67 @@ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask
Returns:
torch.Tensor: Output tensor with the same shape as the input.
"""
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
bsz, seqlen, _ = x.shape

# Fast path for short sequences without masks
use_fast_path = seqlen <= 256 and mask is None

if self.q_lora_rank == 0:
q = self.wq(x)
else:
q = self.wq_b(self.q_norm(self.wq_a(x)))

kv_out = self.wkv_a(x)
kv_pe, kv_in = kv_out[:, :, :self.qk_rope_head_dim], kv_out[:, :, self.qk_rope_head_dim:]
kv_in = self.wkv_b(self.kv_norm(kv_in))
k_nope, v = kv_in[:, :, :self.n_local_heads*self.qk_nope_head_dim], kv_in[:, :, self.n_local_heads*self.qk_nope_head_dim:]

q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
kv = self.wkv_a(x)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
k_nope = k_nope.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.v_head_dim)

q_rope, q_nope = q[:, :, :, :self.qk_rope_head_dim], q[:, :, :, self.qk_rope_head_dim:]
k_rope = kv_pe.view(bsz, seqlen, self.n_local_heads, self.qk_rope_head_dim)

if attn_impl == "naive":
q = torch.cat([q_nope, q_pe], dim=-1)
kv = self.wkv_b(self.kv_norm(kv))
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
self.k_cache[:bsz, start_pos:end_pos] = k
self.v_cache[:bsz, start_pos:end_pos] = v
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
self.k_cache[: bsz, start_pos: start_pos + seqlen] = torch.cat([k_rope, k_nope], dim=-1)
self.v_cache[: bsz, start_pos: start_pos + seqlen] = v
k = self.k_cache[: bsz, : start_pos + seqlen]
v = self.v_cache[: bsz, : start_pos + seqlen]
else:
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
if mask is not None:
scores += mask.unsqueeze(1)
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
if attn_impl == "naive":
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
self.kv_cache[: bsz, start_pos: start_pos + seqlen] = kv_in
self.pe_cache[: bsz, start_pos: start_pos + seqlen] = kv_pe
k = torch.cat([k_rope, k_nope], dim=-1)

q = apply_rotary_emb(q_rope, freqs_cis)
k = apply_rotary_emb(k_rope, freqs_cis)

if use_fast_path:
# Optimized path for short sequences
q = q.transpose(1, 2) # [bsz, n_local_heads, seqlen, head_dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Single matmul for attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
scores = F.softmax(scores, dim=-1, dtype=torch.float32)

# Single matmul for output computation
output = torch.matmul(scores, v)
else:
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
x = self.wo(x.flatten(2))
return x
# Standard path for longer sequences or when mask is needed
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
if mask is not None:
scores = scores + mask
scores = F.softmax(scores, dim=-1, dtype=torch.float32)
output = torch.matmul(scores, v)

output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)


class MLP(nn.Module):
Expand Down Expand Up @@ -757,7 +742,7 @@ def __init__(self, args: ModelArgs):
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
super().__init__()
self.max_seq_len = args.max_seq_len
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
self.embed = ParallelEmbedding(args.vocab_size, args.dim, memory_efficient=True)
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(Block(layer_id, args))
Expand Down