Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

APC fully cached prefill moved to decode #944

Draft
wants to merge 8 commits into
base: habana_main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,9 +1302,18 @@ def _prepare_decode(
dummy_slots = itertools.cycle(
range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size))

is_apc = self.vllm_config.cache_config.enable_prefix_caching

for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
print(seq_group_metadata.token_chunk_size)
if seq_group_metadata.is_prompt and is_apc:
print("APC + fully cached")
seq_group_metadata.token_chunk_size = 1
seq_group_metadata.is_prompt = False
#seq_group_meta.computed_block_nums
else:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1

seq_ids = list(seq_group_metadata.seq_data.keys())
lora_id = seq_group_metadata.lora_int_id
Expand All @@ -1327,10 +1336,12 @@ def _prepare_decode(
if output is None:
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
print(input_tokens)

seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
print(input_positions)

if self.model_is_mrope:
if seq_data.mrope_position_delta is not None:
Expand Down Expand Up @@ -1547,6 +1558,18 @@ def _prepare_decode(
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
)

print("toks", input_tokens, "posits", input_positions, "cached shit", attn_metadata.block_list)

print(PrepareDecodeMetadata(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
slot_mapping=slot_mapping,
lora_ids=lora_ids))

return PrepareDecodeMetadata(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
Expand Down Expand Up @@ -1585,11 +1608,25 @@ def prepare_input_tensors(
prefill_reqs = []
decode_reqs = []
for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
if seq_group_meta.computed_block_nums is not None and len(
seq_group_meta.computed_block_nums
) > 0 and self.vllm_config.cache_config.enable_prefix_caching:
prefix_cached_len = len(
seq_group_meta.computed_block_nums) * self.block_size
seq_len = seq_group_meta.seq_data[list(
seq_group_meta.seq_data.keys())[0]].get_len()
is_seq_prompt = prefix_cached_len > seq_len
else:
is_seq_prompt = seq_group_meta.is_prompt
if is_seq_prompt:
prefill_reqs.append(seq_group_meta)
else:
#seq_group_meta.is_prompt = False
decode_reqs.append(seq_group_meta)

#import pdb; pdb.set_trace()
print("@@@@@@@@@@", len(prefill_reqs), len(decode_reqs))

# Prepare input tensors.
(
input_tokens,
Expand Down
Loading