-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add speculative decoding #83
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
chunk_size: int = 512, | ||
): | ||
""" | ||
Fill a KV cache for a specific model | ||
|
||
Args: | ||
model: The model to use for cache filling | ||
cache: The cache to fill | ||
tokens: Tokens to process | ||
progress_callback: Callback for reporting progress | ||
start_progress: Starting progress percentage | ||
end_progress: Ending progress percentage | ||
""" | ||
chunk_size = 512 # Default chunk size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chunk size is immediately overwritten
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch
I tried to use this with DeepSeek-R1 distills, with the following code block: LMS_MODELS_DIR=${HOME}/.lmstudio/models
# NOTE: cannot use gguf models because they lack required config.json
DRAFT_MODEL=mlx-community/DeepSeek-R1-Distill-Llama-8B-4bit
TARGET_MODEL=mlx-community/DeepSeek-R1-Distill-Llama-70B-8bit
DRAFT_TOKENS=10
python demo.py \
--model ${LMS_MODELS_DIR}/${TARGET_MODEL%@*} \
--draft-model ${LMS_MODELS_DIR}/${DRAFT_MODEL%@*} \
--num-draft-tokens ${DRAFT_TOKENS} \
--prompt "<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Write a shell script that reverses the order of string that start with the letter Y<|im_end|>
<|im_start|>assistant
" My results show that when execution succeeds (sometimes it just terminates mid-answer with a
Test System is a Mac Studio M2 with 192GB of RAM (loading both of these models uses ~90GB or 45% of the available memory) |
@deftdawg thanks for posting your results here. We’ve noticed that speculative decoding on M1/M2 chips is not as performant as expected. Feel free to checkout ml-explore/mlx-examples#1281 and the fantastic work at mlx-lm as we seek to improve this situation. |
Overview
load_draft_model(model_kit: ModelKit | VisionModelKit, path: str | Path) -> None
is_draft_model_compatible(model_kit: ModelKit | VisionModelKit, path: str | Path) -> bool
unload_draft_model(model_kit: ModelKit | VisionModelKit) -> None:
num_draft_tokens
tocreate_generator
, which is the number of tokens to draft when using speculative decoding. Throws if no draft model loaded.demo.py
demo.py
demo.py
(optional)demo.py
(optional)Results
Test machine: Apple M3 Pro 36GB
Tested with
Qwen2.5-7B-Instruct-4bit
as main model, andQwen2.5-0.5B-Instruct-MLX-8bit
as draft modelExact same output as
mlx_lm.generate
with the same model/params.See 1.83x tok/sec increase in this test case when speculative decoding is enabled.
Without speculative decoding: 29.21 tok/s
With speculative decoding: 53.53 tok/s