Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick of "Selective merged prefill #643" #893

Open
wants to merge 13 commits into
base: habana_main
Choose a base branch
from

Conversation

kamil-kaczor
Copy link

Cherry-pick of #643

@kamil-kaczor kamil-kaczor force-pushed the cherrypick_merged_prefill branch from 8356d0f to 4f64e8a Compare March 10, 2025 13:33
@@ -227,6 +228,33 @@ def find_rope_layer(parent, path):
return path_to_rope


class HPUBucketingContextWithMergedPrefill(HPUBucketingContext):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

self.max_num_prefill_seqs,
self.block_size,
self.max_num_batched_tokens)
self.enable_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL',

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to add this flag to the README with small explanation.

@szutenberg szutenberg requested a review from xuechendi March 11, 2025 15:13
Comment on lines +36 to +58
def prompt_fsdpa(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
valid_seq_lengths: Optional[torch.Tensor] = None,
fsdpa_op=None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
softmax_mode = 'fast'
recompute_mode = True
attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False, scale,
softmax_mode, recompute_mode, None, 'right')
attn_weights = attn_weights.transpose(1, 2)
return attn_weights

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in hpu-extension/ops.py where all our attn implementations are. What's the difference between this and other implementations? Is it only because is_causal is False?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, duplicate this func to use 'is_causal' as False.

@@ -227,6 +228,33 @@ def find_rope_layer(parent, path):
return path_to_rope


class HPUBucketingContextWithMergedPrefill(HPUBucketingContext):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines +2047 to +2051
origin_enable_merged_prefill = self.enable_merged_prefill
self.enable_merged_prefill = False
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
False, True)
self.enable_merged_prefill = origin_enable_merged_prefill

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a code smell. Could we do it differently?

Comment on lines +1311 to +1329
if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window
context_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None:
# Prefill has chunked before.
block_table = seq_group_metadata.block_tables[seq_id]
prefix_block_tables.append(block_table)
else:
# The first prefill.
prefix_block_tables.append([])
else:
prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
assert context_len == 0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we support prefix caching in merged prefill?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I didn't enable that, Yang and I had a discussion on how to add context_length as well, will need to re-work on attn_mask if we do so.

Comment on lines +1360 to +1364
if self.sliding_window is not None:
assert context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention")
start_idx = max(0, seq_len - self.sliding_window)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we support sliding window attention in merged prefill?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no

dtype=torch.long,
device='cpu')

max_prefill_bs = int(os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', '8'))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have a parameter in scheduler_config for that, max_num_prefill_seqs

Comment on lines +1293 to +1299
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't support chunked prefill at all

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct, there is this PR however HabanaAI/vllm-hpu-extension#94

sampling_metadata = None
sampling_metadata.selected_token_indices = \
torch.cat((sampling_metadata.selected_token_indices, paddings),
dim=0)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add the else portion? for the pooler, there is no sampling_medatdata

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants