Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Chunked Prefill #188

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

mailvijayasingh
Copy link
Collaborator

If chunked prefill is set to True,
Chunk and appropriately pad the tokens
Call prefill for each chunk

@mailvijayasingh mailvijayasingh changed the title Chunked Prefill [WIP]Chunked Prefill Feb 13, 2025
@@ -518,6 +518,7 @@ async def send_request(
"""Send the request to JetStream server."""
# Tokenize on client side following MLPerf standard.
token_ids = tokenizer.encode(input_request.prompt)
print("len token_ids ", len(token_ids))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete? or use log.debug()

Comment on lines +38 to +39
# 64,
# 128,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should jit compile 128 bucket size

)
else:
jax.debug.print("calling chunked_prefill for {chunk_num}", chunk_num=chunk_num)
prefill_result, first_token = prefill_engine.prefill(params=prefill_params | {"cache": prefill_result["cache"]},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "cache" supposed to represent KV cache from previous chunks so far? Can we rename it to "cache_so_far"?

Comment on lines +589 to +606
if prefill_result is None:
jax.debug.print("calling chunked_prefill for {chunk_num}", chunk_num=chunk_num)
prefill_result, first_token = prefill_engine.prefill(params=prefill_params,
padded_tokens=padded_tokens[chunk_num],
true_length=true_lengths[chunk_num],
positions=positions[chunk_num],
all_true_length=true_length,
previous_chunk=prefill_result,
)
else:
jax.debug.print("calling chunked_prefill for {chunk_num}", chunk_num=chunk_num)
prefill_result, first_token = prefill_engine.prefill(params=prefill_params | {"cache": prefill_result["cache"]},
padded_tokens=padded_tokens[chunk_num],
true_length=true_lengths[chunk_num],
positions=positions[chunk_num],
all_true_length=true_length,
previous_chunk=prefill_result,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can get rid of forking to make the code more readable:

cache_so_far = {} if prefill_result is None else {"cache_so_far": prefill_result["cache"]}
prefill_result, first_token = prefill_engine.prefill(params=prefill_params | cache_so_far, ....)
...

if total token size is 520 and chunk size is 256,
the function will return 3 chunks and return tuple is as follows-
[[t0,..t255][t256,..t511][t512,..t519...(padding)]],
[256, 256, 7+padding],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the true lengths returned should be [256, 256, 7] (no padding)

the function will return 3 chunks and return tuple is as follows-
[[t0,..t255][t256,..t511][t512,..t519...(padding)]],
[256, 256, 7+padding],
[[0,..255],[256,..511],[512..518..]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: [512..519..] (ends at 519)

Comment on lines +607 to +608
t_l_array = jnp.expand_dims(jnp.arange(0, chunk_num*prefill_engine.chunk_size + true_lengths[chunk_num]), 0)
prefill_result['t_l_array'] = t_l_array
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't t_l_array same as positions array? where is this used?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants