-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP]Chunked Prefill #188
base: main
Are you sure you want to change the base?
[WIP]Chunked Prefill #188
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,6 +97,7 @@ | |
from jetstream.engine import engine_api, tokenizer_api, token_utils | ||
from jetstream.core.metrics.prometheus import JetstreamMetricsCollector | ||
import numpy as np | ||
import jax.numpy as jnp | ||
|
||
root = logging.getLogger() | ||
root.setLevel(logging.WARNING) | ||
|
@@ -498,26 +499,55 @@ def _process_prefill_content( | |
tokenizer: tokenizer_api.Tokenizer, | ||
is_bos: bool, | ||
max_prefill_length: int, | ||
) -> Tuple[jax.Array | np.ndarray, int]: | ||
chunked_prefill: bool = False, | ||
chunk_size: Optional[int] = None, | ||
) -> Tuple[jax.Array | np.ndarray, int, jax.Array| np.ndarray]: | ||
content = request.prefill_content | ||
if isinstance(content, str): | ||
# If it's text input, tokenize and pad the input. | ||
return tokenizer.encode( | ||
tokens, true_length = tokenizer.encode( | ||
content, | ||
is_bos=is_bos, | ||
max_prefill_length=max_prefill_length, | ||
jax_padding=self._jax_padding, | ||
) | ||
positions = jnp.expand_dims(jnp.arange(0, len(tokens), dtype=jnp.int32), 0) | ||
|
||
if chunked_prefill: | ||
return token_utils.chunk_and_pad_tokens( | ||
tokens[:true_length], | ||
tokenizer.bos_id, | ||
tokenizer.pad_id, | ||
is_bos=is_bos, | ||
max_prefill_length=max_prefill_length, | ||
chunk_size=chunk_size, | ||
jax_padding=self._jax_padding,) | ||
return tokens, true_length, positions | ||
|
||
else: | ||
if chunked_prefill: | ||
return token_utils.chunk_and_pad_tokens( | ||
content, | ||
tokenizer.bos_id, | ||
tokenizer.pad_id, | ||
is_bos=is_bos, | ||
max_prefill_length=max_prefill_length, | ||
chunk_size=chunk_size, | ||
jax_padding=self._jax_padding,) | ||
|
||
# If it's token input, pad the input. | ||
return token_utils.pad_tokens( | ||
tokens, true_length = token_utils.pad_tokens( | ||
content, | ||
tokenizer.bos_id, | ||
tokenizer.pad_id, | ||
is_bos=is_bos, | ||
max_prefill_length=max_prefill_length, | ||
jax_padding=self._jax_padding, | ||
) | ||
positions = jnp.expand_dims(jnp.arange(0, len(tokens), dtype=jnp.int32), 0) | ||
return tokens, true_length, positions | ||
|
||
|
||
|
||
def _prefill_thread(self, idx: int): | ||
"""Thread which runs in the background performing prefills.""" | ||
|
@@ -545,16 +575,48 @@ def _prefill_thread(self, idx: int): | |
is_bos, | ||
) | ||
# Tokenize and padding the text or token input. | ||
padded_tokens, true_length = self._process_prefill_content( | ||
request, tokenizer, is_bos, prefill_engine.max_prefill_length | ||
padded_tokens, true_length, positions = self._process_prefill_content( | ||
request, tokenizer, is_bos, prefill_engine.max_prefill_length, False, | ||
) | ||
|
||
# Compute new kv cache for the prefill_content. | ||
prefill_result, first_token = prefill_engine.prefill( | ||
# if chunked_prefill is used, | ||
if prefill_engine.use_chunked_prefill: | ||
padded_tokens, true_lengths, positions = self._process_prefill_content( | ||
request, tokenizer, is_bos, prefill_engine.max_prefill_length, prefill_engine.use_chunked_prefill, prefill_engine.chunk_size) | ||
prefill_result = None | ||
next_pos = 0 | ||
for chunk_num, _ in enumerate(padded_tokens): | ||
if prefill_result is None: | ||
jax.debug.print("calling chunked_prefill for {chunk_num}", chunk_num=chunk_num) | ||
prefill_result, first_token = prefill_engine.prefill(params=prefill_params, | ||
padded_tokens=padded_tokens[chunk_num], | ||
true_length=true_lengths[chunk_num], | ||
positions=positions[chunk_num], | ||
all_true_length=true_length, | ||
previous_chunk=prefill_result, | ||
) | ||
else: | ||
jax.debug.print("calling chunked_prefill for {chunk_num}", chunk_num=chunk_num) | ||
prefill_result, first_token = prefill_engine.prefill(params=prefill_params | {"cache": prefill_result["cache"]}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is "cache" supposed to represent KV cache from previous chunks so far? Can we rename it to "cache_so_far"? |
||
padded_tokens=padded_tokens[chunk_num], | ||
true_length=true_lengths[chunk_num], | ||
positions=positions[chunk_num], | ||
all_true_length=true_length, | ||
previous_chunk=prefill_result, | ||
) | ||
Comment on lines
+589
to
+606
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can get rid of forking to make the code more readable:
|
||
t_l_array = jnp.expand_dims(jnp.arange(0, chunk_num*prefill_engine.chunk_size + true_lengths[chunk_num]), 0) | ||
prefill_result['t_l_array'] = t_l_array | ||
Comment on lines
+607
to
+608
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't |
||
prefill_result['next_pos'] = jnp.full((1,1), next_pos + true_lengths[chunk_num], dtype=jnp.int32) | ||
next_pos = next_pos + true_lengths[chunk_num] | ||
|
||
else: | ||
# Compute new kv cache for the prefill_content. | ||
prefill_result, first_token = prefill_engine.prefill( | ||
params=prefill_params, | ||
padded_tokens=padded_tokens, | ||
true_length=true_length, | ||
) | ||
) | ||
|
||
request.prefill_result = prefill_result | ||
|
||
# put first token to detokenize queue | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import math | ||
from seqio.vocabularies import SentencePieceVocabulary | ||
from seqio.vocabularies import Vocabulary | ||
|
||
|
@@ -34,10 +35,8 @@ | |
ResultTokens = Any | ||
|
||
DEFAULT_PREFILL_BUCKETS = [ | ||
16, | ||
32, | ||
64, | ||
128, | ||
# 64, | ||
# 128, | ||
Comment on lines
+38
to
+39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should jit compile 128 bucket size |
||
256, | ||
512, | ||
1024, | ||
|
@@ -97,6 +96,65 @@ def tokenize_and_pad( | |
) | ||
return padded_tokens, true_length | ||
|
||
def chunk_and_pad_tokens(tokens, bos_id: int, | ||
pad_id: int, | ||
is_bos: bool = True, | ||
prefill_lengths: Optional[List[int]] = None, | ||
max_prefill_length: Optional[int] = None, | ||
chunk_size: Optional[int] = None, | ||
jax_padding: bool = True,) -> Tuple[List[Union[jax.Array, np.ndarray]], List[Union[jax.Array, np.ndarray]], List[Union[jax.Array, np.ndarray]]]: | ||
"""Chunks and pads tokens for chunked prefill | ||
if total token size is 520 and chunk size is 256, | ||
the function will return 3 chunks and return tuple is as follows- | ||
[[t0,..t255][t256,..t511][t512,..t519...(padding)]], | ||
[256, 256, 7+padding], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: the true lengths returned should be |
||
[[0,..255],[256,..511],[512..518..]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
|
||
Args: | ||
tokens: Tokens. | ||
bos_id: Bos ID. | ||
pad_id: Pad ID. | ||
is_bos: Add a beginning of sequence token if this is ture. | ||
prefill_lengths: Buckets to pad the sequence to for static compilation. | ||
max_prefill_length: Maximum bucket to use. | ||
chunk_size: maximum size of each chunk | ||
jax_padding: convert to JAX padded tokens if True. | ||
|
||
Returns: | ||
chunk_padded_tokens: List of chunked and padded tokens. | ||
padded_chunk_true_lengths: List of integers - true length of each chunk | ||
positions:list of position of each token in the chunk | ||
""" | ||
|
||
num_tokens = len(tokens) | ||
num_chunks =int(math.ceil(num_tokens/chunk_size)) | ||
# every entry in chunk_padded_tokens is a padded chunk | ||
chunk_padded_tokens = [] | ||
|
||
# true lengths for each chunk | ||
padded_chunk_true_lengths = [] | ||
|
||
# positions of tokens in each chunk | ||
positions = [] | ||
# to be able to slice the tokens | ||
tokens = jnp.array(tokens) | ||
for chunk_num in range(num_chunks): | ||
start = int(chunk_num*chunk_size) | ||
end = jnp.minimum((chunk_num+1)*chunk_size, num_tokens) | ||
chunk_tokens = jax.lax.slice(tokens, (start,), (end,)) | ||
if chunk_num == 0: | ||
padded_chunk, padded_chunk_true_length = pad_tokens(chunk_tokens, bos_id,pad_id, is_bos, prefill_lengths, max_prefill_length, jax_padding) | ||
else: | ||
# is_bos should be false in subsequent chunks. | ||
padded_chunk, padded_chunk_true_length = pad_tokens(chunk_tokens, bos_id,pad_id, False, prefill_lengths, max_prefill_length, jax_padding) | ||
|
||
positions_chunk = jnp.expand_dims(jnp.arange(start, start+len(padded_chunk), dtype=jnp.int32), 0) | ||
chunk_padded_tokens.append(padded_chunk) | ||
padded_chunk_true_lengths.append(padded_chunk_true_length) | ||
positions.append(positions_chunk) | ||
|
||
return chunk_padded_tokens, padded_chunk_true_lengths, positions | ||
|
||
|
||
def pad_tokens( | ||
tokens: np.ndarray, | ||
|
@@ -157,7 +215,6 @@ def pad_tokens( | |
padded_tokens = jnp.array(padded_tokens) | ||
return padded_tokens, true_length | ||
|
||
|
||
def process_result_tokens( | ||
tokenizer: tokenizer_api.Tokenizer, | ||
slot: int, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete? or use log.debug()