Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Chunked Prefill #188

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ async def send_request(
"""Send the request to JetStream server."""
# Tokenize on client side following MLPerf standard.
token_ids = tokenizer.encode(input_request.prompt)
print("len token_ids ", len(token_ids))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete? or use log.debug()


# Send the request
request = jetstream_pb2.DecodeRequest(
Expand Down
78 changes: 70 additions & 8 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from jetstream.engine import engine_api, tokenizer_api, token_utils
from jetstream.core.metrics.prometheus import JetstreamMetricsCollector
import numpy as np
import jax.numpy as jnp

root = logging.getLogger()
root.setLevel(logging.WARNING)
Expand Down Expand Up @@ -498,26 +499,55 @@ def _process_prefill_content(
tokenizer: tokenizer_api.Tokenizer,
is_bos: bool,
max_prefill_length: int,
) -> Tuple[jax.Array | np.ndarray, int]:
chunked_prefill: bool = False,
chunk_size: Optional[int] = None,
) -> Tuple[jax.Array | np.ndarray, int, jax.Array| np.ndarray]:
content = request.prefill_content
if isinstance(content, str):
# If it's text input, tokenize and pad the input.
return tokenizer.encode(
tokens, true_length = tokenizer.encode(
content,
is_bos=is_bos,
max_prefill_length=max_prefill_length,
jax_padding=self._jax_padding,
)
positions = jnp.expand_dims(jnp.arange(0, len(tokens), dtype=jnp.int32), 0)

if chunked_prefill:
return token_utils.chunk_and_pad_tokens(
tokens[:true_length],
tokenizer.bos_id,
tokenizer.pad_id,
is_bos=is_bos,
max_prefill_length=max_prefill_length,
chunk_size=chunk_size,
jax_padding=self._jax_padding,)
return tokens, true_length, positions

else:
if chunked_prefill:
return token_utils.chunk_and_pad_tokens(
content,
tokenizer.bos_id,
tokenizer.pad_id,
is_bos=is_bos,
max_prefill_length=max_prefill_length,
chunk_size=chunk_size,
jax_padding=self._jax_padding,)

# If it's token input, pad the input.
return token_utils.pad_tokens(
tokens, true_length = token_utils.pad_tokens(
content,
tokenizer.bos_id,
tokenizer.pad_id,
is_bos=is_bos,
max_prefill_length=max_prefill_length,
jax_padding=self._jax_padding,
)
positions = jnp.expand_dims(jnp.arange(0, len(tokens), dtype=jnp.int32), 0)
return tokens, true_length, positions



def _prefill_thread(self, idx: int):
"""Thread which runs in the background performing prefills."""
Expand Down Expand Up @@ -545,16 +575,48 @@ def _prefill_thread(self, idx: int):
is_bos,
)
# Tokenize and padding the text or token input.
padded_tokens, true_length = self._process_prefill_content(
request, tokenizer, is_bos, prefill_engine.max_prefill_length
padded_tokens, true_length, positions = self._process_prefill_content(
request, tokenizer, is_bos, prefill_engine.max_prefill_length, False,
)

# Compute new kv cache for the prefill_content.
prefill_result, first_token = prefill_engine.prefill(
# if chunked_prefill is used,
if prefill_engine.use_chunked_prefill:
padded_tokens, true_lengths, positions = self._process_prefill_content(
request, tokenizer, is_bos, prefill_engine.max_prefill_length, prefill_engine.use_chunked_prefill, prefill_engine.chunk_size)
prefill_result = None
next_pos = 0
for chunk_num, _ in enumerate(padded_tokens):
if prefill_result is None:
jax.debug.print("calling chunked_prefill for {chunk_num}", chunk_num=chunk_num)
prefill_result, first_token = prefill_engine.prefill(params=prefill_params,
padded_tokens=padded_tokens[chunk_num],
true_length=true_lengths[chunk_num],
positions=positions[chunk_num],
all_true_length=true_length,
previous_chunk=prefill_result,
)
else:
jax.debug.print("calling chunked_prefill for {chunk_num}", chunk_num=chunk_num)
prefill_result, first_token = prefill_engine.prefill(params=prefill_params | {"cache": prefill_result["cache"]},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "cache" supposed to represent KV cache from previous chunks so far? Can we rename it to "cache_so_far"?

padded_tokens=padded_tokens[chunk_num],
true_length=true_lengths[chunk_num],
positions=positions[chunk_num],
all_true_length=true_length,
previous_chunk=prefill_result,
)
Comment on lines +589 to +606
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can get rid of forking to make the code more readable:

cache_so_far = {} if prefill_result is None else {"cache_so_far": prefill_result["cache"]}
prefill_result, first_token = prefill_engine.prefill(params=prefill_params | cache_so_far, ....)
...

t_l_array = jnp.expand_dims(jnp.arange(0, chunk_num*prefill_engine.chunk_size + true_lengths[chunk_num]), 0)
prefill_result['t_l_array'] = t_l_array
Comment on lines +607 to +608
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't t_l_array same as positions array? where is this used?

prefill_result['next_pos'] = jnp.full((1,1), next_pos + true_lengths[chunk_num], dtype=jnp.int32)
next_pos = next_pos + true_lengths[chunk_num]

else:
# Compute new kv cache for the prefill_content.
prefill_result, first_token = prefill_engine.prefill(
params=prefill_params,
padded_tokens=padded_tokens,
true_length=true_length,
)
)

request.prefill_result = prefill_result

# put first token to detokenize queue
Expand Down
67 changes: 62 additions & 5 deletions jetstream/engine/token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import jax
import jax.numpy as jnp
import numpy as np
import math
from seqio.vocabularies import SentencePieceVocabulary
from seqio.vocabularies import Vocabulary

Expand All @@ -34,10 +35,8 @@
ResultTokens = Any

DEFAULT_PREFILL_BUCKETS = [
16,
32,
64,
128,
# 64,
# 128,
Comment on lines +38 to +39
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should jit compile 128 bucket size

256,
512,
1024,
Expand Down Expand Up @@ -97,6 +96,65 @@ def tokenize_and_pad(
)
return padded_tokens, true_length

def chunk_and_pad_tokens(tokens, bos_id: int,
pad_id: int,
is_bos: bool = True,
prefill_lengths: Optional[List[int]] = None,
max_prefill_length: Optional[int] = None,
chunk_size: Optional[int] = None,
jax_padding: bool = True,) -> Tuple[List[Union[jax.Array, np.ndarray]], List[Union[jax.Array, np.ndarray]], List[Union[jax.Array, np.ndarray]]]:
"""Chunks and pads tokens for chunked prefill
if total token size is 520 and chunk size is 256,
the function will return 3 chunks and return tuple is as follows-
[[t0,..t255][t256,..t511][t512,..t519...(padding)]],
[256, 256, 7+padding],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the true lengths returned should be [256, 256, 7] (no padding)

[[0,..255],[256,..511],[512..518..]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: [512..519..] (ends at 519)


Args:
tokens: Tokens.
bos_id: Bos ID.
pad_id: Pad ID.
is_bos: Add a beginning of sequence token if this is ture.
prefill_lengths: Buckets to pad the sequence to for static compilation.
max_prefill_length: Maximum bucket to use.
chunk_size: maximum size of each chunk
jax_padding: convert to JAX padded tokens if True.

Returns:
chunk_padded_tokens: List of chunked and padded tokens.
padded_chunk_true_lengths: List of integers - true length of each chunk
positions:list of position of each token in the chunk
"""

num_tokens = len(tokens)
num_chunks =int(math.ceil(num_tokens/chunk_size))
# every entry in chunk_padded_tokens is a padded chunk
chunk_padded_tokens = []

# true lengths for each chunk
padded_chunk_true_lengths = []

# positions of tokens in each chunk
positions = []
# to be able to slice the tokens
tokens = jnp.array(tokens)
for chunk_num in range(num_chunks):
start = int(chunk_num*chunk_size)
end = jnp.minimum((chunk_num+1)*chunk_size, num_tokens)
chunk_tokens = jax.lax.slice(tokens, (start,), (end,))
if chunk_num == 0:
padded_chunk, padded_chunk_true_length = pad_tokens(chunk_tokens, bos_id,pad_id, is_bos, prefill_lengths, max_prefill_length, jax_padding)
else:
# is_bos should be false in subsequent chunks.
padded_chunk, padded_chunk_true_length = pad_tokens(chunk_tokens, bos_id,pad_id, False, prefill_lengths, max_prefill_length, jax_padding)

positions_chunk = jnp.expand_dims(jnp.arange(start, start+len(padded_chunk), dtype=jnp.int32), 0)
chunk_padded_tokens.append(padded_chunk)
padded_chunk_true_lengths.append(padded_chunk_true_length)
positions.append(positions_chunk)

return chunk_padded_tokens, padded_chunk_true_lengths, positions


def pad_tokens(
tokens: np.ndarray,
Expand Down Expand Up @@ -157,7 +215,6 @@ def pad_tokens(
padded_tokens = jnp.array(padded_tokens)
return padded_tokens, true_length


def process_result_tokens(
tokenizer: tokenizer_api.Tokenizer,
slot: int,
Expand Down
Loading