|
34 | 34 |
|
35 | 35 | # disable custom dispatcher, let Dynamo takes over
|
36 | 36 | # all the control
|
37 |
| - llm = LLM(model="google/gemma-2b", |
| 37 | + llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct", |
| 38 | + max_model_len=512, |
| 39 | + max_num_seqs=64, |
38 | 40 | enforce_eager=True,
|
39 | 41 | compilation_config={"level": CompilationLevel.DYNAMO_AS_IS})
|
40 | 42 | outputs = llm.generate(prompts, sampling_params)
|
|
44 | 46 | print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
45 | 47 | assert generated_text.startswith(answer)
|
46 | 48 |
|
47 |
| -compiled_code = sorted( |
| 49 | +compiled_codes = sorted( |
48 | 50 | glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
|
49 | 51 |
|
50 |
| -# we should only trigger Dynamo compilation three times: |
51 |
| -# one for the profiling phase without kv cache |
52 |
| -# one for the prefill phase with symbolic shapes |
53 |
| -# one for the decode phase with symbolic shapes |
| 52 | +for i, compiled_code in enumerate(compiled_codes): |
| 53 | + print("{} file: {}".format(i + 1, compiled_code)) |
| 54 | + |
| 55 | +# We should only trigger Dynamo compilation 4 times: |
| 56 | +# 1. forward pass (symbolic) |
| 57 | +# 2. compute_logits (symbolic) |
| 58 | +# 3. forward pass (shape 16) |
| 59 | +# 4. forward pass (shape 32) |
54 | 60 | # and later calls should not trigger Dynamo compilation again.
|
55 |
| -# NOTE: it might still trigger XLA compilation. |
| 61 | +# NOTE: It might still trigger XLA compilation. |
| 62 | + |
| 63 | +# Check we have 4 compiled codes |
| 64 | +assert len(compiled_codes) == 4 |
56 | 65 |
|
57 |
| -# check we have three compiled code |
58 |
| -# this is the assumption when we use the custom dispatcher |
59 |
| -assert len(compiled_code) == 3 |
| 66 | +kv_cache_prefix = "kv_cache" |
| 67 | +attn_prefix = "ragged_paged_attention" |
60 | 68 |
|
61 |
| -# check all the compilations are as expected |
62 |
| -compiled_fn = sorted( |
| 69 | +# Check all the compilations are as expected |
| 70 | +compiled_fns = sorted( |
63 | 71 | glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py")))
|
64 | 72 |
|
65 |
| -# the first compilation is the profiling phase, |
66 |
| -# it should not have any kv cache |
67 |
| -with open(compiled_fn[0]) as f: |
| 73 | +for i, compiled_fn in enumerate(compiled_fns): |
| 74 | + print("{} file: {}".format(i + 1, compiled_fn)) |
| 75 | + |
| 76 | +# The first compilation is symbolic, so it should not have any kv_caches |
| 77 | +with open(compiled_fns[0]) as f: |
| 78 | + content = f.read() |
| 79 | + assert kv_cache_prefix not in content |
| 80 | + |
| 81 | +# The second compilation is symbolic, so it should not have any kv_caches |
| 82 | +with open(compiled_fns[1]) as f: |
68 | 83 | content = f.read()
|
69 |
| - assert "kv_caches" not in content |
| 84 | + assert kv_cache_prefix not in content |
70 | 85 |
|
71 |
| -# the second compilation is the prefill phase, |
72 |
| -# it should have kv cache and the flash_attention op |
73 |
| -with open(compiled_fn[1]) as f: |
| 86 | +# The third compilation is shape 16, so it should have kv_caches and the |
| 87 | +# ragged_paged_attention |
| 88 | +with open(compiled_fns[2]) as f: |
74 | 89 | content = f.read()
|
75 |
| - assert "kv_caches" in content and "torch.ops.xla.flash_attention" in content |
| 90 | + assert (kv_cache_prefix in content and attn_prefix in content) |
76 | 91 |
|
77 |
| -# the third compilation is the decode phase, |
78 |
| -# it should have kv cache and the paged_attention op |
79 |
| -with open(compiled_fn[2]) as f: |
| 92 | +# The forth compilation is shape 32, so it should have kv_caches and the |
| 93 | +# ragged_paged_attention |
| 94 | +with open(compiled_fns[3]) as f: |
80 | 95 | content = f.read()
|
81 |
| - assert "kv_caches" in content and "torch.ops.xla.paged_attention" in content |
| 96 | + assert (kv_cache_prefix in content and attn_prefix in content) |
0 commit comments