Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] EAGLE output norm bug #14464

Merged
merged 8 commits into from
Mar 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/features/spec_decode.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ A variety of speculative models of this type are available on HF hub:
## Speculating using EAGLE based draft models

The following code configures vLLM to use speculative decoding where proposals are generated by
an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model.
an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. A more detailed example for offline mode, including how to extract request level acceptance rate, can be found [here](<gh-file:examples/offline_inference/eagle.py>).

```python
from vllm import LLM, SamplingParams
Expand Down
93 changes: 93 additions & 0 deletions examples/offline_inference/eagle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Could you also add a sentence in the doc (https://github.com/vllm-project/vllm/blob/main/docs/source/features/spec_decode.md), referring this example?

import argparse
import json
import os

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams

parser = argparse.ArgumentParser()

parser.add_argument(
"--dataset",
type=str,
default="./examples/data/gsm8k.jsonl",
help="downloaded from the eagle repo " \
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
)
parser.add_argument("--max_num_seqs", type=int, default=8)
parser.add_argument("--num_prompts", type=int, default=80)
parser.add_argument("--num_spec_tokens", type=int, default=2)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--draft_tp", type=int, default=1)
parser.add_argument("--enforce_eager", action='store_true')
parser.add_argument("--enable_chunked_prefill", action='store_true')
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
parser.add_argument("--temp", type=float, default=0)

args = parser.parse_args()

print(args)

model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"

max_model_len = 2048

tokenizer = AutoTokenizer.from_pretrained(model_dir)

if os.path.exists(args.dataset):
prompts = []
num_prompts = args.num_prompts
with open(args.dataset) as f:
for line in f:
data = json.loads(line)
prompts.append(data["turns"][0])
else:
prompts = ["The future of AI is", "The president of the United States is"]

prompts = prompts[:args.num_prompts]
num_prompts = len(prompts)

prompt_ids = [
tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True)
for prompt in prompts
]

llm = LLM(
model=model_dir,
trust_remote_code=True,
tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill,
max_num_batched_tokens=args.max_num_batched_tokens,
enforce_eager=args.enforce_eager,
max_model_len=max_model_len,
max_num_seqs=args.max_num_seqs,
gpu_memory_utilization=0.8,
speculative_model=eagle_dir,
num_speculative_tokens=args.num_spec_tokens,
speculative_draft_tensor_parallel_size=args.draft_tp,
speculative_max_model_len=max_model_len,
disable_log_stats=False,
)

sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)

outputs = llm.generate(prompt_token_ids=prompt_ids,
sampling_params=sampling_params)

# calculate the average number of accepted tokens per forward pass, +1 is
# to account for the token from the target model that's always going to be
# accepted
acceptance_counts = [0] * (args.num_spec_tokens + 1)
for output in outputs:
for step, count in enumerate(output.metrics.spec_token_acceptance_counts):
acceptance_counts[step] += count

print(f"mean acceptance length: \
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
7 changes: 6 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,10 @@ def _create_sequence_group_with_sampling(
self.generation_config_fields, seq.eos_token_id)

# Create the sequence group.
draft_size = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a better name?
Precisely, it's not draft_size, it's the maximum number of tokens a step can generate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I named it draft_size because in the future there might be multi-draft & tree attention support. In those cases, draft_size will not be the max number of tokens a step can generate, but rather the number of nodes in the draft tree. If you believe there is a better name I am more than happy to change it!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! Then could you comment it here and explain why it's different from num_spec_token.

if self.vllm_config.speculative_config is not None:
draft_size = \
self.vllm_config.speculative_config.num_speculative_tokens + 1
seq_group = SequenceGroup(
request_id=request_id,
seqs=[seq],
Expand All @@ -839,7 +843,8 @@ def _create_sequence_group_with_sampling(
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
priority=priority,
draft_size=draft_size)

return seq_group

Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ def process_outputs(self,
seqs = sequence_group.get_seqs(
status=SequenceStatus.FINISHED_ABORTED)

for output in outputs:
if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID:
sequence_group.metrics.spec_token_acceptance_counts[
output.step_index] += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to add based on step_index, can you just add the number of generated tokens this step?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also is it step_index can only be 0,1,2,...num_spec_tokens? Do you want to group the number of accepted tokens together based on the position it's proposed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, step_index can be [0, num_spec_tokens]. This way, we can analyze the additional accepted tokens for each additional speculation step. For example, for a finished request, we might observe its node_acceptance_counts to be [100, 80, 45, 20]. With this granularity, we can tell we had 100 forward passes and for the third speculated token, it's acceptance rate is only 20%, which may not worth the verification overhead.

I'm a bit confused regarding the grouping idea. If I understood it correctly, we only have single draft spec decoding atm, so step_index is essentially the position (n-th speculated token)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, got it. Yeah then please also add [100, 80, 45, 20] example in the comments when you add comments to node_acceptance_counts.


assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, x, residual):
if residual is None:
return x
else:
return x, residual
return x + residual, None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch!



class EAGLE(nn.Module):
Expand Down
44 changes: 29 additions & 15 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ class RequestMetrics:
model_execute_time: The time spent in the model execute function. This
will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time.
spec_token_acceptance_counts: number of accepted speculative tokens at
each position; the first token is from
the target model and is always accepted;
e.g., when it's [10, 8, 4, 2] for a req,
it means there were 10 forward passes in
total, and there were 8, 4, 2 accepted
tokens at 1st, 2nd, 3rd speculation step.
"""
arrival_time: float
last_token_time: float
Expand All @@ -121,6 +128,7 @@ class RequestMetrics:
scheduler_time: Optional[float] = None
model_forward_time: Optional[float] = None
model_execute_time: Optional[float] = None
spec_token_acceptance_counts: Optional[list[int]] = None


class SequenceDataDelta(
Expand Down Expand Up @@ -639,22 +647,25 @@ class SequenceGroup:
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request.
draft_size: The number of speculative tokens plus one from the target
model; equal to max number of tokens a step can generate
for single-draft speculative decoding but larger than
that for multi-draft SD (currently not supported).
"""

def __init__(
self,
request_id: str,
seqs: list[Sequence],
arrival_time: float,
sampling_params: Optional[SamplingParams] = None,
lora_request: Optional[LoRARequest] = None,
pooling_params: Optional[PoolingParams] = None,
pooled_data: Optional[torch.Tensor] = None,
encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
def __init__(self,
request_id: str,
seqs: list[Sequence],
arrival_time: float,
sampling_params: Optional[SamplingParams] = None,
lora_request: Optional[LoRARequest] = None,
pooling_params: Optional[PoolingParams] = None,
pooled_data: Optional[torch.Tensor] = None,
encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
draft_size: int = 1) -> None:
self.request_id = request_id
self.seqs = seqs
self.first_seq = seqs[0]
Expand All @@ -667,7 +678,9 @@ def __init__(
last_token_time=arrival_time,
first_scheduled_time=None,
first_token_time=None,
time_in_queue=None)
time_in_queue=None,
spec_token_acceptance_counts=[0] *
draft_size)
self.last_token_latency = 0.0
self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None
Expand Down Expand Up @@ -1079,6 +1092,7 @@ class CompletionSequenceGroupOutput(
samples: list[SequenceOutput]
# Prompt logprob for each prompt query token.
prompt_logprobs: Optional[PromptLogprobs]
step_index: Optional[int] = 0

def __repr__(self) -> str:
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
Expand Down
2 changes: 1 addition & 1 deletion vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ def _create_output_sampler_list(
[sequence_index][:num_logprobs],
topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs],
))
step_index=step_index))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))

Expand Down
32 changes: 16 additions & 16 deletions vllm/spec_decode/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ def create_logprobs_output(


def create_sequence_group_output(
token_id: int,
token_id_logprob_rank: int,
token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]],
prompt_logprobs: Optional[PromptLogprobs] = None,
) -> CompletionSequenceGroupOutput:
token_id: int,
token_id_logprob_rank: int,
token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]],
prompt_logprobs: Optional[PromptLogprobs] = None,
step_index: Optional[int] = 0) -> CompletionSequenceGroupOutput:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as above, why is it an optional field?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I made it optional to minimize changes since I don't have to modify the single step worker code to accommodate this additional field this way. What's a better way to do this?

"""Create a SequenceGroupOutput given the sampling results.

Args:
Expand All @@ -110,6 +110,7 @@ def create_sequence_group_output(
seq_id (int): The sequence id.
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
step_index: (Optional[int]): The index of the speculative token.
"""

logprobs = create_logprobs_output(
Expand All @@ -120,14 +121,13 @@ def create_sequence_group_output(
topk_logprobs,
)

return CompletionSequenceGroupOutput(
samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs=logprobs)
],
prompt_logprobs=prompt_logprobs,
)
return CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs=logprobs)
],
prompt_logprobs=prompt_logprobs,
step_index=step_index)


def split_batch_by_proposal_len(
Expand Down