1
1
# SPDX-License-Identifier: Apache-2.0
2
2
"""Compare the with and without prefix caching."""
3
3
4
+ from typing import Optional
5
+
4
6
import pytest
5
7
6
8
from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
15
17
def make_request (request_id ,
16
18
prompt_token_ids ,
17
19
mm_positions = None ,
18
- mm_hashes = None ):
20
+ mm_hashes = None ,
21
+ prompt_logprobs : Optional [int ] = None ):
19
22
if mm_positions is None :
20
23
multi_modal_inputs = None
21
24
else :
@@ -28,7 +31,8 @@ def make_request(request_id,
28
31
multi_modal_inputs = multi_modal_inputs ,
29
32
multi_modal_hashes = mm_hashes ,
30
33
multi_modal_placeholders = mm_positions ,
31
- sampling_params = SamplingParams (max_tokens = 17 ),
34
+ sampling_params = SamplingParams (max_tokens = 17 ,
35
+ prompt_logprobs = prompt_logprobs ),
32
36
eos_token_id = 100 ,
33
37
arrival_time = 0 ,
34
38
lora_request = None ,
@@ -144,6 +148,110 @@ def test_prefill():
144
148
assert manager .block_pool .free_block_queue .free_list_tail is None
145
149
146
150
151
+ def test_prefill_plp ():
152
+ '''Test prefill with APC and some prompt logprobs (plp) requests.
153
+
154
+ 1. Schedule plp request and validate APC block allocation
155
+ 2. Schedule non-plp request and validate blocks
156
+ 3. Schedule plp request; no hit should occur; validate blocks
157
+ '''
158
+ manager = KVCacheManager (
159
+ block_size = 16 ,
160
+ num_gpu_blocks = 10 ,
161
+ max_model_len = 8192 ,
162
+ sliding_window = None ,
163
+ enable_caching = True ,
164
+ num_preallocate_tokens = 16 ,
165
+ )
166
+
167
+ # Complete 3 blocks (48 tokens)
168
+ common_token_ids = [i for i in range (3 ) for _ in range (16 )]
169
+
170
+ # Request #0 is a prompt logprobs request
171
+ # Fully cache miss
172
+ # Incomplete 1 block (7 tokens)
173
+ unique_token_ids = [3 ] * 7
174
+ all_token_ids = common_token_ids + unique_token_ids
175
+ req0 = make_request ("0" , all_token_ids , prompt_logprobs = 5 )
176
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req0 )
177
+ assert len (manager .req_to_block_hashes [req0 .request_id ]) == 3
178
+ assert not computed_blocks
179
+ assert num_computed_tokens == 0
180
+ blocks = manager .allocate_slots (req0 , 55 , computed_blocks )
181
+ assert [b .block_id for b in blocks ] == [0 , 1 , 2 , 3 , 4 ]
182
+ req0_block_hashes = [b .block_hash for b in blocks ]
183
+
184
+ # Check full block metadata
185
+ parent_block_hash = None
186
+ for block_id in (0 , 1 , 2 ):
187
+ block_tokens = tuple (all_token_ids [block_id * 16 :(block_id + 1 ) * 16 ])
188
+ block_hash = hash_block_tokens (parent_block_hash , block_tokens )
189
+ assert manager .block_pool .blocks [block_id ].block_hash == block_hash
190
+ assert manager .block_pool .blocks [block_id ].ref_cnt == 1
191
+ parent_block_hash = block_hash .hash_value
192
+
193
+ # Check partial/preallocated block metadata
194
+ for block_id in (3 , 4 ):
195
+ assert manager .block_pool .blocks [block_id ].block_hash is None
196
+ assert manager .block_pool .blocks [block_id ].ref_cnt == 1
197
+
198
+ # Request #1 is a non-prompt-logprobs request:
199
+ # Cache hit in the common prefix when the original block is still in use.
200
+ # Incomplete 1 block (5 tokens)
201
+ unique_token_ids = [3 ] * 5
202
+ req1 = make_request ("1" , common_token_ids + unique_token_ids )
203
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req1 )
204
+ assert len (manager .req_to_block_hashes [req1 .request_id ]) == 3
205
+ assert [b .block_id for b in computed_blocks ] == [0 , 1 , 2 ]
206
+ assert num_computed_tokens == 3 * 16
207
+ num_new_tokens = 53 - 3 * 16
208
+ blocks = manager .allocate_slots (req1 , num_new_tokens , computed_blocks )
209
+ assert [b .block_id for b in blocks ] == [5 , 6 ]
210
+ for block in computed_blocks :
211
+ assert block .ref_cnt == 2
212
+
213
+ # At this point, we should have 3 free blocks left.
214
+ assert manager .block_pool .free_block_queue .num_free_blocks == 3
215
+
216
+ manager .free (req0 )
217
+ manager .free (req1 )
218
+
219
+ # All blocks should be available.
220
+ assert manager .block_pool .free_block_queue .num_free_blocks == 10
221
+ # The order should be
222
+ # [unallocated (7, 8, 9)]
223
+ # [unique_req0 (4, 3)]
224
+ # [unique_req1 (6, 5)]
225
+ # [common (2, 1, 0)]
226
+ assert [
227
+ b .block_id
228
+ for b in manager .block_pool .free_block_queue .get_all_free_blocks ()
229
+ ] == [7 , 8 , 9 , 4 , 3 , 6 , 5 , 2 , 1 , 0 ]
230
+
231
+ # Request #2 is a prompt-logprobs request:
232
+ # NO cache hit in the common prefix; duplicates request #0 cached blocks
233
+ unique_token_ids = [3 ] * 6
234
+ req2 = make_request ("2" ,
235
+ common_token_ids + unique_token_ids ,
236
+ prompt_logprobs = 5 )
237
+ computed_blocks , num_computed_tokens = manager .get_computed_blocks (req2 )
238
+ assert len (manager .req_to_block_hashes [req2 .request_id ]) == 3
239
+ assert not computed_blocks
240
+ assert num_computed_tokens == 0
241
+ blocks = manager .allocate_slots (req2 , 55 , computed_blocks )
242
+ block_ids = [b .block_id for b in blocks ]
243
+ # Duplicate cached blocks have different ids but same hashes vs request #0
244
+ assert [b .block_hash for b in blocks ] == req0_block_hashes
245
+ assert block_ids != [0 , 1 , 2 , 3 , 4 ]
246
+
247
+ # Request #2 block hashes are valid since request #0 hashes are.
248
+ # Check block reference counts.
249
+ for block_id in block_ids :
250
+ assert manager .block_pool .blocks [block_id ].ref_cnt == 1
251
+
252
+ manager .free (req2 )
253
+
254
+
147
255
def test_decode ():
148
256
manager = KVCacheManager (
149
257
block_size = 16 ,
0 commit comments