Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick of "Selective merged prefill #643" #893

Open
wants to merge 13 commits into
base: habana_main
Choose a base branch
from
52 changes: 49 additions & 3 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,38 @@

logger = init_logger(__name__)

HPUFusedSDPA = None
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
HPUFusedSDPA = FusedSDPA
except ImportError:
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")


def prompt_fsdpa(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
valid_seq_lengths: Optional[torch.Tensor] = None,
fsdpa_op=None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
softmax_mode = 'fast'
recompute_mode = True
attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False, scale,
softmax_mode, recompute_mode, None, 'right')
attn_weights = attn_weights.transpose(1, 2)
return attn_weights

Comment on lines +36 to +58

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in hpu-extension/ops.py where all our attn implementations are. What's the difference between this and other implementations? Is it only because is_causal is False?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, duplicate this func to use 'is_causal' as False.


class HPUAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -78,6 +110,9 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]
enable_merged_prefill: bool = False
actual_num_prefills: Optional[torch.Tensor] = None
repeated_idx_tensor: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -214,10 +249,12 @@ def forward(
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
enable_merged_prefill = attn_metadata.enable_merged_prefill
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
attn_bias = attn_metadata.attn_bias
if attn_metadata.is_prompt and self.attn_type \
is not AttentionType.ENCODER_ONLY:
is not AttentionType.ENCODER_ONLY and not enable_merged_prefill:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None and isinstance(kv_cache, tuple):
Expand All @@ -244,19 +281,28 @@ def forward(
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
assert attn_bias is not None
assert attn_bias.dtype is not None
assert attn_bias.shape is not None
assert attn_bias.tile is not None
position_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads,
attn_bias.dtype, attn_bias.shape[-1])
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
elif enable_merged_prefill:
pass
else:
attn_bias = attn_metadata.attn_bias

if not self.prefill_use_flex_attention:
out = ops.prompt_attention(
if enable_merged_prefill and self.prefill_use_fusedsdpa:
prompt_attn_func = prompt_fsdpa
else:
prompt_attn_func = ops.prompt_attention
out = prompt_attn_func(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
Expand Down
Loading
Loading