Skip to content

Commit 0a049c7

Browse files
[CI/Build] Add tests for the V1 tpu_model_runner. (#14843)
Signed-off-by: Yarong Mu <[email protected]>
1 parent d0cfec7 commit 0a049c7

File tree

3 files changed

+310
-1
lines changed

3 files changed

+310
-1
lines changed

.buildkite/run-tpu-v1-test.sh

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ docker run --privileged --net host --shm-size=16G -it \
3030
&& echo TEST_4 \
3131
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
3232
&& echo TEST_5 \
33-
&& python3 /workspace/vllm/examples/offline_inference/tpu.py" \
33+
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
34+
&& echo TEST_6 \
35+
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py" \
3436

3537

3638
# TODO: This test fails because it uses RANDOM_SEED sampling

tests/v1/tpu/worker/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import unittest.mock as mock
3+
4+
import pytest
5+
6+
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
7+
from vllm.sampling_params import SamplingParams
8+
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
9+
SchedulerOutput)
10+
from vllm.v1.sample.metadata import SamplingMetadata
11+
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
12+
13+
# Mock torch_xla module since it may not be available in the test environments
14+
torch_xla_patcher = mock.patch.dict(
15+
"sys.modules", {
16+
"torch_xla": mock.MagicMock(),
17+
"torch_xla.core.xla_model": mock.MagicMock(),
18+
"torch_xla.runtime": mock.MagicMock(),
19+
})
20+
torch_xla_patcher.start()
21+
22+
# Mock the PallasAttentionBackend
23+
pallas_attention_backend_patcher = mock.patch(
24+
"vllm.v1.worker.tpu_model_runner.PallasAttentionBackend", )
25+
pallas_attention_backend_patcher.start()
26+
27+
28+
@pytest.fixture
29+
def model_runner():
30+
# Patchers have already been started at module level.
31+
scheduler_config = SchedulerConfig(
32+
max_num_seqs=10,
33+
max_num_batched_tokens=512,
34+
max_model_len=512,
35+
)
36+
model_config = ModelConfig(
37+
model="facebook/opt-125m",
38+
task="generate",
39+
tokenizer="facebook/opt-125m",
40+
tokenizer_mode="auto",
41+
trust_remote_code=True,
42+
dtype="bfloat16", # TPUs typically use bfloat16
43+
seed=42,
44+
)
45+
cache_config = CacheConfig(
46+
block_size=16,
47+
gpu_memory_utilization=0.9,
48+
swap_space=0,
49+
cache_dtype="auto",
50+
)
51+
vllm_config = VllmConfig(
52+
model_config=model_config,
53+
cache_config=cache_config,
54+
scheduler_config=scheduler_config,
55+
)
56+
device = "xla:0" # Mocking TPU device
57+
with mock.patch("vllm.v1.worker.tpu_model_runner.torch"), \
58+
mock.patch("vllm.v1.worker.tpu_model_runner.xm"), \
59+
mock.patch("vllm.v1.worker.tpu_model_runner.xr"):
60+
return TPUModelRunner(vllm_config, device)
61+
62+
63+
@pytest.fixture(autouse=True, scope="session")
64+
def cleanup_patches():
65+
yield
66+
torch_xla_patcher.stop()
67+
pallas_attention_backend_patcher.stop()
68+
69+
70+
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
71+
new_reqs = []
72+
num_scheduled_tokens = {}
73+
total_num_scheduled_tokens = 0
74+
for req_id in req_ids:
75+
new_reqs.append(
76+
NewRequestData(
77+
req_id=req_id,
78+
prompt_token_ids=[1, 2, 3],
79+
prompt="test",
80+
mm_inputs=[],
81+
mm_hashes=[],
82+
mm_positions=[],
83+
sampling_params=SamplingParams(),
84+
block_ids=[0],
85+
num_computed_tokens=0,
86+
lora_request=None,
87+
))
88+
num_scheduled_tokens[req_id] = 3
89+
total_num_scheduled_tokens += num_scheduled_tokens[req_id]
90+
91+
return SchedulerOutput(
92+
scheduled_new_reqs=new_reqs,
93+
scheduled_cached_reqs=[],
94+
num_scheduled_tokens=num_scheduled_tokens,
95+
total_num_scheduled_tokens=total_num_scheduled_tokens,
96+
scheduled_spec_decode_tokens={},
97+
scheduled_encoder_inputs={},
98+
num_common_prefix_blocks=0,
99+
finished_req_ids=set(),
100+
free_encoder_input_ids=[],
101+
structured_output_request_ids={},
102+
grammar_bitmask=None,
103+
)
104+
105+
106+
def _is_req_scheduled(model_runner, req_id: str) -> bool:
107+
return req_id in model_runner.input_batch.req_id_to_index
108+
109+
110+
def _is_req_added(model_runner, req_id: str) -> bool:
111+
return req_id in model_runner.requests
112+
113+
114+
def _is_sampling_metadata_changed(model_runner,
115+
sampling_metadata_before: SamplingMetadata):
116+
return model_runner.input_batch.sampling_metadata is not (
117+
sampling_metadata_before)
118+
119+
120+
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
121+
req_index = model_runner.input_batch.req_id_to_index[req_id]
122+
block_table = model_runner.input_batch.block_table
123+
req_state = model_runner.requests[req_id]
124+
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
125+
return False
126+
num_blocks = block_table.num_blocks_per_row[req_index]
127+
return (block_table.block_table_np[req_index, :num_blocks] ==
128+
req_state.block_ids).all()
129+
130+
131+
def test_update_states_new_request(model_runner):
132+
req_id = "req_0"
133+
134+
# new req
135+
scheduler_output = _schedule_new_request(req_id)
136+
137+
metadata_before = model_runner.input_batch.sampling_metadata
138+
model_runner._update_states(scheduler_output)
139+
140+
assert _is_sampling_metadata_changed(model_runner, metadata_before)
141+
assert _is_req_added(model_runner, req_id)
142+
assert _is_req_scheduled(model_runner, req_id)
143+
assert _is_req_state_block_table_match(model_runner, req_id)
144+
145+
146+
def test_update_states_request_finished(model_runner):
147+
req_id = "req_0"
148+
149+
# new req
150+
scheduler_output = _schedule_new_request(req_id)
151+
152+
model_runner._update_states(scheduler_output)
153+
assert _is_req_added(model_runner, req_id)
154+
assert _is_req_scheduled(model_runner, req_id)
155+
156+
# finish req
157+
scheduler_output = SchedulerOutput(
158+
scheduled_new_reqs=[],
159+
scheduled_cached_reqs=[],
160+
num_scheduled_tokens={},
161+
total_num_scheduled_tokens=0,
162+
scheduled_spec_decode_tokens={},
163+
scheduled_encoder_inputs={},
164+
num_common_prefix_blocks=0,
165+
finished_req_ids={req_id},
166+
free_encoder_input_ids=[],
167+
structured_output_request_ids={},
168+
grammar_bitmask=None,
169+
)
170+
171+
metadata_before = model_runner.input_batch.sampling_metadata
172+
model_runner._update_states(scheduler_output)
173+
assert _is_sampling_metadata_changed(model_runner, metadata_before)
174+
assert not _is_req_added(model_runner, req_id)
175+
assert not _is_req_scheduled(model_runner, req_id)
176+
177+
178+
def test_update_states_request_resumed(model_runner):
179+
req_id = "req_0"
180+
181+
# new req
182+
scheduler_output = _schedule_new_request(req_id)
183+
184+
model_runner._update_states(scheduler_output)
185+
assert _is_req_added(model_runner, req_id)
186+
assert _is_req_scheduled(model_runner, req_id)
187+
188+
# unschedule req
189+
scheduler_output = SchedulerOutput(
190+
scheduled_new_reqs=[],
191+
scheduled_cached_reqs=[],
192+
num_scheduled_tokens={},
193+
total_num_scheduled_tokens=0,
194+
scheduled_spec_decode_tokens={},
195+
scheduled_encoder_inputs={},
196+
num_common_prefix_blocks=0,
197+
finished_req_ids=set(),
198+
free_encoder_input_ids=[],
199+
structured_output_request_ids={},
200+
grammar_bitmask=None,
201+
)
202+
203+
model_runner._update_states(scheduler_output)
204+
assert _is_req_added(model_runner, req_id)
205+
assert not _is_req_scheduled(model_runner, req_id)
206+
207+
# resume req
208+
cached_req_data = CachedRequestData(
209+
req_id=req_id,
210+
resumed_from_preemption=False,
211+
new_token_ids=[],
212+
new_block_ids=[],
213+
num_computed_tokens=0,
214+
)
215+
216+
scheduler_output = SchedulerOutput(
217+
scheduled_new_reqs=[],
218+
scheduled_cached_reqs=[cached_req_data],
219+
num_scheduled_tokens={req_id: 1},
220+
total_num_scheduled_tokens=1,
221+
scheduled_spec_decode_tokens={},
222+
scheduled_encoder_inputs={},
223+
num_common_prefix_blocks=0,
224+
finished_req_ids=set(),
225+
free_encoder_input_ids=[],
226+
structured_output_request_ids={},
227+
grammar_bitmask=None,
228+
)
229+
230+
metadata_before = model_runner.input_batch.sampling_metadata
231+
model_runner._update_states(scheduler_output)
232+
assert _is_sampling_metadata_changed(model_runner, metadata_before)
233+
assert _is_req_added(model_runner, req_id)
234+
assert _is_req_scheduled(model_runner, req_id)
235+
assert _is_req_state_block_table_match(model_runner, req_id)
236+
237+
238+
def test_update_states_no_changes(model_runner):
239+
req_id = "req_0"
240+
241+
# new req
242+
scheduler_output = _schedule_new_request(req_id)
243+
244+
model_runner._update_states(scheduler_output)
245+
assert _is_req_added(model_runner, req_id)
246+
assert _is_req_scheduled(model_runner, req_id)
247+
248+
# schedule req
249+
scheduler_output = SchedulerOutput(
250+
scheduled_new_reqs=[],
251+
scheduled_cached_reqs=[],
252+
num_scheduled_tokens={req_id: 1},
253+
total_num_scheduled_tokens=1,
254+
scheduled_spec_decode_tokens={},
255+
scheduled_encoder_inputs={},
256+
num_common_prefix_blocks=0,
257+
finished_req_ids=set(),
258+
free_encoder_input_ids=[],
259+
structured_output_request_ids={},
260+
grammar_bitmask=None,
261+
)
262+
263+
metadata_before = model_runner.input_batch.sampling_metadata
264+
model_runner._update_states(scheduler_output)
265+
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
266+
assert _is_req_added(model_runner, req_id)
267+
assert _is_req_scheduled(model_runner, req_id)
268+
assert _is_req_state_block_table_match(model_runner, req_id)
269+
270+
271+
def test_update_states_request_unscheduled(model_runner):
272+
req_ids = ("req_0", "req_1")
273+
274+
# new reqs
275+
scheduler_output = _schedule_new_request(*req_ids)
276+
277+
model_runner._update_states(scheduler_output)
278+
279+
assert _is_req_added(model_runner, req_ids[0])
280+
assert _is_req_scheduled(model_runner, req_ids[0])
281+
282+
assert _is_req_added(model_runner, req_ids[1])
283+
assert _is_req_scheduled(model_runner, req_ids[1])
284+
285+
# unschedule req_1
286+
scheduler_output = SchedulerOutput(
287+
scheduled_new_reqs=[],
288+
scheduled_cached_reqs=[],
289+
num_scheduled_tokens={req_ids[0]: 1},
290+
total_num_scheduled_tokens=1,
291+
scheduled_spec_decode_tokens={},
292+
scheduled_encoder_inputs={},
293+
num_common_prefix_blocks=0,
294+
finished_req_ids=set(),
295+
free_encoder_input_ids=[],
296+
structured_output_request_ids={},
297+
grammar_bitmask=None,
298+
)
299+
300+
metadata_before = model_runner._update_states(scheduler_output)
301+
assert _is_sampling_metadata_changed(model_runner, metadata_before)
302+
303+
assert _is_req_added(model_runner, req_ids[0])
304+
assert _is_req_scheduled(model_runner, req_ids[0])
305+
306+
assert _is_req_added(model_runner, req_ids[1])
307+
assert not _is_req_scheduled(model_runner, req_ids[1])

0 commit comments

Comments
 (0)