Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][TPU] Enable Top K #15489

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/v1/tpu/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,15 @@ def test_sampler_different(model_name: str):
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
output2 = llm.generate(prompts, sampling_params)
assert output[0].outputs[0].text != output2[0].outputs[0].text

# Batch-case with TopK
for B in [4, 16]:
p = prompts * B
sampling_params = [
SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64, top_k=12)
] * B
# disable on first prompt to check top k handles it
sampling_params[0].top_k = -1
sampling_params[0].min_p = 0
output = llm.generate(p, sampling_params)
assert output[0].outputs[0].text != output[-1].outputs[0].text
6 changes: 0 additions & 6 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@
VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0


Expand Down Expand Up @@ -671,11 +670,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",

# If set, disables TPU-specific optimization for top-k & top-p sampling
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,

# Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP":
Expand Down
22 changes: 4 additions & 18 deletions vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,7 @@ def __init__(self):
"best performance, please install FlashInfer.")
self.forward = self.forward_native
elif current_platform.is_tpu():
if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
logger.warning(
"TPU-specific optimization for top-k & top-p sampling are "
"disabled, falling back to PyTorch-native implementation "
"which could be very slow.")
self.forward = self.forward_native
else:
self.forward = self.forward_tpu
self.forward = self.forward_tpu
else:
self.forward = self.forward_native

Expand Down Expand Up @@ -125,16 +118,8 @@ def forward_tpu(
# If only top-k is specified, use pytorch's builtin topk op. This leads
# to significant speed up on TPU compared to using apply_top_k_top_p.
if k is not None and p is None:
topk_values, topk_indices = torch.topk(logits, k, dim=-1)

mask = torch.ones_like(logits, dtype=torch.bool)
mask.scatter_(-1, topk_indices, False)
logits.masked_fill_(mask, float('-inf'))
else:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass

logits = apply_top_k_only(logits, k)
# TODO Add TPU optimized topp kernel and topk+topp
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)

Expand Down Expand Up @@ -200,6 +185,7 @@ def apply_top_k_only(
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1)
# TODO B, max_top_k | here K is dynamic!!
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index)
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
Expand Down
7 changes: 4 additions & 3 deletions vllm/v1/sample/tpu/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
temperature=-1.0,
min_p=0.0,
# strictly disabled for now
# top_k=-1,
top_k=0,
# top_p=0.0,
# frequency_penalties=0.0,
# presence_penalties=0.0,
Expand Down Expand Up @@ -98,7 +98,8 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
DEFAULT_SAMPLING_PARAMS["temperature"])
# TODO Temporarily disabled until sampling options are enabled
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k,
DEFAULT_SAMPLING_PARAMS["top_k"])
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
DEFAULT_SAMPLING_PARAMS["min_p"])

Expand All @@ -114,7 +115,7 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
device=input_batch.device),
# TODO enable more and avoid returning None values
top_p=None, # input_batch.top_p[:padded_num_reqs],
top_k=None, # input_batch.top_k[:padded_num_reqs],
top_k=input_batch.top_k[:padded_num_reqs],
min_p=input_batch.min_p[:padded_num_reqs],
generators=input_batch.generators,
indices_do_sample=indices_do_sample)