From 280b3784d45b3094a8e398ef432cb993940044c5 Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Wed, 21 Aug 2024 00:46:19 +0800 Subject: [PATCH 1/9] feat: support batch input in `generate()` The `prompt` argument can now be either a `str` or `list[str]`. The change to `generate()` is backwards-compatible. The changes to `generate_step()`, `top_p_sampling()`, and `min_p_sampling()` are backwards-incompatible in order to unify shapes; this could be changed by adding a few if-statements, if preferred. --- llms/mlx_lm/sample_utils.py | 64 +++++++++------ llms/mlx_lm/server.py | 7 +- llms/mlx_lm/utils.py | 155 ++++++++++++++++++++++-------------- 3 files changed, 138 insertions(+), 88 deletions(-) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 20b008fac..1b403b393 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -26,7 +26,10 @@ def min_p_sampling( 0.99-0.8 range. min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered. Default: ``1``. - + temperature: Temperature parameter for softmax distribution reshaping. + Returns: + token(s) selected based on the min-p criterion. + Shape: same as logits, but with the last dimension having size 1. """ if not (0 <= min_p <= 1.0): raise ValueError( @@ -39,14 +42,14 @@ def min_p_sampling( # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 # Softmax probabilities - probs = mx.softmax(logits * (1 / temperature), axis=-1) + probs = mx.softmax(logits / temperature, axis=-1) # Indices sorted in decreasing order - sorted_indices = mx.argsort(-logits).squeeze(0) - sorted_probs = probs[..., sorted_indices] + sorted_indices = mx.argsort(-logits) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1) # Top probability - top_probs = probs[..., sorted_indices[0]] + top_probs = mx.expand_dims(sorted_probs[..., 0], axis=-1) # Calculate the min_p threshold scaled_min_p = min_p * top_probs @@ -58,13 +61,18 @@ def min_p_sampling( # Create pool of tokens with probability less than scaled min_p selected_probs = mx.where(tokens_to_remove, 0, sorted_probs) - # Return sampled token - sorted_token = mx.random.categorical(mx.log(selected_probs)) - return sorted_indices[sorted_token] + # Return sampled token(s) + sampled_indices = mx.random.categorical(mx.log(selected_probs)) + tokens = mx.take_along_axis( + sorted_indices, mx.expand_dims(sampled_indices, axis=-1), axis=-1 + ) + return tokens.squeeze(-1) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: +def top_p_sampling( + logits: mx.array, top_p: float, temperature: float, axis: int = -1 +) -> mx.array: """ Apply top-p (nucleus) sampling to logits. @@ -72,29 +80,35 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr logits: The logits from the model's output. top_p: The cumulative probability threshold for top-p filtering. temperature: Temperature parameter for softmax distribution reshaping. + axis: The axis along which to apply top-p sampling. Returns: - token selected based on the top-p criterion. + token(s) selected based on the top-p criterion. """ - # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 - probs = mx.softmax(logits * (1 / temperature), axis=-1) + # Apply temperature and compute softmax + probs = mx.softmax(logits / temperature, axis=axis) - # sort probs in ascending order - sorted_indices = mx.argsort(probs, axis=-1) - sorted_probs = probs[..., sorted_indices.squeeze(0)] + # Sort probs in descending order + sorted_indices = mx.argsort(-probs, axis=axis) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis) - cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + # Compute cumulative probabilities + cumulative_probs = mx.cumsum(sorted_probs, axis=axis) - # select tokens with cumulative probs below threshold - top_probs = mx.where( - cumulative_probs > 1 - top_p, - sorted_probs, - 0, - ) + # Create a mask for probs above the threshold + mask = cumulative_probs <= top_p + + # Apply the mask to the sorted probabilities + masked_probs = sorted_probs * mask - sorted_token = mx.random.categorical(mx.log(top_probs)) - token = sorted_indices.squeeze(0)[sorted_token] + # Sample from the normalized probabilities + sampled_indices = mx.random.categorical(mx.log(masked_probs), axis=axis) + + # Gather the original token indices + tokens = mx.take_along_axis( + sorted_indices, mx.expand_dims(sampled_indices, axis=axis), axis=axis + ) - return token + return tokens.squeeze(axis) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 79ac18361..aa2c5ed7c 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -410,7 +410,7 @@ def handle_completion( top_tokens = [] for (token, logprobs), _ in zip( generate_step( - prompt=prompt, + prompts=prompt[None], model=self.model, temp=self.temperature, top_p=self.top_p, @@ -420,6 +420,8 @@ def handle_completion( ), range(self.max_tokens), ): + token = token.item() + logprobs = logprobs.squeeze(0) detokenizer.add_token(token) logging.debug(detokenizer.text) tokens.append(token) @@ -497,7 +499,7 @@ def handle_stream( for (token, _), _ in zip( generate_step( - prompt=prompt, + prompts=prompt[None], model=self.model, temp=self.temperature, top_p=self.top_p, @@ -506,6 +508,7 @@ def handle_stream( ), range(self.max_tokens), ): + token = token.item() detokenizer.add_token(token) logging.debug(detokenizer.text) tokens.append(token) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 71476df3e..95e6fd462 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -117,12 +117,12 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f logits (mx.array): Logits with repetition penalty applied to generated tokens. """ if len(generated_tokens) > 0: - indices = mx.array([token for token in generated_tokens]) - selected_logits = logits[:, indices] + indices = generated_tokens + selected_logits = mx.take_along_axis(logits, indices, axis=-1) selected_logits = mx.where( selected_logits < 0, selected_logits * penalty, selected_logits / penalty ) - logits[:, indices] = selected_logits + logits[mx.arange(indices.shape[0])[:, None], indices] = selected_logits return logits @@ -147,7 +147,7 @@ def make_kv_caches( def generate_step( - prompt: mx.array, + prompts: mx.array, model: nn.Module, temp: float = 0.0, repetition_penalty: Optional[float] = None, @@ -164,7 +164,7 @@ def generate_step( A generator producing token ids based on the given prompt from the model. Args: - prompt (mx.array): The input prompt. + prompts (mx.array): The input prompt(s). Shape: ``(bs, seq_len)``. model (nn.Module): The model to use for generation. temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``. @@ -185,27 +185,33 @@ def generate_step( Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing - one token and a vector of log probabilities. + one token and a vector of log probabilities per prompt. + Shapes: ``(bs, 1), (bs, vocab_size)``. """ - def sample(logits: mx.array) -> Tuple[mx.array, float]: + if prompts.ndim != 2: + raise ValueError( + f"Shape of prompts should be (bs, seq_len), got {prompts.shape}" + ) + + def sample(logits: mx.array) -> Tuple[mx.array, mx.array]: if logit_bias: indices = mx.array(list(logit_bias.keys())) values = mx.array(list(logit_bias.values())) logits[:, indices] += values - logprobs = logits - mx.logsumexp(logits) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) if temp == 0: - token = mx.argmax(logits, axis=-1) + tokens = mx.argmax(logits, axis=-1) else: if top_p > 0 and top_p < 1.0: - token = top_p_sampling(logits, top_p, temp) + tokens = top_p_sampling(logits, top_p, temp) elif min_p != 0.0: - token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) + tokens = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) else: - token = categorical_sampling(logits, temp) + tokens = categorical_sampling(logits, temp) - return token, logprobs + return mx.expand_dims(tokens, axis=-1), logprobs if repetition_penalty and ( repetition_penalty < 0 or not isinstance(repetition_penalty, float) @@ -214,7 +220,7 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: f"repetition_penalty must be a non-negative float, got {repetition_penalty}" ) - y = prompt + y = prompts # Create the KV cache for generation cache = make_kv_caches(model, max_kv_size) @@ -229,14 +235,14 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: c.update_and_fetch(h[0], h[1]) mx.eval([c.state for c in cache]) - repetition_context = prompt.tolist() + repetition_context = prompts if repetition_context_size: - repetition_context = repetition_context[-repetition_context_size:] + repetition_context = repetition_context[:, -repetition_context_size:] def _step(y): nonlocal repetition_context - logits = model(y[None], cache=cache) + logits = model(y, cache=cache) logits = logits[:, -1, :] if repetition_penalty: @@ -244,27 +250,27 @@ def _step(y): logits, repetition_context, repetition_penalty ) y, logprobs = sample(logits) - repetition_context.append(y.item()) + repetition_context = mx.concatenate([repetition_context, y], axis=-1) else: y, logprobs = sample(logits) if repetition_context_size: - if len(repetition_context) > repetition_context_size: - repetition_context = repetition_context[-repetition_context_size:] - return y, logprobs.squeeze(0) + if repetition_context.shape[1] > repetition_context_size: + repetition_context = repetition_context[:, -repetition_context_size:] + return y, logprobs - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=cache) + while y.shape[1] > prefill_step_size: + model(y[:, :prefill_step_size], cache=cache) mx.eval([c.state for c in cache]) - y = y[prefill_step_size:] + y = y[:, prefill_step_size:] y, logprobs = _step(y) - mx.async_eval(y) while True: next_y, next_logprobs = _step(y) mx.async_eval(next_y) - yield y.item(), logprobs + mx.eval(y) + yield y, logprobs y, logprobs = next_y, next_logprobs @@ -296,9 +302,10 @@ def stream_generate( detokenizer.reset() for (token, _), n in zip( - generate_step(prompt_tokens, model, **kwargs), + generate_step(prompt_tokens[None], model, **kwargs), range(max_tokens), ): + token = token.item() if token == tokenizer.eos_token_id: break detokenizer.add_token(token) @@ -313,19 +320,19 @@ def stream_generate( def generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: str, + prompt: Union[str, List[str]], max_tokens: int = 100, verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> Union[str, List[str]]: """ Generate a complete response from the model. Args: model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (str): The string prompt. + prompts (str): The string prompt(s). max_tokens (int): The maximum number of tokens. Default: ``100``. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. @@ -334,56 +341,82 @@ def generate( kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. """ + is_batch = isinstance(prompt, list) if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - if verbose: - print("=" * 10) - print("Prompt:", prompt) - - prompt_tokens = mx.array(tokenizer.encode(prompt)) - detokenizer = tokenizer.detokenizer + if is_batch: + tokenizer._tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer._tokenizer.pad_token = tokenizer.eos_token + tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id + prompt_tokens = mx.array( + tokenizer._tokenizer(prompt, padding=True)["input_ids"] + ) + output_toks = [] + else: + prompt_tokens = mx.array(tokenizer.encode(prompt))[None] + detokenizer = tokenizer.detokenizer + detokenizer.reset() + if verbose: + print("=" * 10) + print("Prompt:", prompt) tic = time.perf_counter() - detokenizer.reset() - for (token, logprobs), n in zip( + for (tokens, logprobs), n in zip( generate_step(prompt_tokens, model, **kwargs), range(max_tokens), ): if n == 0: prompt_time = time.perf_counter() - tic tic = time.perf_counter() - if token == tokenizer.eos_token_id: + if (tokens == tokenizer.eos_token_id).all(): break - detokenizer.add_token(token) - - if verbose: - if formatter: - # We have to finalize so that the prob corresponds to the last segment - detokenizer.finalize() - formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) - else: - print(detokenizer.last_segment, end="", flush=True) - - token_count = n + 1 - detokenizer.finalize() + if is_batch: + output_toks.append(tokens) + else: + token = tokens.item() + logprobs = logprobs.squeeze(0) + detokenizer.add_token(token) + if verbose: + if formatter: + # We have to finalize so that the prob corresponds to the last segment + detokenizer.finalize() + formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) + else: + print(detokenizer.last_segment, end="", flush=True) + + if is_batch: + output_toks = mx.concatenate(output_toks, axis=1) + token_count = output_toks.size + response = [ + response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0] + for response in tokenizer.batch_decode(output_toks.tolist()) + ] + else: + token_count = n + detokenizer.finalize() + response = detokenizer.text if verbose: gen_time = time.perf_counter() - tic - print(detokenizer.last_segment, flush=True) - print("=" * 10) - if token_count == 0: + if token_count <= 0: print("No tokens generated for this prompt") - return + if is_batch: + for p, resp in zip(prompt, response): + print("=" * 10) + print("Prompt:", p) + print(resp) + else: + print(detokenizer.last_segment, flush=True) prompt_tps = prompt_tokens.size / prompt_time - gen_tps = (token_count - 1) / gen_time - print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 2**30 - print(f"Peak memory: {peak_mem:.3f} GB") + gen_tps = token_count / gen_time + print("=" * 10) + print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + print(f"Generation: {gen_tps:.3f} tokens-per-sec") - return detokenizer.text + return response def load_config(model_path: Path) -> dict: From 2caa8329c050c855e48b018118b638332f8c184d Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 23 Aug 2024 16:27:50 +0900 Subject: [PATCH 2/9] feat: show batch generation progress --- llms/mlx_lm/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 95e6fd462..899f20e8f 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -375,6 +375,8 @@ def generate( break if is_batch: output_toks.append(tokens) + if verbose: + print(".", end="", flush=True) else: token = tokens.item() logprobs = logprobs.squeeze(0) @@ -404,6 +406,7 @@ def generate( if token_count <= 0: print("No tokens generated for this prompt") if is_batch: + print() for p, resp in zip(prompt, response): print("=" * 10) print("Prompt:", p) From a28ca03e0429112e46030ae33a940550935156d2 Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 27 Dec 2024 01:51:28 -0800 Subject: [PATCH 3/9] update generate_step callsites --- llms/mlx_lm/cache_prompt.py | 2 +- llms/mlx_lm/utils.py | 6 +++++- llms/tests/test_prompt_cache.py | 33 ++++++++++++++++++++------------- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 9d7d1603d..4f88061e8 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -132,7 +132,7 @@ def main(): prompt = args.prompt cache = make_prompt_cache(model, args.max_kv_size) - y = mx.array(tokenizer.encode(prompt)) + y = mx.array(tokenizer.encode(prompt))[None] # Process the prompt start = time.time() diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index ec52e283d..b4f7728d6 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -296,7 +296,7 @@ def generate_step( def _step(y): with mx.stream(generation_stream): - if y.ndims == 1: + if y.ndim == 1: y = mx.expand_dims(y, axis=-1) logits = model( y, @@ -390,12 +390,16 @@ def stream_generate( prompt if isinstance(prompt, list) else tokenizer.encode(prompt) ) + if prompt.ndim == 1: + prompt = prompt[None] + detokenizer = tokenizer.detokenizer with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): + token, logprobs = token.item(), logprobs.squeeze(0) if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index de5694d58..6acab5a71 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -121,21 +121,24 @@ def test_save_load_mixed_cache(self): def test_cache_with_generate(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] - results = list(generate_step(prompt, model, max_tokens=4)) + results = list(generate_step(prompt[None], model, max_tokens=4)) + results = [(t.item(), l.squeeze(0)) for t, l in results] toks, all_logits = zip(*results) prompt_cache = make_prompt_cache(model) i = 0 for tok, logits in generate_step( - prompt, model, prompt_cache=prompt_cache, max_tokens=2 + prompt[None], model, prompt_cache=prompt_cache, max_tokens=2 ): + tok, logits = tok.item(), logits.squeeze(0) self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i])) i += 1 for tok, logits in generate_step( - mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1 + mx.array([[toks[i]]]), model, prompt_cache=prompt_cache, max_tokens=1 ): + tok, logits = tok.item(), logits.squeeze(0) i += 1 self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i])) @@ -205,14 +208,14 @@ def test_trim_cache_with_generate(self): prompt_cache = make_prompt_cache(model) # Generate one token so we process the full prompt - last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache)) - last_tok = mx.array([last_tok]) + last_tok, _ = next(generate_step(prompt[None], model, prompt_cache=prompt_cache)) # Generate two more tokens results = zip( - range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) + range(2), generate_step(last_tok[None], model, prompt_cache=prompt_cache) ) - toks, all_logits = zip(*(r[1] for r in results)) + results = [(t.item(), l.squeeze(0)) for _, (t, l) in results] + toks, all_logits = zip(*results) # To get back to the cache just after processing the prompt, # trim by 3 tokens @@ -220,9 +223,10 @@ def test_trim_cache_with_generate(self): # Generate the same thing again results = zip( - range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) + range(2), generate_step(last_tok[None], model, prompt_cache=prompt_cache) ) - second_toks, second_all_logits = zip(*(r[1] for r in results)) + results = [(t.item(), l.squeeze(0)) for _, (t, l) in results] + second_toks, second_all_logits = zip(*results) self.assertEqual(toks, second_toks) self.assertTrue( all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits)) @@ -278,14 +282,16 @@ def test_save_load_quantized_cache(self): def test_cache_to_quantized(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] - results = zip(range(4), generate_step(prompt, model)) - toks, all_logits = zip(*(r[1] for r in results)) + results = zip(range(4), generate_step(prompt[None], model)) + results = [(t.item(), l.squeeze(0)) for _, (t, l) in results] + toks, all_logits = zip(*results) prompt_cache = make_prompt_cache(model) i = 0 for _, (tok, logits) in zip( - range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + range(2), generate_step(prompt[None], model, prompt_cache=prompt_cache) ): + tok, logits = tok.item(), logits.squeeze(0) self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i])) i += 1 @@ -294,8 +300,9 @@ def test_cache_to_quantized(self): for _, (tok, logits) in zip( range(1), - generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + generate_step(mx.array([[toks[i]]]), model, prompt_cache=prompt_cache), ): + tok, logits = tok.item(), logits.squeeze(0) i += 1 self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2)) From cded14988ca0c1bd88f7b9f19f8e7b902ed50504 Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 27 Dec 2024 01:53:15 -0800 Subject: [PATCH 4/9] fix test_generate --- llms/tests/test_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index f23453943..b069edefb 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -49,7 +49,7 @@ def logits_processor(toks, logits): verbose=False, logits_processors=[logits_processor], ) - self.assertEqual(len(all_toks), len(init_toks) + 5) + self.assertEqual(all_toks.shape[-1], len(init_toks) + 5) if __name__ == "__main__": From 465eb79fff9812c5c8d5bedb2769ecee3bfab769 Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 27 Dec 2024 13:47:09 -0800 Subject: [PATCH 5/9] implement batch_generate --- llms/mlx_lm/utils.py | 83 ++++++++++++++++++++++++++++++++++++- llms/tests/test_generate.py | 22 ++++++++-- 2 files changed, 100 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b4f7728d6..f28fd8306 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -21,6 +21,7 @@ # Local imports from .models import cache +from .models.base import create_causal_mask from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model @@ -355,6 +356,7 @@ def _step(y): prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) if n == max_tokens: break + mx.eval(y) yield y, logprobs if n % 256 == 0: mx.metal.clear_cache() @@ -488,8 +490,85 @@ def generate( return text -def batch_generate(): - pass +def batch_generate( + model: nn.Module, + tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], + prompts: list[str], + verbose: bool = False, + **kwargs, +) -> list[str]: + """ + Generate complete responses from the model for a list of prompts. + + Args: + model (nn.Module): The language model. + tokenizer (PreTrainedTokenizer): The tokenizer. + prompts (List[str]): The string prompts. + verbose (bool): If ``True``, print tokens and timing information. + Default: ``False``. + kwargs: The remaining options get passed to :func:`generate_step`. + See :func:`generate_step` for more details. + """ + if 'prompt_cache' in kwargs: + # TODO: Handle `prompt_cache` and `prompt` both left-padded, so that + # we have texttext. Should involve taking `prompt_cache_lens` + # to extend `mask` below, and handling position_ids (see TODO below) + raise ValueError("Batch generation does not support prompt_cache yet.") + if not isinstance(tokenizer, TokenizerWrapper): + tokenizer = TokenizerWrapper(tokenizer) + # TODO: left-shift position_ids for absolute/rotary positional encodings + # Example: https://github.com/huggingface/transformers/issues/26072#issuecomment-2101209470 + tokenizer._tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer._tokenizer.pad_token = tokenizer.eos_token + tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id + res = tokenizer._tokenizer(prompts, padding=True) + input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"]) + causal_mask = create_causal_mask(token_mask.shape[-1]) + mask = mx.where(token_mask[:, None, None, :], causal_mask, -1e9) + + output_toks = [] + prompt_time = None + ended = mx.zeros(len(prompts), dtype=mx.bool_) + tic = time.perf_counter() + # TODO: non-generator version of `generate_step` so that we can + # add or remove prompts from the batch as they start/finish + for tokens, _ in generate_step(input_ids, model, mask=mask, **kwargs): + if not prompt_time: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + ended = ended | (tokens == tokenizer.eos_token_id) + if ended.all(): + break + output_toks.append(tokens) + if verbose: + print(".", end="", flush=True) + output_toks = mx.stack(output_toks, axis=-1) + token_count = output_toks.size + response = [ + response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0] + for response in tokenizer.batch_decode(output_toks.tolist()) + ] + if verbose: + gen_time = time.perf_counter() - tic + if token_count <= 0: + print("No tokens generated for this prompt") + else: + print() + for p, resp in zip(prompts, response): + print("=" * 10) + print("Prompt:", p) + print(resp) + print("=" * 10) + if prompt_time: + prompt_tps = input_ids.size / prompt_time + print(f"Prompt: {input_ids.size} tokens, {prompt_tps:.3f} tokens-per-sec") + gen_tps = token_count / gen_time + print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") + peak_mem = mx.metal.get_peak_memory() / 2**30 + print(f"Peak memory: {peak_mem:.3f} GB") + + return response def load_config(model_path: Path) -> dict: diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index b069edefb..14fa75e91 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -2,12 +2,11 @@ import unittest -from mlx_lm.sample_utils import make_logits_processors -from mlx_lm.utils import generate, load +from mlx_lm.sample_utils import make_logits_processors, make_sampler +from mlx_lm.utils import generate, batch_generate, load class TestGenerate(unittest.TestCase): - @classmethod def setUpClass(cls): HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" @@ -51,6 +50,23 @@ def logits_processor(toks, logits): ) self.assertEqual(all_toks.shape[-1], len(init_toks) + 5) + def test_batch_generate(self): + logit_bias = {0: 20.0, 1: -20.0} + texts = batch_generate( + self.model, + self.tokenizer, + [ + "hello", + "this is a longer prompt to test out the padding and masking. hello", + ], + max_tokens=5, + prefill_step_size=4, + sampler=make_sampler(temp=1.0, min_p=0.1), + logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0), + verbose=False, + ) + self.assertEqual(texts, ['!', '!']) + if __name__ == "__main__": unittest.main() From fdd16caf7a021360ded434fa65060e4691d79441 Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 27 Dec 2024 15:26:43 -0800 Subject: [PATCH 6/9] mask dtype --- llms/mlx_lm/models/base.py | 5 ++++- llms/mlx_lm/utils.py | 11 +++++++++-- llms/tests/test_generate.py | 2 +- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index ad7a4a65a..3d402aa1b 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -28,6 +28,7 @@ def create_causal_mask( offset: int = 0, window_size: Optional[int] = None, lengths: Optional[mx.array] = None, + dtype: mx.Dtype = mx.float32, ): rinds = mx.arange(offset + N) linds = mx.arange(offset, offset + N) if offset else rinds @@ -39,7 +40,9 @@ def create_causal_mask( if lengths is not None: lengths = lengths[:, None, None, None] mask = mask | (rinds >= lengths) - return mask * -1e9 + # HACK: sometimes see NaN logprobs if no divide by 2 here + # return mask * (mx.finfo(dtype).min / 2) + return mask.astype(dtype) * (-65504. / 2) def create_attention_mask(h: mx.array, cache: Optional[Any] = None): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index f28fd8306..185a06986 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -524,8 +524,15 @@ def batch_generate( tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id res = tokenizer._tokenizer(prompts, padding=True) input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"]) - causal_mask = create_causal_mask(token_mask.shape[-1]) - mask = mx.where(token_mask[:, None, None, :], causal_mask, -1e9) + dtype = None + for module in model.modules(): + if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding): + dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype + break + causal_mask = create_causal_mask(token_mask.shape[-1], dtype=dtype) + # HACK: sometimes see NaN logprobs if no divide by 2 here + # mask = mx.where(token_mask[:, None, None, :], causal_mask, mx.finfo(dtype).min / 2) + mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504. / 2) output_toks = [] prompt_time = None diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index 14fa75e91..41f9704f2 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -61,7 +61,7 @@ def test_batch_generate(self): ], max_tokens=5, prefill_step_size=4, - sampler=make_sampler(temp=1.0, min_p=0.1), + sampler=make_sampler(temp=1., min_p=0.5), logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0), verbose=False, ) From 30e98c85c12865d56bcf28171f7105b21c385f42 Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 27 Dec 2024 15:52:42 -0800 Subject: [PATCH 7/9] tweaks --- llms/mlx_lm/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 185a06986..0925e4693 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -296,9 +296,9 @@ def generate_step( prompt_progress_callback = prompt_progress_callback or (lambda *_: None) def _step(y): + if y.ndim == 1: + y = mx.expand_dims(y, axis=-1) with mx.stream(generation_stream): - if y.ndim == 1: - y = mx.expand_dims(y, axis=-1) logits = model( y, cache=prompt_cache, @@ -514,6 +514,7 @@ def batch_generate( # we have texttext. Should involve taking `prompt_cache_lens` # to extend `mask` below, and handling position_ids (see TODO below) raise ValueError("Batch generation does not support prompt_cache yet.") + tokenizer = copy.deepcopy(tokenizer) if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) # TODO: left-shift position_ids for absolute/rotary positional encodings From 089480878fabbdd2215816bf3dc303602772b48d Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 27 Dec 2024 16:01:52 -0800 Subject: [PATCH 8/9] dtype fix --- llms/mlx_lm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0925e4693..9e3d37780 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -525,7 +525,7 @@ def batch_generate( tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id res = tokenizer._tokenizer(prompts, padding=True) input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"]) - dtype = None + dtype = mx.float32 for module in model.modules(): if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding): dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype From 2541f13907c4908335baf05bcdda5b2a04f94105 Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 27 Dec 2024 16:21:20 -0800 Subject: [PATCH 9/9] format --- llms/mlx_lm/models/base.py | 2 +- llms/mlx_lm/utils.py | 16 ++++++++++------ llms/tests/test_generate.py | 8 +++++--- llms/tests/test_prompt_cache.py | 4 +++- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 3d402aa1b..568b85abb 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -42,7 +42,7 @@ def create_causal_mask( mask = mask | (rinds >= lengths) # HACK: sometimes see NaN logprobs if no divide by 2 here # return mask * (mx.finfo(dtype).min / 2) - return mask.astype(dtype) * (-65504. / 2) + return mask.astype(dtype) * (-65504.0 / 2) def create_attention_mask(h: mx.array, cache: Optional[Any] = None): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 9e3d37780..714609f59 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -329,9 +329,11 @@ def _step(y): model( y[:, :prefill_step_size], cache=prompt_cache, - mask=mask[:, :, :prefill_step_size, : offset + prefill_step_size] - if mask is not None - else None, + mask=( + mask[:, :, :prefill_step_size, : offset + prefill_step_size] + if mask is not None + else None + ), ) maybe_quantize_kv_cache( prompt_cache, quantized_kv_start, kv_group_size, kv_bits @@ -509,7 +511,7 @@ def batch_generate( kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. """ - if 'prompt_cache' in kwargs: + if "prompt_cache" in kwargs: # TODO: Handle `prompt_cache` and `prompt` both left-padded, so that # we have texttext. Should involve taking `prompt_cache_lens` # to extend `mask` below, and handling position_ids (see TODO below) @@ -527,13 +529,15 @@ def batch_generate( input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"]) dtype = mx.float32 for module in model.modules(): - if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding): + if isinstance(module, nn.QuantizedEmbedding) or isinstance( + module, nn.Embedding + ): dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype break causal_mask = create_causal_mask(token_mask.shape[-1], dtype=dtype) # HACK: sometimes see NaN logprobs if no divide by 2 here # mask = mx.where(token_mask[:, None, None, :], causal_mask, mx.finfo(dtype).min / 2) - mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504. / 2) + mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504.0 / 2) output_toks = [] prompt_time = None diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index 41f9704f2..bcdb0d9f7 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -61,11 +61,13 @@ def test_batch_generate(self): ], max_tokens=5, prefill_step_size=4, - sampler=make_sampler(temp=1., min_p=0.5), - logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0), + sampler=make_sampler(temp=1.0, min_p=0.5), + logits_processors=make_logits_processors( + logit_bias, repetition_penalty=2.0 + ), verbose=False, ) - self.assertEqual(texts, ['!', '!']) + self.assertEqual(texts, ["!", "!"]) if __name__ == "__main__": diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 6acab5a71..5fcc6834a 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -208,7 +208,9 @@ def test_trim_cache_with_generate(self): prompt_cache = make_prompt_cache(model) # Generate one token so we process the full prompt - last_tok, _ = next(generate_step(prompt[None], model, prompt_cache=prompt_cache)) + last_tok, _ = next( + generate_step(prompt[None], model, prompt_cache=prompt_cache) + ) # Generate two more tokens results = zip(