-
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] EAGLE output norm bug #14464
Changes from all commits
7dee20f
143b13a
acaf894
5373496
85d5a91
a16e424
9f22b9f
d1ea6da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import argparse | ||
import json | ||
import os | ||
|
||
from transformers import AutoTokenizer | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--dataset", | ||
type=str, | ||
default="./examples/data/gsm8k.jsonl", | ||
help="downloaded from the eagle repo " \ | ||
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/" | ||
) | ||
parser.add_argument("--max_num_seqs", type=int, default=8) | ||
parser.add_argument("--num_prompts", type=int, default=80) | ||
parser.add_argument("--num_spec_tokens", type=int, default=2) | ||
parser.add_argument("--tp", type=int, default=1) | ||
parser.add_argument("--draft_tp", type=int, default=1) | ||
parser.add_argument("--enforce_eager", action='store_true') | ||
parser.add_argument("--enable_chunked_prefill", action='store_true') | ||
parser.add_argument("--max_num_batched_tokens", type=int, default=2048) | ||
parser.add_argument("--temp", type=float, default=0) | ||
|
||
args = parser.parse_args() | ||
|
||
print(args) | ||
|
||
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" | ||
|
||
max_model_len = 2048 | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_dir) | ||
|
||
if os.path.exists(args.dataset): | ||
prompts = [] | ||
num_prompts = args.num_prompts | ||
with open(args.dataset) as f: | ||
for line in f: | ||
data = json.loads(line) | ||
prompts.append(data["turns"][0]) | ||
else: | ||
prompts = ["The future of AI is", "The president of the United States is"] | ||
|
||
prompts = prompts[:args.num_prompts] | ||
num_prompts = len(prompts) | ||
|
||
prompt_ids = [ | ||
tokenizer.apply_chat_template([{ | ||
"role": "user", | ||
"content": prompt | ||
}], | ||
add_generation_prompt=True) | ||
for prompt in prompts | ||
] | ||
|
||
llm = LLM( | ||
model=model_dir, | ||
trust_remote_code=True, | ||
tensor_parallel_size=args.tp, | ||
enable_chunked_prefill=args.enable_chunked_prefill, | ||
max_num_batched_tokens=args.max_num_batched_tokens, | ||
enforce_eager=args.enforce_eager, | ||
max_model_len=max_model_len, | ||
max_num_seqs=args.max_num_seqs, | ||
gpu_memory_utilization=0.8, | ||
speculative_model=eagle_dir, | ||
num_speculative_tokens=args.num_spec_tokens, | ||
speculative_draft_tensor_parallel_size=args.draft_tp, | ||
speculative_max_model_len=max_model_len, | ||
disable_log_stats=False, | ||
) | ||
|
||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) | ||
|
||
outputs = llm.generate(prompt_token_ids=prompt_ids, | ||
sampling_params=sampling_params) | ||
|
||
# calculate the average number of accepted tokens per forward pass, +1 is | ||
# to account for the token from the target model that's always going to be | ||
# accepted | ||
acceptance_counts = [0] * (args.num_spec_tokens + 1) | ||
for output in outputs: | ||
for step, count in enumerate(output.metrics.spec_token_acceptance_counts): | ||
acceptance_counts[step] += count | ||
|
||
print(f"mean acceptance length: \ | ||
{sum(acceptance_counts) / acceptance_counts[0]:.2f}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -830,6 +830,10 @@ def _create_sequence_group_with_sampling( | |
self.generation_config_fields, seq.eos_token_id) | ||
|
||
# Create the sequence group. | ||
draft_size = 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a better name? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I named it draft_size because in the future there might be multi-draft & tree attention support. In those cases, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see! Then could you comment it here and explain why it's different from num_spec_token. |
||
if self.vllm_config.speculative_config is not None: | ||
draft_size = \ | ||
self.vllm_config.speculative_config.num_speculative_tokens + 1 | ||
seq_group = SequenceGroup( | ||
request_id=request_id, | ||
seqs=[seq], | ||
|
@@ -839,7 +843,8 @@ def _create_sequence_group_with_sampling( | |
trace_headers=trace_headers, | ||
prompt_adapter_request=prompt_adapter_request, | ||
encoder_seq=encoder_seq, | ||
priority=priority) | ||
priority=priority, | ||
draft_size=draft_size) | ||
|
||
return seq_group | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -100,6 +100,11 @@ def process_outputs(self, | |
seqs = sequence_group.get_seqs( | ||
status=SequenceStatus.FINISHED_ABORTED) | ||
|
||
for output in outputs: | ||
if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID: | ||
sequence_group.metrics.spec_token_acceptance_counts[ | ||
output.step_index] += 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need to add based on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also is it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I'm a bit confused regarding the grouping idea. If I understood it correctly, we only have single draft spec decoding atm, so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, got it. Yeah then please also add [100, 80, 45, 20] example in the comments when you add comments to |
||
|
||
assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences" | ||
assert len(seqs) == 1, ( | ||
"Beam search not supported in multi-step decoding.") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,7 +38,7 @@ def forward(self, x, residual): | |
if residual is None: | ||
return x | ||
else: | ||
return x, residual | ||
return x + residual, None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the catch! |
||
|
||
|
||
class EAGLE(nn.Module): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,14 +93,14 @@ def create_logprobs_output( | |
|
||
|
||
def create_sequence_group_output( | ||
token_id: int, | ||
token_id_logprob_rank: int, | ||
token_id_logprob: float, | ||
seq_id: SeqId, | ||
topk_token_ids: List[Optional[int]], | ||
topk_logprobs: List[Optional[float]], | ||
prompt_logprobs: Optional[PromptLogprobs] = None, | ||
) -> CompletionSequenceGroupOutput: | ||
token_id: int, | ||
token_id_logprob_rank: int, | ||
token_id_logprob: float, | ||
seq_id: SeqId, | ||
topk_token_ids: List[Optional[int]], | ||
topk_logprobs: List[Optional[float]], | ||
prompt_logprobs: Optional[PromptLogprobs] = None, | ||
step_index: Optional[int] = 0) -> CompletionSequenceGroupOutput: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question as above, why is it an optional field? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I made it optional to minimize changes since I don't have to modify the single step worker code to accommodate this additional field this way. What's a better way to do this? |
||
"""Create a SequenceGroupOutput given the sampling results. | ||
|
||
Args: | ||
|
@@ -110,6 +110,7 @@ def create_sequence_group_output( | |
seq_id (int): The sequence id. | ||
topk_token_ids (List[Optional[int]]): The list of top-k token ids. | ||
topk_logprobs (List[Optional[float]]): The list of top-k logprobs. | ||
step_index: (Optional[int]): The index of the speculative token. | ||
""" | ||
|
||
logprobs = create_logprobs_output( | ||
|
@@ -120,14 +121,13 @@ def create_sequence_group_output( | |
topk_logprobs, | ||
) | ||
|
||
return CompletionSequenceGroupOutput( | ||
samples=[ | ||
SequenceOutput(parent_seq_id=seq_id, | ||
output_token=token_id, | ||
logprobs=logprobs) | ||
], | ||
prompt_logprobs=prompt_logprobs, | ||
) | ||
return CompletionSequenceGroupOutput(samples=[ | ||
SequenceOutput(parent_seq_id=seq_id, | ||
output_token=token_id, | ||
logprobs=logprobs) | ||
], | ||
prompt_logprobs=prompt_logprobs, | ||
step_index=step_index) | ||
|
||
|
||
def split_batch_by_proposal_len( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Could you also add a sentence in the doc (https://github.com/vllm-project/vllm/blob/main/docs/source/features/spec_decode.md), referring this example?