Skip to content

Commit 9ed6ee9

Browse files
authoredMar 15, 2025··
[Bugfix] EAGLE output norm bug (vllm-project#14464)
Signed-off-by: Bryan Lu <[email protected]>
1 parent ee3778d commit 9ed6ee9

File tree

8 files changed

+152
-35
lines changed

8 files changed

+152
-35
lines changed
 

‎docs/source/features/spec_decode.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ A variety of speculative models of this type are available on HF hub:
162162
## Speculating using EAGLE based draft models
163163

164164
The following code configures vLLM to use speculative decoding where proposals are generated by
165-
an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model.
165+
an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. A more detailed example for offline mode, including how to extract request level acceptance rate, can be found [here](<gh-file:examples/offline_inference/eagle.py>).
166166

167167
```python
168168
from vllm import LLM, SamplingParams

‎examples/offline_inference/eagle.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import argparse
3+
import json
4+
import os
5+
6+
from transformers import AutoTokenizer
7+
8+
from vllm import LLM, SamplingParams
9+
10+
parser = argparse.ArgumentParser()
11+
12+
parser.add_argument(
13+
"--dataset",
14+
type=str,
15+
default="./examples/data/gsm8k.jsonl",
16+
help="downloaded from the eagle repo " \
17+
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
18+
)
19+
parser.add_argument("--max_num_seqs", type=int, default=8)
20+
parser.add_argument("--num_prompts", type=int, default=80)
21+
parser.add_argument("--num_spec_tokens", type=int, default=2)
22+
parser.add_argument("--tp", type=int, default=1)
23+
parser.add_argument("--draft_tp", type=int, default=1)
24+
parser.add_argument("--enforce_eager", action='store_true')
25+
parser.add_argument("--enable_chunked_prefill", action='store_true')
26+
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
27+
parser.add_argument("--temp", type=float, default=0)
28+
29+
args = parser.parse_args()
30+
31+
print(args)
32+
33+
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
34+
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
35+
36+
max_model_len = 2048
37+
38+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
39+
40+
if os.path.exists(args.dataset):
41+
prompts = []
42+
num_prompts = args.num_prompts
43+
with open(args.dataset) as f:
44+
for line in f:
45+
data = json.loads(line)
46+
prompts.append(data["turns"][0])
47+
else:
48+
prompts = ["The future of AI is", "The president of the United States is"]
49+
50+
prompts = prompts[:args.num_prompts]
51+
num_prompts = len(prompts)
52+
53+
prompt_ids = [
54+
tokenizer.apply_chat_template([{
55+
"role": "user",
56+
"content": prompt
57+
}],
58+
add_generation_prompt=True)
59+
for prompt in prompts
60+
]
61+
62+
llm = LLM(
63+
model=model_dir,
64+
trust_remote_code=True,
65+
tensor_parallel_size=args.tp,
66+
enable_chunked_prefill=args.enable_chunked_prefill,
67+
max_num_batched_tokens=args.max_num_batched_tokens,
68+
enforce_eager=args.enforce_eager,
69+
max_model_len=max_model_len,
70+
max_num_seqs=args.max_num_seqs,
71+
gpu_memory_utilization=0.8,
72+
speculative_model=eagle_dir,
73+
num_speculative_tokens=args.num_spec_tokens,
74+
speculative_draft_tensor_parallel_size=args.draft_tp,
75+
speculative_max_model_len=max_model_len,
76+
disable_log_stats=False,
77+
)
78+
79+
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
80+
81+
outputs = llm.generate(prompt_token_ids=prompt_ids,
82+
sampling_params=sampling_params)
83+
84+
# calculate the average number of accepted tokens per forward pass, +1 is
85+
# to account for the token from the target model that's always going to be
86+
# accepted
87+
acceptance_counts = [0] * (args.num_spec_tokens + 1)
88+
for output in outputs:
89+
for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
90+
acceptance_counts[step] += count
91+
92+
print(f"mean acceptance length: \
93+
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")

‎vllm/engine/llm_engine.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,10 @@ def _create_sequence_group_with_sampling(
853853
self.generation_config_fields, seq.eos_token_id)
854854

855855
# Create the sequence group.
856+
draft_size = 1
857+
if self.vllm_config.speculative_config is not None:
858+
draft_size = \
859+
self.vllm_config.speculative_config.num_speculative_tokens + 1
856860
seq_group = SequenceGroup(
857861
request_id=request_id,
858862
seqs=[seq],
@@ -862,7 +866,8 @@ def _create_sequence_group_with_sampling(
862866
trace_headers=trace_headers,
863867
prompt_adapter_request=prompt_adapter_request,
864868
encoder_seq=encoder_seq,
865-
priority=priority)
869+
priority=priority,
870+
draft_size=draft_size)
866871

867872
return seq_group
868873

‎vllm/engine/output_processor/multi_step.py

+5
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ def process_outputs(self,
100100
seqs = sequence_group.get_seqs(
101101
status=SequenceStatus.FINISHED_ABORTED)
102102

103+
for output in outputs:
104+
if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID:
105+
sequence_group.metrics.spec_token_acceptance_counts[
106+
output.step_index] += 1
107+
103108
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
104109
assert len(seqs) == 1, (
105110
"Beam search not supported in multi-step decoding.")

‎vllm/model_executor/models/eagle.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def forward(self, x, residual):
3838
if residual is None:
3939
return x
4040
else:
41-
return x, residual
41+
return x + residual, None
4242

4343

4444
class EAGLE(nn.Module):

‎vllm/sequence.py

+29-15
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,13 @@ class RequestMetrics:
111111
model_execute_time: The time spent in the model execute function. This
112112
will include model forward, block/sync across
113113
workers, cpu-gpu sync time and sampling time.
114+
spec_token_acceptance_counts: number of accepted speculative tokens at
115+
each position; the first token is from
116+
the target model and is always accepted;
117+
e.g., when it's [10, 8, 4, 2] for a req,
118+
it means there were 10 forward passes in
119+
total, and there were 8, 4, 2 accepted
120+
tokens at 1st, 2nd, 3rd speculation step.
114121
"""
115122
arrival_time: float
116123
last_token_time: float
@@ -121,6 +128,7 @@ class RequestMetrics:
121128
scheduler_time: Optional[float] = None
122129
model_forward_time: Optional[float] = None
123130
model_execute_time: Optional[float] = None
131+
spec_token_acceptance_counts: Optional[list[int]] = None
124132

125133

126134
class SequenceDataDelta(
@@ -639,22 +647,25 @@ class SequenceGroup:
639647
trace_headers: OpenTelemetry trace headers.
640648
prompt_adapter_request: Prompt Adapter request.
641649
priority: User-defined priority of the request.
650+
draft_size: The number of speculative tokens plus one from the target
651+
model; equal to max number of tokens a step can generate
652+
for single-draft speculative decoding but larger than
653+
that for multi-draft SD (currently not supported).
642654
"""
643655

644-
def __init__(
645-
self,
646-
request_id: str,
647-
seqs: list[Sequence],
648-
arrival_time: float,
649-
sampling_params: Optional[SamplingParams] = None,
650-
lora_request: Optional[LoRARequest] = None,
651-
pooling_params: Optional[PoolingParams] = None,
652-
pooled_data: Optional[torch.Tensor] = None,
653-
encoder_seq: Optional[Sequence] = None,
654-
trace_headers: Optional[Mapping[str, str]] = None,
655-
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
656-
priority: int = 0,
657-
) -> None:
656+
def __init__(self,
657+
request_id: str,
658+
seqs: list[Sequence],
659+
arrival_time: float,
660+
sampling_params: Optional[SamplingParams] = None,
661+
lora_request: Optional[LoRARequest] = None,
662+
pooling_params: Optional[PoolingParams] = None,
663+
pooled_data: Optional[torch.Tensor] = None,
664+
encoder_seq: Optional[Sequence] = None,
665+
trace_headers: Optional[Mapping[str, str]] = None,
666+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
667+
priority: int = 0,
668+
draft_size: int = 1) -> None:
658669
self.request_id = request_id
659670
self.seqs = seqs
660671
self.first_seq = seqs[0]
@@ -667,7 +678,9 @@ def __init__(
667678
last_token_time=arrival_time,
668679
first_scheduled_time=None,
669680
first_token_time=None,
670-
time_in_queue=None)
681+
time_in_queue=None,
682+
spec_token_acceptance_counts=[0] *
683+
draft_size)
671684
self.last_token_latency = 0.0
672685
self.lora_request = lora_request
673686
self.prompt_logprobs: Optional[PromptLogprobs] = None
@@ -1079,6 +1092,7 @@ class CompletionSequenceGroupOutput(
10791092
samples: list[SequenceOutput]
10801093
# Prompt logprob for each prompt query token.
10811094
prompt_logprobs: Optional[PromptLogprobs]
1095+
step_index: Optional[int] = 0
10821096

10831097
def __repr__(self) -> str:
10841098
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "

‎vllm/spec_decode/spec_decode_worker.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,7 @@ def _create_output_sampler_list(
10801080
[sequence_index][:num_logprobs],
10811081
topk_logprobs=topk_logprobs_by_step[step_index]
10821082
[sequence_index][:num_logprobs],
1083-
))
1083+
step_index=step_index))
10841084
sampler_output_list.append(
10851085
SamplerOutput(outputs=step_output_token_ids))
10861086

‎vllm/spec_decode/util.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,14 @@ def create_logprobs_output(
9393

9494

9595
def create_sequence_group_output(
96-
token_id: int,
97-
token_id_logprob_rank: int,
98-
token_id_logprob: float,
99-
seq_id: SeqId,
100-
topk_token_ids: List[Optional[int]],
101-
topk_logprobs: List[Optional[float]],
102-
prompt_logprobs: Optional[PromptLogprobs] = None,
103-
) -> CompletionSequenceGroupOutput:
96+
token_id: int,
97+
token_id_logprob_rank: int,
98+
token_id_logprob: float,
99+
seq_id: SeqId,
100+
topk_token_ids: List[Optional[int]],
101+
topk_logprobs: List[Optional[float]],
102+
prompt_logprobs: Optional[PromptLogprobs] = None,
103+
step_index: Optional[int] = 0) -> CompletionSequenceGroupOutput:
104104
"""Create a SequenceGroupOutput given the sampling results.
105105
106106
Args:
@@ -110,6 +110,7 @@ def create_sequence_group_output(
110110
seq_id (int): The sequence id.
111111
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
112112
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
113+
step_index: (Optional[int]): The index of the speculative token.
113114
"""
114115

115116
logprobs = create_logprobs_output(
@@ -120,14 +121,13 @@ def create_sequence_group_output(
120121
topk_logprobs,
121122
)
122123

123-
return CompletionSequenceGroupOutput(
124-
samples=[
125-
SequenceOutput(parent_seq_id=seq_id,
126-
output_token=token_id,
127-
logprobs=logprobs)
128-
],
129-
prompt_logprobs=prompt_logprobs,
130-
)
124+
return CompletionSequenceGroupOutput(samples=[
125+
SequenceOutput(parent_seq_id=seq_id,
126+
output_token=token_id,
127+
logprobs=logprobs)
128+
],
129+
prompt_logprobs=prompt_logprobs,
130+
step_index=step_index)
131131

132132

133133
def split_batch_by_proposal_len(

0 commit comments

Comments
 (0)
Please sign in to comment.