Skip to content

Commit ef64044

Browse files
authoredMar 8, 2025··
[V1] Prompt logprobs + APC compatibility; prompt logprobs reqs cannot fill APC (vllm-project#13949)
1 parent 66e16a0 commit ef64044

9 files changed

+292
-162
lines changed
 

‎tests/v1/core/test_prefix_caching.py

+110-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Compare the with and without prefix caching."""
33

4+
from typing import Optional
5+
46
import pytest
57

68
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
@@ -15,7 +17,8 @@
1517
def make_request(request_id,
1618
prompt_token_ids,
1719
mm_positions=None,
18-
mm_hashes=None):
20+
mm_hashes=None,
21+
prompt_logprobs: Optional[int] = None):
1922
if mm_positions is None:
2023
multi_modal_inputs = None
2124
else:
@@ -28,7 +31,8 @@ def make_request(request_id,
2831
multi_modal_inputs=multi_modal_inputs,
2932
multi_modal_hashes=mm_hashes,
3033
multi_modal_placeholders=mm_positions,
31-
sampling_params=SamplingParams(max_tokens=17),
34+
sampling_params=SamplingParams(max_tokens=17,
35+
prompt_logprobs=prompt_logprobs),
3236
eos_token_id=100,
3337
arrival_time=0,
3438
lora_request=None,
@@ -144,6 +148,110 @@ def test_prefill():
144148
assert manager.block_pool.free_block_queue.free_list_tail is None
145149

146150

151+
def test_prefill_plp():
152+
'''Test prefill with APC and some prompt logprobs (plp) requests.
153+
154+
1. Schedule plp request and validate APC block allocation
155+
2. Schedule non-plp request and validate blocks
156+
3. Schedule plp request; no hit should occur; validate blocks
157+
'''
158+
manager = KVCacheManager(
159+
block_size=16,
160+
num_gpu_blocks=10,
161+
max_model_len=8192,
162+
sliding_window=None,
163+
enable_caching=True,
164+
num_preallocate_tokens=16,
165+
)
166+
167+
# Complete 3 blocks (48 tokens)
168+
common_token_ids = [i for i in range(3) for _ in range(16)]
169+
170+
# Request #0 is a prompt logprobs request
171+
# Fully cache miss
172+
# Incomplete 1 block (7 tokens)
173+
unique_token_ids = [3] * 7
174+
all_token_ids = common_token_ids + unique_token_ids
175+
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
176+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
177+
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
178+
assert not computed_blocks
179+
assert num_computed_tokens == 0
180+
blocks = manager.allocate_slots(req0, 55, computed_blocks)
181+
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
182+
req0_block_hashes = [b.block_hash for b in blocks]
183+
184+
# Check full block metadata
185+
parent_block_hash = None
186+
for block_id in (0, 1, 2):
187+
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
188+
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
189+
assert manager.block_pool.blocks[block_id].block_hash == block_hash
190+
assert manager.block_pool.blocks[block_id].ref_cnt == 1
191+
parent_block_hash = block_hash.hash_value
192+
193+
# Check partial/preallocated block metadata
194+
for block_id in (3, 4):
195+
assert manager.block_pool.blocks[block_id].block_hash is None
196+
assert manager.block_pool.blocks[block_id].ref_cnt == 1
197+
198+
# Request #1 is a non-prompt-logprobs request:
199+
# Cache hit in the common prefix when the original block is still in use.
200+
# Incomplete 1 block (5 tokens)
201+
unique_token_ids = [3] * 5
202+
req1 = make_request("1", common_token_ids + unique_token_ids)
203+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
204+
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
205+
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
206+
assert num_computed_tokens == 3 * 16
207+
num_new_tokens = 53 - 3 * 16
208+
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
209+
assert [b.block_id for b in blocks] == [5, 6]
210+
for block in computed_blocks:
211+
assert block.ref_cnt == 2
212+
213+
# At this point, we should have 3 free blocks left.
214+
assert manager.block_pool.free_block_queue.num_free_blocks == 3
215+
216+
manager.free(req0)
217+
manager.free(req1)
218+
219+
# All blocks should be available.
220+
assert manager.block_pool.free_block_queue.num_free_blocks == 10
221+
# The order should be
222+
# [unallocated (7, 8, 9)]
223+
# [unique_req0 (4, 3)]
224+
# [unique_req1 (6, 5)]
225+
# [common (2, 1, 0)]
226+
assert [
227+
b.block_id
228+
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
229+
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
230+
231+
# Request #2 is a prompt-logprobs request:
232+
# NO cache hit in the common prefix; duplicates request #0 cached blocks
233+
unique_token_ids = [3] * 6
234+
req2 = make_request("2",
235+
common_token_ids + unique_token_ids,
236+
prompt_logprobs=5)
237+
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
238+
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
239+
assert not computed_blocks
240+
assert num_computed_tokens == 0
241+
blocks = manager.allocate_slots(req2, 55, computed_blocks)
242+
block_ids = [b.block_id for b in blocks]
243+
# Duplicate cached blocks have different ids but same hashes vs request #0
244+
assert [b.block_hash for b in blocks] == req0_block_hashes
245+
assert block_ids != [0, 1, 2, 3, 4]
246+
247+
# Request #2 block hashes are valid since request #0 hashes are.
248+
# Check block reference counts.
249+
for block_id in block_ids:
250+
assert manager.block_pool.blocks[block_id].ref_cnt == 1
251+
252+
manager.free(req2)
253+
254+
147255
def test_decode():
148256
manager = KVCacheManager(
149257
block_size=16,

‎tests/v1/core/test_scheduler.py

+49-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from typing import Optional
33

4+
import pytest
5+
46
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
57
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
68
from vllm.sampling_params import SamplingParams
@@ -16,7 +18,21 @@ def create_scheduler(
1618
model: str = "facebook/opt-125m",
1719
max_num_seqs: int = 16,
1820
max_num_batched_tokens: int = 8192,
21+
enable_prefix_caching: Optional[bool] = None,
1922
) -> Scheduler:
23+
'''Create scheduler under test.
24+
25+
Args:
26+
model: model under test
27+
max_num_seqs: max sequences to schedule
28+
max_num_batch_tokens: max num tokens to batch
29+
enable_prefix_caching: optionally force APC config
30+
(True/False) or use default
31+
(None)
32+
33+
Returns:
34+
:class:`Scheduler` instance
35+
'''
2036
scheduler_config = SchedulerConfig(
2137
max_num_seqs=max_num_seqs,
2238
max_num_batched_tokens=max_num_batched_tokens,
@@ -31,11 +47,16 @@ def create_scheduler(
3147
dtype="float16",
3248
seed=42,
3349
)
50+
# Cache config, optionally force APC
51+
kwargs_cache = ({} if enable_prefix_caching is None else {
52+
'enable_prefix_caching': enable_prefix_caching
53+
})
3454
cache_config = CacheConfig(
3555
block_size=16,
3656
gpu_memory_utilization=0.9,
3757
swap_space=0,
3858
cache_dtype="auto",
59+
**kwargs_cache,
3960
)
4061
vllm_config = VllmConfig(
4162
scheduler_config=scheduler_config,
@@ -54,16 +75,16 @@ def create_scheduler(
5475
)
5576

5677

57-
def create_requests(
58-
num_requests: int,
59-
num_tokens: int = 10,
60-
mm_positions: Optional[list[PlaceholderRange]] = None,
61-
max_tokens: int = 16,
62-
stop_token_ids: Optional[list[int]] = None,
63-
):
78+
def create_requests(num_requests: int,
79+
num_tokens: int = 10,
80+
mm_positions: Optional[list[PlaceholderRange]] = None,
81+
max_tokens: int = 16,
82+
stop_token_ids: Optional[list[int]] = None,
83+
prompt_logprobs: Optional[int] = None):
6484
sampling_params = SamplingParams(ignore_eos=False,
6585
max_tokens=max_tokens,
66-
stop_token_ids=stop_token_ids)
86+
stop_token_ids=stop_token_ids,
87+
prompt_logprobs=prompt_logprobs)
6788
requests = []
6889
for i in range(num_requests):
6990
if mm_positions is not None:
@@ -122,9 +143,18 @@ def test_get_num_unfinished_requests():
122143
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
123144

124145

125-
def test_schedule():
126-
scheduler = create_scheduler()
127-
requests = create_requests(num_requests=10)
146+
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
147+
(None, None),
148+
(True, 5),
149+
])
150+
def test_schedule(enable_prefix_caching: Optional[bool],
151+
prompt_logprobs: Optional[int]):
152+
'''Test scheduling.
153+
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
154+
'''
155+
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
156+
requests = create_requests(num_requests=10,
157+
prompt_logprobs=prompt_logprobs)
128158
for request in requests:
129159
scheduler.add_request(request)
130160

@@ -427,14 +457,21 @@ def test_stop_via_update_from_output():
427457
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
428458

429459

430-
def test_schedule_concurrent_batches():
460+
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
461+
(None, None),
462+
(True, 5),
463+
])
464+
def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
465+
prompt_logprobs: Optional[int]):
431466
scheduler = create_scheduler(
432467
max_num_batched_tokens=1024,
433468
max_num_seqs=2,
469+
enable_prefix_caching=enable_prefix_caching,
434470
)
435471
requests = create_requests(
436472
num_requests=2,
437473
num_tokens=512,
474+
prompt_logprobs=prompt_logprobs,
438475
)
439476

440477
# Schedule the first request.

‎tests/v1/engine/test_async_llm.py

-36
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import pytest
88

9-
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
109
from vllm import SamplingParams
1110
from vllm.assets.image import ImageAsset
1211
from vllm.engine.arg_utils import AsyncEngineArgs
@@ -72,41 +71,6 @@ async def generate(engine: AsyncLLM,
7271
return count, request_id
7372

7473

75-
@pytest.mark.parametrize(
76-
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
77-
@pytest.mark.asyncio
78-
async def test_async_llm_refuses_prompt_logprobs_with_apc(
79-
monkeypatch, output_kind: RequestOutputKind):
80-
"""Test passes if AsyncLLM raises an exception when it is configured
81-
for automatic prefix caching and it receives a request with
82-
prompt_logprobs enabled, which is incompatible."""
83-
# TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a
84-
# better way to test V1 so that in the future when we switch, we don't
85-
# have to change all the tests.
86-
monkeypatch.setenv("VLLM_USE_V1", "1")
87-
# Create AsyncLLM engine with APC
88-
apc_engine_args = AsyncEngineArgs(model="facebook/opt-125m",
89-
enable_prefix_caching=True,
90-
gpu_memory_utilization=0.8,
91-
disable_log_requests=True)
92-
engine = AsyncLLM.from_engine_args(apc_engine_args)
93-
try:
94-
with pytest.raises(ValueError) as excinfo:
95-
# Issue a request with prompt logprobs enabled, which should fail
96-
await asyncio.create_task(
97-
generate(engine,
98-
"request-0",
99-
TEXT_PROMPT,
100-
output_kind,
101-
10,
102-
prompt_logprobs=5))
103-
# Validate exception string is correct
104-
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG
105-
finally:
106-
# Shut down engine
107-
engine.shutdown()
108-
109-
11074
@pytest.mark.parametrize(
11175
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
11276
@pytest.mark.parametrize("engine_args_and_prompt",

‎tests/v1/engine/test_llm_engine.py

-15
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import pytest
77

8-
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
98
from vllm import LLM, SamplingParams
109

1110
MODEL = "facebook/opt-125m"
@@ -98,17 +97,3 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
9897
raise AssertionError(
9998
f"{len(completion_counts)} unique completions; expected"
10099
f" {n}. Repeats: {repeats}")
101-
102-
103-
def test_llm_engine_refuses_prompt_logprobs_with_apc(vllm_model_apc):
104-
"""Test passes if LLMEngine raises an exception when it is configured
105-
for automatic prefix caching and it receives a request with
106-
prompt_logprobs enabled, which is incompatible."""
107-
model: LLM = vllm_model_apc.model
108-
with pytest.raises(ValueError) as excinfo:
109-
model.generate(
110-
"Hello, my name is",
111-
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))
112-
113-
# Validate exception string is correct
114-
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG

‎tests/v1/engine/utils.py

-3
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@
3030
STOP_STRINGS = ["I love working on", "company by far", "brother in"]
3131
PROMPT_LEN = 5
3232

33-
PLP_APC_UNSUPPORTED_MSG = ("Prefix caching with prompt logprobs not yet "
34-
"supported on VLLM V1.")
35-
3633
random.seed(42)
3734

3835

0 commit comments

Comments
 (0)
Please sign in to comment.