Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add speculative decoding #83

Merged
merged 9 commits into from
Jan 23, 2025
Merged

Add speculative decoding #83

merged 9 commits into from
Jan 23, 2025

Conversation

mattjcly
Copy link
Member

@mattjcly mattjcly commented Jan 23, 2025

Overview

  • Adds the following methods to mlx_engine.generate's API:
    • load_draft_model(model_kit: ModelKit | VisionModelKit, path: str | Path) -> None
    • is_draft_model_compatible(model_kit: ModelKit | VisionModelKit, path: str | Path) -> bool
    • unload_draft_model(model_kit: ModelKit | VisionModelKit) -> None:
  • These methods allow for a speculative decoding draft model to be loaded/unloaded at any point during the lifetime of a ModelKit
  • Adds optional generation parameter num_draft_tokens to create_generator, which is the number of tokens to draft when using speculative decoding. Throws if no draft model loaded.
  • Adds unit tests (and confirmed that all prior unit tests pass still)
  • Adds capability to run speculative decoding example in demo.py
  • Adds generation statistics info to demo.py
  • Adds prompt processing callback printouts to demo.py (optional)
  • Adds temp argument to demo.py (optional)
  • Adds simple but helpful log methods

Results

Test machine: Apple M3 Pro 36GB

Tested with Qwen2.5-7B-Instruct-4bit as main model, and Qwen2.5-0.5B-Instruct-MLX-8bit as draft model

Exact same output as mlx_lm.generate with the same model/params.

See 1.83x tok/sec increase in this test case when speculative decoding is enabled.

Without speculative decoding: 29.21 tok/s

-> % python demo.py --model /Users/matt/.cache/lm-studio/models/mlx-community/Qwen2.5-7B-Instruct-4bit --prompt "<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Write a quick sort in C++<|im_end|>
<|im_start|>assistant
" --temp 0.0
[ModelKit][INFO] Loading model from /Users/matt/.cache/lm-studio/models/mlx-community/Qwen2.5-7B-Instruct-4bit...
[ModelKit][INFO] Model loaded successfully
Certainly! Below is a simple implementation of the Quick Sort algorithm in C++:
...

Generation stats:
 - Time to first token: 0.25s
 - Total tokens generated: 594
 - Total time: 20.34s
 - Tokens per second: 29.21

With speculative decoding: 53.53 tok/s

-> % python demo.py --model /Users/matt/.cache/lm-studio/models/mlx-community/Qwen2.5-7B-Instruct-4bit --draft-model /Users/matt/.cache/lm-studio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit --prompt "<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Write a quick sort in C++<|im_end|>
<|im_start|>assistant
" --temp 0.0
[ModelKit][INFO] Loading model from /Users/matt/.cache/lm-studio/models/mlx-community/Qwen2.5-7B-Instruct-4bit...
[ModelKit][INFO] Model loaded successfully
[ModelKit][INFO] Loading draft model from /Users/matt/.cache/lm-studio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit...
[ModelKit][INFO] Draft model loaded
Certainly! Below is a simple implementation of the Quick Sort algorithm in C++:
...

Generation stats:
 - Time to first token: 0.28s
 - Total tokens generated: 595
 - Total time: 11.12s
 - Tokens per second: 53.53

@mattjcly mattjcly requested review from neilmehta24 and yagil January 23, 2025 20:24
Copy link
Member

@yagil yagil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

@mattjcly mattjcly merged commit cb1b880 into main Jan 23, 2025
@mattjcly mattjcly deleted the matt/speculative-decoding branch January 23, 2025 23:40
Comment on lines +132 to +145
chunk_size: int = 512,
):
"""
Fill a KV cache for a specific model

Args:
model: The model to use for cache filling
cache: The cache to fill
tokens: Tokens to process
progress_callback: Callback for reporting progress
start_progress: Starting progress percentage
end_progress: Ending progress percentage
"""
chunk_size = 512 # Default chunk size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chunk size is immediately overwritten

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

@deftdawg
Copy link

deftdawg commented Feb 6, 2025

I tried to use this with DeepSeek-R1 distills, with the following code block:

LMS_MODELS_DIR=${HOME}/.lmstudio/models
# NOTE: cannot use gguf models because they lack required config.json
DRAFT_MODEL=mlx-community/DeepSeek-R1-Distill-Llama-8B-4bit
TARGET_MODEL=mlx-community/DeepSeek-R1-Distill-Llama-70B-8bit

DRAFT_TOKENS=10

python demo.py \
    --model ${LMS_MODELS_DIR}/${TARGET_MODEL%@*} \
    --draft-model ${LMS_MODELS_DIR}/${DRAFT_MODEL%@*} \
    --num-draft-tokens ${DRAFT_TOKENS} \
    --prompt "<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Write a shell script that reverses the order of string that start with the letter Y<|im_end|>
<|im_start|>assistant
"

My results show that when execution succeeds (sometimes it just terminates mid-answer with a % character), it is slower than running without a draft model... Is this something to do with the R1 models I'm using or am I missing a parameter to make it go faster or is there a bug with the implementation?

num-draft-tokens Tokens per second (t/s)
12 6.98
10 8.03
8 7.96
5 7.31
4 5.49
3 5.54
0 8.32, 8.37, 8.24
Without Spec Dec 8.64, 8.68, 8.76

Test System is a Mac Studio M2 with 192GB of RAM (loading both of these models uses ~90GB or 45% of the available memory)

@mattjcly
Copy link
Member Author

@deftdawg thanks for posting your results here. We’ve noticed that speculative decoding on M1/M2 chips is not as performant as expected. Feel free to checkout ml-explore/mlx-examples#1281 and the fantastic work at mlx-lm as we seek to improve this situation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants