|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 |
| -import tempfile |
3 |
| -from time import time |
4 | 2 |
|
5 | 3 | import pytest
|
6 | 4 |
|
|
15 | 13 | )
|
16 | 14 |
|
17 | 15 |
|
18 |
| -# TODO remove this test once VLLM_XLA_CHECK_RECOMPILATION does not error out |
19 |
| -@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"]) |
20 |
| -@pytest.mark.skipif(not current_platform.is_tpu(), |
21 |
| - reason="This test needs a TPU") |
22 |
| -def test_sampler_compilation(model_name: str, monkeypatch): |
23 |
| - """ |
24 |
| - Check that no recompilation happens despite changing sampling parameters. |
25 |
| - We can't read XLA metrics from the engine process, hence we measure time. |
26 |
| - """ |
27 |
| - with tempfile.TemporaryDirectory() as temp_dir: |
28 |
| - monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir) |
29 |
| - # Compiling model init may still take some time, enforce_eager to skip. |
30 |
| - llm = LLM(model_name, |
31 |
| - enforce_eager=True, |
32 |
| - max_num_seqs=16, |
33 |
| - max_model_len=1024, |
34 |
| - gpu_memory_utilization=0.5) |
35 |
| - prompts = [ |
36 |
| - "A robot may not injure a human being", |
37 |
| - "It is only with the heart that one can see rightly;", |
38 |
| - ] |
39 |
| - # First inference should be slow |
40 |
| - sampling_params = SamplingParams( |
41 |
| - temperature=0.7, |
42 |
| - # top_p=0.6, # TODO too slow! |
43 |
| - top_k=10, |
44 |
| - min_p=0.2, |
45 |
| - max_tokens=16) |
46 |
| - s = time() |
47 |
| - _ = llm.generate(prompts, sampling_params) |
48 |
| - run1 = time() - s |
49 |
| - |
50 |
| - # Second request with different params, but for which we |
51 |
| - # compiled for in previous eager iteration. |
52 |
| - sampling_params = SamplingParams(temperature=0.1, |
53 |
| - top_k=12, |
54 |
| - min_p=0.8, |
55 |
| - max_tokens=24) |
56 |
| - s = time() |
57 |
| - _ = llm.generate(prompts, sampling_params) |
58 |
| - run2 = time() - s |
59 |
| - # Much faster after compiling |
60 |
| - assert run1 * 0.1 > run2 |
61 |
| - print("TIMES", run1, run2) |
62 |
| - |
63 |
| - # Third request with min_p set to "None". It will not trigger |
64 |
| - # recompilation as a default 0 value will be used. |
65 |
| - sampling_params = SamplingParams(max_tokens=24, temperature=0.0) |
66 |
| - s = time() |
67 |
| - _ = llm.generate(prompts, sampling_params) |
68 |
| - run3 = time() - s |
69 |
| - assert run1 * 0.1 > run3 |
70 |
| - print("TIMES", run1, run3) |
71 |
| - |
72 |
| - |
73 | 16 | @pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
|
74 | 17 | @pytest.mark.skipif(not current_platform.is_tpu(),
|
75 | 18 | reason="This test needs a TPU")
|
|
0 commit comments