diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index f535abedea2..4c653e1a81d 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -34,3 +34,15 @@ def test_sampler_different(model_name: str): sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64) output2 = llm.generate(prompts, sampling_params) assert output[0].outputs[0].text != output2[0].outputs[0].text + + # Batch-case with TopK + for B in [4, 16]: + p = prompts * B + sampling_params = [ + SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64, top_k=12) + ] * B + # disable on first prompt to check top k handles it + sampling_params[0].top_k = -1 + sampling_params[0].min_p = 0 + output = llm.generate(p, sampling_params) + assert output[0].outputs[0].text != output[-1].outputs[0].text diff --git a/vllm/envs.py b/vllm/envs.py index 8a03ba329b0..b54bb527c6c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -103,7 +103,6 @@ VLLM_DP_MASTER_PORT: int = 0 VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False - VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 @@ -671,11 +670,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_V0_USE_OUTLINES_CACHE": lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", - # If set, disables TPU-specific optimization for top-k & top-p sampling - "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION": - lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"])) - if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None, - # Gap between padding buckets for the forward pass. So we have # 8, we will run forward pass with [16, 24, 32, ...]. "VLLM_TPU_BUCKET_PADDING_GAP": diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 5dfcae08b17..1aeea110765 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -72,14 +72,7 @@ def __init__(self): "best performance, please install FlashInfer.") self.forward = self.forward_native elif current_platform.is_tpu(): - if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: - logger.warning( - "TPU-specific optimization for top-k & top-p sampling are " - "disabled, falling back to PyTorch-native implementation " - "which could be very slow.") - self.forward = self.forward_native - else: - self.forward = self.forward_tpu + self.forward = self.forward_tpu else: self.forward = self.forward_native @@ -125,16 +118,8 @@ def forward_tpu( # If only top-k is specified, use pytorch's builtin topk op. This leads # to significant speed up on TPU compared to using apply_top_k_top_p. if k is not None and p is None: - topk_values, topk_indices = torch.topk(logits, k, dim=-1) - - mask = torch.ones_like(logits, dtype=torch.bool) - mask.scatter_(-1, topk_indices, False) - logits.masked_fill_(mask, float('-inf')) - else: - # TODO Placeholder for TPU optimized topp kernel - # logits = apply_top_k_top_p(logits, k, p) - pass - + logits = apply_top_k_only(logits, k) + # TODO Add TPU optimized topp kernel and topk+topp probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators) @@ -200,6 +185,7 @@ def apply_top_k_only( # topk.values tensor has shape [batch_size, max_top_k]. # Convert top k to 0-based index in range [0, max_top_k). k_index = k.sub_(1).unsqueeze(1) + # TODO B, max_top_k | here K is dynamic!! top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index) # Handle non-topk rows. top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 89d3ddf51d7..cbad3cdbba0 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -11,7 +11,7 @@ temperature=-1.0, min_p=0.0, # strictly disabled for now - # top_k=-1, + top_k=0, # top_p=0.0, # frequency_penalties=0.0, # presence_penalties=0.0, @@ -98,7 +98,8 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, DEFAULT_SAMPLING_PARAMS["temperature"]) # TODO Temporarily disabled until sampling options are enabled # copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p) - # copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k) + copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k, + DEFAULT_SAMPLING_PARAMS["top_k"]) copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p, DEFAULT_SAMPLING_PARAMS["min_p"]) @@ -114,7 +115,7 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, device=input_batch.device), # TODO enable more and avoid returning None values top_p=None, # input_batch.top_p[:padded_num_reqs], - top_k=None, # input_batch.top_k[:padded_num_reqs], + top_k=input_batch.top_k[:padded_num_reqs], min_p=input_batch.min_p[:padded_num_reqs], generators=input_batch.generators, indices_do_sample=indices_do_sample)