Skip to content

Commit 4759b89

Browse files
committed
sync
Signed-off-by: NickLucche <[email protected]>
1 parent 85ad5cc commit 4759b89

File tree

2 files changed

+1
-74
lines changed

2 files changed

+1
-74
lines changed

tests/v1/tpu/test_sampler.py

-57
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
import tempfile
3-
from time import time
42

53
import pytest
64

@@ -15,61 +13,6 @@
1513
)
1614

1715

18-
# TODO remove this test once VLLM_XLA_CHECK_RECOMPILATION does not error out
19-
@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"])
20-
@pytest.mark.skipif(not current_platform.is_tpu(),
21-
reason="This test needs a TPU")
22-
def test_sampler_compilation(model_name: str, monkeypatch):
23-
"""
24-
Check that no recompilation happens despite changing sampling parameters.
25-
We can't read XLA metrics from the engine process, hence we measure time.
26-
"""
27-
with tempfile.TemporaryDirectory() as temp_dir:
28-
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir)
29-
# Compiling model init may still take some time, enforce_eager to skip.
30-
llm = LLM(model_name,
31-
enforce_eager=True,
32-
max_num_seqs=16,
33-
max_model_len=1024,
34-
gpu_memory_utilization=0.5)
35-
prompts = [
36-
"A robot may not injure a human being",
37-
"It is only with the heart that one can see rightly;",
38-
]
39-
# First inference should be slow
40-
sampling_params = SamplingParams(
41-
temperature=0.7,
42-
# top_p=0.6, # TODO too slow!
43-
top_k=10,
44-
min_p=0.2,
45-
max_tokens=16)
46-
s = time()
47-
_ = llm.generate(prompts, sampling_params)
48-
run1 = time() - s
49-
50-
# Second request with different params, but for which we
51-
# compiled for in previous eager iteration.
52-
sampling_params = SamplingParams(temperature=0.1,
53-
top_k=12,
54-
min_p=0.8,
55-
max_tokens=24)
56-
s = time()
57-
_ = llm.generate(prompts, sampling_params)
58-
run2 = time() - s
59-
# Much faster after compiling
60-
assert run1 * 0.1 > run2
61-
print("TIMES", run1, run2)
62-
63-
# Third request with min_p set to "None". It will not trigger
64-
# recompilation as a default 0 value will be used.
65-
sampling_params = SamplingParams(max_tokens=24, temperature=0.0)
66-
s = time()
67-
_ = llm.generate(prompts, sampling_params)
68-
run3 = time() - s
69-
assert run1 * 0.1 > run3
70-
print("TIMES", run1, run3)
71-
72-
7316
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
7417
@pytest.mark.skipif(not current_platform.is_tpu(),
7518
reason="This test needs a TPU")

vllm/v1/sample/ops/topk_topp_sampler.py

+1-17
Original file line numberDiff line numberDiff line change
@@ -118,28 +118,12 @@ def forward_tpu(
118118
# If only top-k is specified, use pytorch's builtin topk op. This leads
119119
# to significant speed up on TPU compared to using apply_top_k_top_p.
120120
if k is not None and p is None:
121-
logits = top_k_only(logits, k)
121+
logits = apply_top_k_only(logits, k)
122122
# TODO Add TPU optimized topp kernel and topk+topp
123123
probs = logits.softmax(dim=-1, dtype=torch.float32)
124124
return random_sample(probs, generators)
125125

126126

127-
def top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
128-
# Avoid sorting vocab for top-k only case.
129-
no_top_k_mask = k == logits.shape[1]
130-
# Set non-top-k rows to 1 so that we can gather.
131-
k = k.masked_fill(no_top_k_mask, 1)
132-
max_top_k = k.max()
133-
# topk.values tensor has shape [batch_size, max_top_k].
134-
# Convert top k to 0-based index in range [0, max_top_k).
135-
k_index = k.sub_(1).unsqueeze(1)
136-
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index)
137-
# Handle non-topk rows.
138-
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
139-
logits.masked_fill_(logits < top_k_mask, -float("inf"))
140-
return logits
141-
142-
143127
def apply_top_k_top_p(
144128
logits: torch.Tensor,
145129
k: Optional[torch.Tensor],

0 commit comments

Comments
 (0)