Skip to content

Commit

Permalink
Add speculative decoding (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjcly authored Jan 23, 2025
1 parent e152e38 commit cb1b880
Show file tree
Hide file tree
Showing 12 changed files with 457 additions and 56 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pip install -U -r requirements.txt

### Text Model Demo
Download models with the `lms` CLI tool. The `lms` CLI documentation can be found here: https://lmstudio.ai/docs/cli
Run the `demo.py` script with an MLX text model:
Run the `demo.py` script with an MLX text generation model:
```bash
lms get mlx-community/Meta-Llama-3.1-8B-Instruct-4bit
python demo.py --model ~/.cache/lm-studio/models/mlx-community/Meta-Llama-3.1-8B-Instruct-4bit
Expand Down Expand Up @@ -84,6 +84,22 @@ Currently supported vision models include:
- [Llava-v1.6](https://model.lmstudio.ai/download/mlx-community/llava-v1.6-mistral-7b-4bit)
- `lms get mlx-community/llava-v1.6-mistral-7b-4bit`

### Speculative Decoding Demo
Run the `demo.py` script with an MLX text generation model and a compatible `--draft-model`
```bash
lms get mlx-community/Qwen2.5-7B-Instruct-4bit
lms get lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit
python demo.py \
--model ~/.lmstudio/models/mlx-community/Qwen2.5-7B-Instruct-4bit \
--draft-model ~/.lmstudio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit \
--prompt "<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Write a quick sort algorithm in C++<|im_end|>
<|im_start|>assistant
"
```

## Testing

To run tests, run the following command from the root of this repo:
Expand Down
78 changes: 75 additions & 3 deletions demo.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import argparse
import base64
import time

from mlx_engine.generate import load_model, create_generator, tokenize
from mlx_engine.generate import load_model, load_draft_model, create_generator, tokenize
from mlx_engine.model_kit import VALID_KV_BITS, VALID_KV_GROUP_SIZE


DEFAULT_PROMPT = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Explain the rules of chess in one sentence.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
DEFAULT_TEMP = 0.8


def setup_arg_parser():
Expand All @@ -20,7 +22,7 @@ def setup_arg_parser():
"--model",
required=True,
type=str,
help="The path to the local model directory.",
help="The file system path to the model",
)
parser.add_argument(
"--prompt",
Expand All @@ -34,6 +36,12 @@ def setup_arg_parser():
nargs="+",
help="Path of the images to process",
)
parser.add_argument(
"--temp",
default=DEFAULT_TEMP,
type=float,
help="Sampling temperature",
)
parser.add_argument(
"--stop-strings",
type=str,
Expand Down Expand Up @@ -68,6 +76,21 @@ def setup_arg_parser():
type=int,
help="When --kv-bits is set, start quantizing the KV cache from this step onwards",
)
parser.add_argument(
"--draft-model",
type=str,
help="The file system path to the draft model for speculative decoding.",
)
parser.add_argument(
"--num-draft-tokens",
type=int,
help="Number of tokens to draft when using speculative decoding.",
)
parser.add_argument(
"--print-prompt-progress",
action="store_true",
help="Enable printed prompt processing progress callback",
)
return parser


Expand All @@ -76,13 +99,48 @@ def image_to_base64(image_path):
return base64.b64encode(image_file.read()).decode("utf-8")


class GenerationStatsCollector:
def __init__(self):
self.start_time = time.time()
self.first_token_time = None
self.total_tokens = 0

def add_tokens(self, tokens):
"""Record new tokens and their timing."""
if self.first_token_time is None:
self.first_token_time = time.time()
self.total_tokens += len(tokens)

def print_stats(self):
"""Print generation statistics."""
end_time = time.time()
total_time = end_time - self.start_time
print(f"\n\nGeneration stats:")
print(f" - Time to first token: {self.first_token_time - self.start_time:.2f}s")
print(f" - Total tokens generated: {self.total_tokens}")
print(f" - Total time: {total_time:.2f}s")
print(f" - Tokens per second: {self.total_tokens / total_time:.2f}")


if __name__ == "__main__":
# Parse arguments
parser = setup_arg_parser()
args = parser.parse_args()
if isinstance(args.images, str):
args.images = [args.images]

# Set up prompt processing callback
def prompt_progress_callback(percent):
if args.print_prompt_progress:
width = 40 # bar width
filled = int(width * percent / 100)
bar = "█" * filled + "░" * (width - filled)
print(f"\rProcessing prompt: |{bar}| ({percent:.1f}%)", end="", flush=True)
if percent >= 100:
print() # new line when done
else:
pass

# Load the model
model_path = args.model
model_kit = load_model(
Expand All @@ -94,6 +152,10 @@ def image_to_base64(image_path):
quantized_kv_start=args.quantized_kv_start,
)

# Load draft model if requested
if args.draft_model:
load_draft_model(model_kit=model_kit, path=args.draft_model)

# Tokenize the prompt
prompt = args.prompt
prompt_tokens = tokenize(model_kit, prompt)
Expand All @@ -108,6 +170,9 @@ def image_to_base64(image_path):
# Record top logprobs
logprobs_list = []

# Initialize generation stats collector
stats_collector = GenerationStatsCollector()

# Generate the response
generator = create_generator(
model_kit,
Expand All @@ -116,15 +181,22 @@ def image_to_base64(image_path):
stop_strings=args.stop_strings,
max_tokens=1024,
top_logprobs=args.top_logprobs,
prompt_progress_callback=prompt_progress_callback,
num_draft_tokens=args.num_draft_tokens,
temp=args.temp,
)
for generation_result in generator:
print(generation_result.text, end="", flush=True)
stats_collector.add_tokens(generation_result.tokens)
logprobs_list.extend(generation_result.top_logprobs)

if generation_result.stop_condition:
stats_collector.print_stats()
print(
f"\n\nStopped generation due to: {generation_result.stop_condition.stop_reason}"
f"\nStopped generation due to: {generation_result.stop_condition.stop_reason}"
)
if generation_result.stop_condition.stop_string:
print(f"Stop string: {generation_result.stop_condition.stop_string}")

if args.top_logprobs:
[print(x) for x in logprobs_list]
9 changes: 8 additions & 1 deletion mlx_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,11 @@ def _set_outlines_cache_dir(cache_dir: Path | str):
"""
The API for `mlx_engine` is specified in generate.py
"""
from .generate import load_model, create_generator, tokenize
from .generate import (
load_model,
load_draft_model,
is_draft_model_compatible,
unload_draft_model,
create_generator,
tokenize,
)
108 changes: 93 additions & 15 deletions mlx_engine/cache_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional, Any

from mlx_engine.logging import log_info, log_warn
from mlx_lm.models.cache import (
make_prompt_cache,
trim_prompt_cache,
Expand All @@ -17,7 +18,10 @@ class CacheWrapper:
"""

def __init__(
self, model: nn.Module, max_kv_size: Optional[int], verbose: bool = False
self,
model: nn.Module,
max_kv_size: Optional[int],
verbose: bool = False,
):
"""
Initialize the CacheWrapper.
Expand All @@ -30,6 +34,7 @@ def __init__(
self.tokens: Optional[mx.array] = None
self.cache: List[Any] = make_prompt_cache(model, max_kv_size)
self.model = model
self.draft_model: Optional[nn.Module] = None
self.max_kv_size = max_kv_size
self.verbose = verbose

Expand Down Expand Up @@ -116,6 +121,72 @@ def _get_unprocessed_tokens(
# All of the common tokens are now in the cache, so we can return the remaining tokens that still need to be processed
return prompt_tokens[common_prefix:]

def _prefill(
self,
model,
cache,
tokens,
progress_callback,
start_progress: float,
end_progress: float,
chunk_size: int = 512,
):
"""
Fill a KV cache for a specific model
Args:
model: The model to use for cache filling
cache: The cache to fill
tokens: Tokens to process
progress_callback: Callback for reporting progress
start_progress: Starting progress percentage
end_progress: Ending progress percentage
"""
chunk_size = 512 # Default chunk size
remaining_tokens = tokens
num_processed = 0
total_tokens = len(tokens)

while remaining_tokens.size > 0:
current_chunk_size = min(chunk_size, remaining_tokens.size)
current_chunk = remaining_tokens[:current_chunk_size]

model(current_chunk[None], cache=cache)
mx.eval([c.state for c in cache])

remaining_tokens = remaining_tokens[current_chunk_size:]
num_processed += current_chunk_size

# Scale progress to fit between start_progress and end_progress
progress = start_progress + (end_progress - start_progress) * (
num_processed / total_tokens
)
progress_callback(progress)
mx.metal.clear_cache()

def set_draft_model(self, draft_model: nn.Module):
if self.model is None:
raise ValueError("Cannot add a draft model to cache without a main model")
if self.max_kv_size is not None:
log_warn(prefix="CacheWrapper", message="Disabling max_kv_size when adding a draft model")
self.max_kv_size = None

# clear the current cache, append draft model cache to the end of the main model cache as per
# https://github.com/ml-explore/mlx-examples/blob/514502da22f0dc4c1ac439bdf78c07d5ec41acf7/llms/mlx_lm/utils.py#L381-L382
self.cache: List[Any] = make_prompt_cache(self.model)
if draft_model is not None:
self.cache += make_prompt_cache(draft_model)
self.draft_model = draft_model

def unset_draft_model(self):
if self.draft_model is None:
log_info(
prefix="CacheWrapper", message="No draft model to remove from cache"
)
return
self.draft_model = None
self.cache = self.cache[: len(self.model.layers)]

def update_cache(
self,
prompt_tokens: mx.array,
Expand Down Expand Up @@ -144,21 +215,28 @@ def update_cache(
# Prefill the cache with the non-excluded prompt tokens
prompt_progress_callback(0)
prefill_tokens = prompt_tokens[:-num_tokens_to_exclude]
num_total_prefill_tokens = len(prefill_tokens)
processed: int = 0
chunk_default_size: int = 512
with mx.stream(generation_stream):
while prefill_tokens.size > 0:
chunk_size = min(chunk_default_size, prefill_tokens.size)
chunk = prefill_tokens[:chunk_size]

self.model(chunk[None], cache=self.cache)
mx.eval([c.state for c in self.cache])

prefill_tokens = prefill_tokens[chunk_size:]
processed += chunk_size
prompt_progress_callback((processed / num_total_prefill_tokens) * 100)
mx.metal.clear_cache()
if self.draft_model is not None:
# Fill draft model cache (0% to 50% progress)
draft_cache = self.cache[len(self.model.layers) :]
self._prefill(
model=self.draft_model,
cache=draft_cache,
tokens=prefill_tokens,
progress_callback=prompt_progress_callback,
start_progress=0,
end_progress=50,
)
# Fill main model cache (50% to 100% progress for draft model, 0% to 100% otherwise)
main_cache = self.cache[: len(self.model.layers)]
self._prefill(
model=self.model,
cache=main_cache,
tokens=prefill_tokens,
progress_callback=prompt_progress_callback,
start_progress=50 if self.draft_model is not None else 0,
end_progress=100,
)

# Return the tokens that must still be processed outside of the cache
non_prefill_tokens = prompt_tokens[-num_tokens_to_exclude:]
Expand Down
Loading

0 comments on commit cb1b880

Please sign in to comment.