-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP]Chunked Prefill #188
base: main
Are you sure you want to change the base?
[WIP]Chunked Prefill #188
Conversation
@@ -518,6 +518,7 @@ async def send_request( | |||
"""Send the request to JetStream server.""" | |||
# Tokenize on client side following MLPerf standard. | |||
token_ids = tokenizer.encode(input_request.prompt) | |||
print("len token_ids ", len(token_ids)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete? or use log.debug()
# 64, | ||
# 128, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should jit compile 128 bucket size
) | ||
else: | ||
jax.debug.print("calling chunked_prefill for {chunk_num}", chunk_num=chunk_num) | ||
prefill_result, first_token = prefill_engine.prefill(params=prefill_params | {"cache": prefill_result["cache"]}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "cache" supposed to represent KV cache from previous chunks so far? Can we rename it to "cache_so_far"?
if prefill_result is None: | ||
jax.debug.print("calling chunked_prefill for {chunk_num}", chunk_num=chunk_num) | ||
prefill_result, first_token = prefill_engine.prefill(params=prefill_params, | ||
padded_tokens=padded_tokens[chunk_num], | ||
true_length=true_lengths[chunk_num], | ||
positions=positions[chunk_num], | ||
all_true_length=true_length, | ||
previous_chunk=prefill_result, | ||
) | ||
else: | ||
jax.debug.print("calling chunked_prefill for {chunk_num}", chunk_num=chunk_num) | ||
prefill_result, first_token = prefill_engine.prefill(params=prefill_params | {"cache": prefill_result["cache"]}, | ||
padded_tokens=padded_tokens[chunk_num], | ||
true_length=true_lengths[chunk_num], | ||
positions=positions[chunk_num], | ||
all_true_length=true_length, | ||
previous_chunk=prefill_result, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can get rid of forking to make the code more readable:
cache_so_far = {} if prefill_result is None else {"cache_so_far": prefill_result["cache"]}
prefill_result, first_token = prefill_engine.prefill(params=prefill_params | cache_so_far, ....)
...
if total token size is 520 and chunk size is 256, | ||
the function will return 3 chunks and return tuple is as follows- | ||
[[t0,..t255][t256,..t511][t512,..t519...(padding)]], | ||
[256, 256, 7+padding], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the true lengths returned should be [256, 256, 7]
(no padding)
the function will return 3 chunks and return tuple is as follows- | ||
[[t0,..t255][t256,..t511][t512,..t519...(padding)]], | ||
[256, 256, 7+padding], | ||
[[0,..255],[256,..511],[512..518..]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: [512..519..]
(ends at 519)
t_l_array = jnp.expand_dims(jnp.arange(0, chunk_num*prefill_engine.chunk_size + true_lengths[chunk_num]), 0) | ||
prefill_result['t_l_array'] = t_l_array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't t_l_array
same as positions array? where is this used?
If chunked prefill is set to True,
Chunk and appropriately pad the tokens
Call prefill for each chunk