-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cherry-pick of "Selective merged prefill #643" #893
base: habana_main
Are you sure you want to change the base?
Conversation
8356d0f
to
4f64e8a
Compare
@@ -227,6 +228,33 @@ def find_rope_layer(parent, path): | |||
return path_to_rope | |||
|
|||
|
|||
class HPUBucketingContextWithMergedPrefill(HPUBucketingContext): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be a part of https://github.com/HabanaAI/vllm-hpu-extension/blob/main/vllm_hpu_extension/bucketing.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
self.max_num_prefill_seqs, | ||
self.block_size, | ||
self.max_num_batched_tokens) | ||
self.enable_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to add this flag to the README with small explanation.
def prompt_fsdpa( | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
attn_bias: Optional[torch.Tensor] = None, | ||
p: float = 0.0, | ||
scale: Optional[float] = None, | ||
matmul_qk_op=torch.matmul, | ||
softmax_op=torch.softmax, | ||
matmul_av_op=torch.matmul, | ||
valid_seq_lengths: Optional[torch.Tensor] = None, | ||
fsdpa_op=None, | ||
) -> torch.Tensor: | ||
query = query.transpose(1, 2) | ||
key = key.transpose(1, 2) | ||
value = value.transpose(1, 2) | ||
softmax_mode = 'fast' | ||
recompute_mode = True | ||
attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False, scale, | ||
softmax_mode, recompute_mode, None, 'right') | ||
attn_weights = attn_weights.transpose(1, 2) | ||
return attn_weights | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be in hpu-extension/ops.py where all our attn implementations are. What's the difference between this and other implementations? Is it only because is_causal is False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, duplicate this func to use 'is_causal' as False.
@@ -227,6 +228,33 @@ def find_rope_layer(parent, path): | |||
return path_to_rope | |||
|
|||
|
|||
class HPUBucketingContextWithMergedPrefill(HPUBucketingContext): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
origin_enable_merged_prefill = self.enable_merged_prefill | ||
self.enable_merged_prefill = False | ||
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, | ||
False, True) | ||
self.enable_merged_prefill = origin_enable_merged_prefill |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a code smell. Could we do it differently?
if computed_block_nums is not None and len( | ||
computed_block_nums) > 0 and self.sliding_window is None: | ||
# Prefix is not supported with sliding_window | ||
context_len = len(computed_block_nums) * self.block_size | ||
prompt_tokens = prompt_tokens[context_len:] | ||
prefix_block_tables.append(computed_block_nums) | ||
elif self.scheduler_config.chunked_prefill_enabled: | ||
if seq_group_metadata.block_tables is not None: | ||
# Prefill has chunked before. | ||
block_table = seq_group_metadata.block_tables[seq_id] | ||
prefix_block_tables.append(block_table) | ||
else: | ||
# The first prefill. | ||
prefix_block_tables.append([]) | ||
else: | ||
prefix_block_tables.append([]) | ||
# Right now, prefill start is always 0. However, this | ||
# assumption can be changed once chunked prefill is introduced. | ||
assert context_len == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we support prefix caching in merged prefill?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I didn't enable that, Yang and I had a discussion on how to add context_length as well, will need to re-work on attn_mask if we do so.
if self.sliding_window is not None: | ||
assert context_len == 0, ( | ||
"Prefix caching is currently not supported with " | ||
"sliding window attention") | ||
start_idx = max(0, seq_len - self.sliding_window) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we support sliding window attention in merged prefill?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no
dtype=torch.long, | ||
device='cpu') | ||
|
||
max_prefill_bs = int(os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', '8')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we already have a parameter in scheduler_config for that, max_num_prefill_seqs
if (self.scheduler_config is not None | ||
and self.scheduler_config.chunked_prefill_enabled | ||
and not (computed_block_nums is None | ||
or computed_block_nums == [])): | ||
raise RuntimeError( | ||
"chunked prefill cannot be used with prefix caching " | ||
"now.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't support chunked prefill at all
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is correct, there is this PR however HabanaAI/vllm-hpu-extension#94
sampling_metadata = None | ||
sampling_metadata.selected_token_indices = \ | ||
torch.cat((sampling_metadata.selected_token_indices, paddings), | ||
dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add the else portion? for the pooler, there is no sampling_medatdata
Cherry-pick of #643