From 68403f5577b1fac4c34d2805aff3b448351a72be Mon Sep 17 00:00:00 2001 From: paNikitin <115797306+paNikitin@users.noreply.github.com> Date: Sun, 23 Feb 2025 12:31:44 +0300 Subject: [PATCH 1/9] added cot loss masking training --- llms/mlx_lm/examples/lora_config.yaml | 6 + llms/mlx_lm/lora.py | 46 +++++--- llms/mlx_lm/tuner/new_tokens.py | 162 ++++++++++++++++++++++++++ llms/mlx_lm/tuner/trainer.py | 118 +++++++++++++------ 4 files changed, 279 insertions(+), 53 deletions(-) create mode 100644 llms/mlx_lm/tuner/new_tokens.py diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index 530272c7a..ca799f8d7 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -64,6 +64,12 @@ lora_parameters: scale: 20.0 dropout: 0.0 +# cot loss masking training +# cot: +# use_cot: true +# special: true +# additional_tokens: ["[REASONING]", "[DATA]"] + # Schedule can only be specified in a config file, uncomment to use. #lr_schedule: # name: cosine_decay diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index def3b6ddf..45119025b 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -62,6 +62,7 @@ "grad_checkpoint": False, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, + "cot": False, } @@ -78,7 +79,6 @@ def build_parser(): "--train", action="store_true", help="Do training", - default=None, ) parser.add_argument( "--data", @@ -94,14 +94,6 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) - - parser.add_argument( - "--mask-prompt", - action="store_true", - help="Mask the prompt in the loss when training", - default=False, - ) - parser.add_argument( "--num-layers", type=int, @@ -144,7 +136,6 @@ def build_parser(): "--test", action="store_true", help="Evaluate on the test set after training", - default=None, ) parser.add_argument( "--test-batches", @@ -166,9 +157,13 @@ def build_parser(): "--grad-checkpoint", action="store_true", help="Use gradient checkpointing to reduce memory use.", - default=None, ) parser.add_argument("--seed", type=int, help="The PRNG seed") + parser.add_argument( + "--cot", + type=bool, + help="Use CoT loss masking", + ) return parser @@ -181,14 +176,8 @@ def train_model( training_callback: TrainingCallback = None, ): model.freeze() - if args.num_layers > len(model.layers): - raise ValueError( - f"Requested to train {args.num_layers} layers " - f"but the model only has {len(model.layers)} layers." - ) - if args.fine_tune_type == "full": - for l in model.layers[-max(args.num_layers, 0) :]: + for l in model.layers[-min(args.num_layers, 0) :]: l.unfreeze() elif args.fine_tune_type in ["lora", "dora"]: # Convert linear layers to lora/dora layers and unfreeze in the process @@ -225,10 +214,13 @@ def train_model( adapter_file=adapter_file, max_seq_length=args.max_seq_length, grad_checkpoint=args.grad_checkpoint, + cot=(cot := args.cot), ) model.train() - opt = optim.Adam( + # todo optimizer from args + + opt = optim.AdamW( learning_rate=( build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate ) @@ -269,6 +261,21 @@ def run(args, training_callback: TrainingCallback = None): print("Loading pretrained model") model, tokenizer = load(args.model) + if cot := args.cot: + print("Using CoT loss masking") + if tokens := cot.get("additional_tokens"): + from .tuner.new_tokens import implement_new_tokens + + special = False + if (special_arg := cot.get("special")) and isinstance(special_arg, bool): + print("Updating model and tokenizer with new special tokens") + special = special_arg + else: + print("Updating model and tokenizer with new tokens") + model, tokenizer = implement_new_tokens( + model=model, tokenizer=tokenizer, tokens=tokens, special=special + ) + print("Loading datasets") train_set, valid_set, test_set = load_dataset(args, tokenizer) @@ -293,6 +300,7 @@ def main(): parser = build_parser() args = parser.parse_args() config = args.config + args = vars(args) if config: print("Loading configuration file", config) diff --git a/llms/mlx_lm/tuner/new_tokens.py b/llms/mlx_lm/tuner/new_tokens.py new file mode 100644 index 000000000..489a6705a --- /dev/null +++ b/llms/mlx_lm/tuner/new_tokens.py @@ -0,0 +1,162 @@ +import mlx.nn as nn +import mlx.core as mx +from mlx_lm.tokenizer_utils import TokenizerWrapper + + +def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Module: + """ + Resizes model embeddings to accommodate new tokens + """ + old_embedding = model.model.embed_tokens + + old_vocab_size = old_embedding.num_embeddings + new_vocab_size = len(tokenizer._tokenizer) + + if old_vocab_size != new_vocab_size: + if new_vocab_size < old_vocab_size: + print( + "Warning: New vocab size is smaller than original. Proceeding with trim." + ) + + # check if QuantizedEmbedding has required attributes for dequantization + try: + dequantized_weights = mx.dequantize( + old_embedding.weight, + scales=old_embedding.scales, + biases=old_embedding.biases, + group_size=old_embedding.group_size, + bits=old_embedding.bits, + ) + except AttributeError as e: + print(f"Error: Cannot dequantize embed_tokens. Missing attributes: {e}") + print("Falling back to random weights for embed_tokens.") + dequantized_weights = mx.random.normal( + (old_vocab_size, old_embedding.dims), loc=0.0, scale=0.02 + ) + + # resize embed_tokens + new_embedding = nn.Embedding(new_vocab_size, old_embedding.dims) + new_weights = mx.zeros((new_vocab_size, old_embedding.dims)) + min_vocab_size = min(old_vocab_size, new_vocab_size) + new_weights[:min_vocab_size] = dequantized_weights[:min_vocab_size] + if new_vocab_size > old_vocab_size: + new_weights[old_vocab_size:] = mx.random.normal( + (new_vocab_size - old_vocab_size, old_embedding.dims), + loc=0.0, + scale=0.02, + ) + new_embedding.weight = new_weights + model.model.embed_tokens = new_embedding + + # attention layers handling + if hasattr(model, "args") and getattr(model.args, "tie_word_embeddings", False): + model.model.embed_tokens.weight = new_weights + elif hasattr(model, "lm_head"): + old_lm_head = model.lm_head + if isinstance(old_lm_head, nn.QuantizedLinear): + # resize nn.QuantizedLinear + output_dims, compressed_input_dims = old_lm_head.weight.shape + bits = old_lm_head.bits + input_dims = compressed_input_dims * (32 // bits) + + # dequantize lm_head weights + try: + dequantized_lm_weights = mx.dequantize( + old_lm_head.weight, + scales=old_lm_head.scales, + biases=old_lm_head.biases, + group_size=old_lm_head.group_size, + bits=old_lm_head.bits, + ) + except AttributeError as e: + print(f"Error: Cannot dequantize lm_head. Missing attributes: {e}") + print("Falling back to random weights for lm_head.") + dequantized_lm_weights = mx.random.normal( + (output_dims, input_dims), loc=0.0, scale=0.02 + ) + + new_lm_head = nn.QuantizedLinear( + input_dims=input_dims, + output_dims=new_vocab_size, + bias="bias" in old_lm_head, + group_size=old_lm_head.group_size, + bits=old_lm_head.bits, + ) + new_weights_lm = mx.zeros((new_vocab_size, input_dims)) + new_weights_lm[:min_vocab_size] = dequantized_lm_weights[ + :min_vocab_size + ] + if new_vocab_size > output_dims: + new_weights_lm[output_dims:] = mx.random.normal( + (new_vocab_size - output_dims, input_dims), loc=0.0, scale=0.02 + ) + new_lm_head.weight, new_lm_head.scales, new_lm_head.biases = ( + mx.quantize( + new_weights_lm, new_lm_head.group_size, new_lm_head.bits + ) + ) + if "bias" in old_lm_head: + new_lm_head.bias = mx.zeros((new_vocab_size,)) + new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[ + :min_vocab_size + ] + else: + # resize nn.Linear + new_lm_head = nn.Linear( + old_lm_head.input_dims, new_vocab_size, bias="bias" in old_lm_head + ) + new_weights_lm = mx.zeros((new_vocab_size, old_lm_head.input_dims)) + min_vocab_size = min(old_lm_head.weight.shape[0], new_vocab_size) + new_weights_lm[:min_vocab_size] = old_lm_head.weight[:min_vocab_size] + if new_vocab_size > old_lm_head.weight.shape[0]: + new_weights_lm[old_lm_head.weight.shape[0] :] = mx.random.normal( + ( + new_vocab_size - old_lm_head.weight.shape[0], + old_lm_head.input_dims, + ), + loc=0.0, + scale=0.02, + ) + new_lm_head.weight = new_weights_lm + # todo typechecking + if "bias" in old_lm_head: + new_lm_head.bias = mx.zeros((new_vocab_size,)) + new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[ + :min_vocab_size + ] + + model.lm_head = new_lm_head + else: + print("Vocab already sized right.") + return model + + +def update_tokenizer( + tokenizer: TokenizerWrapper, tokens: list[str], special: bool +) -> TokenizerWrapper: + """ + Appends new tokens to the end of the tokenizer vocab + """ + if special: + # todo TokenizerWrapper access method + tokenizer._tokenizer.add_special_tokens({"additional_special_tokens": tokens}) + print(f"Tokenizer updated with special tokens: {tokens}") + print(f"Tokenizer vocab size after append: {len(tokenizer._tokenizer)}") + else: + # todo add regular tokens + pass + return tokenizer + + +def implement_new_tokens( + model: nn.Module, + tokenizer: TokenizerWrapper, + tokens: list[str], + special: bool = False, +) -> tuple[nn.Module, TokenizerWrapper]: + """ + Update model`s tokenizer and embeddings with new tokens accordingly + """ + tokenizer = update_tokenizer(tokenizer=tokenizer, tokens=tokens, special=special) + model = resize_embeddings(model=model, tokenizer=tokenizer) + return model, tokenizer \ No newline at end of file diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 64e26af8d..4f5db2ea7 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -1,20 +1,20 @@ # Copyright © 2024 Apple Inc. +from functools import partial import glob import shutil import time from dataclasses import dataclass, field from pathlib import Path -from typing import List, Optional, Tuple +from typing import Union import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten -from transformers import PreTrainedTokenizer -from .datasets import CompletionsDataset +from mlx_lm.tokenizer_utils import TokenizerWrapper def grad_checkpoint(layer): @@ -64,32 +64,80 @@ class TrainingArgs: default=False, metadata={"help": "Use gradient checkpointing to reduce memory use."}, ) + cot: bool = field( + default=False, + metadata={"help": "Use CoT loss masking with positioning penalty"}, + ) -def default_loss(model, batch, lengths): - inputs = batch[:, :-1] - targets = batch[:, 1:] - +def default_loss(model, inputs, targets, lengths): logits = model(inputs) logits = logits.astype(mx.float32) - steps = mx.arange(1, targets.shape[1] + 1) - mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) + length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] - ce = nn.losses.cross_entropy(logits, targets) * mask - ntoks = mask.sum() + ce = nn.losses.cross_entropy(logits, targets) * length_mask + ntoks = length_mask.sum() ce = ce.sum() / ntoks return ce, ntoks -def iterate_batches( - dataset, - tokenizer, - batch_size, - max_seq_length, - train=False, -): +@dataclass +class CotTrainingArgs: + cot: bool = False + reasoning_token: str = "[REASONING]" + data_token: str = "[DATA]" + + +def cot_loss( + model: nn.Module, + inputs: mx.array, + targets: mx.array, + lengths: int, + tokenizer: TokenizerWrapper, + penalty: mx.float32 = 10.0, +) -> tuple[mx.array, mx.array]: + logits = model(inputs).astype(mx.float32) + + reasoning_token_id = tokenizer.encode(CotTrainingArgs.reasoning_token)[0] + data_token_id = tokenizer.encode(CotTrainingArgs.data_token)[0] + + reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1) + data_positions = mx.argmax(targets == data_token_id, axis=1) + + seq_indices = mx.arange(targets.shape[1])[None, :] + + # base CoT mask: starts at [DATA] + cot_mask = (seq_indices >= data_positions[:, None]).astype(mx.float32) + + # length mask: limits to non-padded regions + length_mask = (seq_indices < lengths[:, None]).astype(mx.float32) + + # combine masks: only include tokens after [DATA] AND within sequence length + loss_mask = cot_mask * length_mask + + # validate sequence structure + valid_seq = ( + (reasoning_positions < data_positions) + & mx.any(targets == reasoning_token_id, axis=1) + & mx.any(targets == data_token_id, axis=1) + ) + + # compute base cross-entropy loss + ce = nn.losses.cross_entropy(logits, targets) + + # masking loss before [DATA]; applying penalty for invalid seq + valid_loss = (ce * loss_mask).sum(axis=1) / (mx.sum(loss_mask, axis=1) + 1e-8) + final_loss = mx.where(valid_seq, valid_loss, penalty) # 10.0 as invalid penalty + loss = mx.mean(final_loss) + + valid_tokens = mx.sum(loss_mask) + 1e-8 + + return loss, valid_tokens + + +def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): # Sort by length: idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) if len(dataset) < batch_size: @@ -114,10 +162,6 @@ def iterate_batches( indices = np.random.permutation(len(batch_idx)) for i in indices: batch = [dataset[j] for j in batch_idx[i]] - if len(batch[0]) == 2: - batch, offsets = zip(*batch) - else: - offsets = [0] * len(batch) lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: print( @@ -140,7 +184,8 @@ def iterate_batches( truncated_length # Update lengths to match truncated lengths ) batch = mx.array(batch_arr) - yield batch, mx.array(list(zip(offsets, lengths))) + + yield batch[:, :-1], batch[:, 1:], mx.array(lengths) if not train: break @@ -156,8 +201,8 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = mx.array(0.0) - ntokens = mx.array(0) + all_losses = 0 + ntokens = 0 index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -213,6 +258,11 @@ def train( if args.grad_checkpoint: grad_checkpoint(model.layers[0]) + if args.cot: + loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0) + else: + loss = default_loss + state = [model.state, optimizer.state] def step(batch): @@ -233,8 +283,8 @@ def step(batch): n_tokens = 0 steps = 0 trained_tokens = 0 - train_time = 0 # Main training loop + start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), iterate_batches( @@ -245,11 +295,10 @@ def step(batch): train=True, ), ): - tic = time.perf_counter() # Report validation loss if needed, the first validation loss # is always measured before any training. if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: - tic = time.perf_counter() + stop = time.perf_counter() val_loss = evaluate( model=model, dataset=val_dataset, @@ -260,7 +309,7 @@ def step(batch): max_seq_length=args.max_seq_length, iterate_batches=iterate_batches, ) - val_time = time.perf_counter() - tic + val_time = time.perf_counter() - stop if rank == 0: print( f"Iter {it}: " @@ -277,23 +326,24 @@ def step(batch): } training_callback.on_val_loss_report(val_info) - tic = time.perf_counter() + start = time.perf_counter() lvalue, toks = step(batch) losses += lvalue n_tokens += toks steps += 1 mx.eval(state, losses, n_tokens) - train_time += time.perf_counter() - tic # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: + stop = time.perf_counter() + train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() - it_sec = args.steps_per_report / train_time - tokens_sec = float(n_tokens) / train_time + it_sec = args.steps_per_report / (stop - start) + tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: @@ -322,7 +372,7 @@ def step(batch): losses = 0 n_tokens = 0 steps = 0 - train_time = 0 + start = time.perf_counter() # Save adapter weights if it % args.steps_per_save == 0: From 95d44228c942af9157f5115a509615403cac8b76 Mon Sep 17 00:00:00 2001 From: paNikitin <115797306+paNikitin@users.noreply.github.com> Date: Sun, 23 Feb 2025 12:34:17 +0300 Subject: [PATCH 2/9] Update new_tokens.py --- llms/mlx_lm/tuner/new_tokens.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/tuner/new_tokens.py b/llms/mlx_lm/tuner/new_tokens.py index 489a6705a..981d4a56c 100644 --- a/llms/mlx_lm/tuner/new_tokens.py +++ b/llms/mlx_lm/tuner/new_tokens.py @@ -159,4 +159,4 @@ def implement_new_tokens( """ tokenizer = update_tokenizer(tokenizer=tokenizer, tokens=tokens, special=special) model = resize_embeddings(model=model, tokenizer=tokenizer) - return model, tokenizer \ No newline at end of file + return model, tokenizer From 0f790c4c84f05d12ce477ccc286766f1e8ec7adc Mon Sep 17 00:00:00 2001 From: paNikitin <115797306+paNikitin@users.noreply.github.com> Date: Sun, 23 Feb 2025 12:37:49 +0300 Subject: [PATCH 3/9] Update lora.py --- llms/mlx_lm/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 45119025b..6edea28dc 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -177,7 +177,7 @@ def train_model( ): model.freeze() if args.fine_tune_type == "full": - for l in model.layers[-min(args.num_layers, 0) :]: + for l in model.layers[-max(args.num_layers, 0) :]: l.unfreeze() elif args.fine_tune_type in ["lora", "dora"]: # Convert linear layers to lora/dora layers and unfreeze in the process From a2b61afd05ac934c422978979394e46ec770f6f1 Mon Sep 17 00:00:00 2001 From: paNikitin <115797306+paNikitin@users.noreply.github.com> Date: Sun, 23 Feb 2025 12:48:25 +0300 Subject: [PATCH 4/9] update upstream --- llms/mlx_lm/lora.py | 14 ++++++++++++ llms/mlx_lm/tuner/trainer.py | 42 ++++++++++++++++++++++-------------- 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 6edea28dc..eb0a279e5 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -79,6 +79,7 @@ def build_parser(): "--train", action="store_true", help="Do training", + default=None, ) parser.add_argument( "--data", @@ -94,6 +95,12 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) + parser.add_argument( + "--mask-prompt", + action="store_true", + help="Mask the prompt in the loss when training", + default=False, + ) parser.add_argument( "--num-layers", type=int, @@ -136,6 +143,7 @@ def build_parser(): "--test", action="store_true", help="Evaluate on the test set after training", + default=None, ) parser.add_argument( "--test-batches", @@ -157,6 +165,7 @@ def build_parser(): "--grad-checkpoint", action="store_true", help="Use gradient checkpointing to reduce memory use.", + default=None, ) parser.add_argument("--seed", type=int, help="The PRNG seed") parser.add_argument( @@ -176,6 +185,11 @@ def train_model( training_callback: TrainingCallback = None, ): model.freeze() + if args.num_layers > len(model.layers): + raise ValueError( + f"Requested to train {args.num_layers} layers " + f"but the model only has {len(model.layers)} layers." + ) if args.fine_tune_type == "full": for l in model.layers[-max(args.num_layers, 0) :]: l.unfreeze() diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 4f5db2ea7..24e93f92a 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -6,6 +6,7 @@ import time from dataclasses import dataclass, field from pathlib import Path +from typing import List, Optional, Tuple from typing import Union import mlx.core as mx @@ -13,7 +14,7 @@ import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten - +from transformers import PreTrainedTokenizer from mlx_lm.tokenizer_utils import TokenizerWrapper @@ -70,14 +71,18 @@ class TrainingArgs: ) -def default_loss(model, inputs, targets, lengths): +def default_loss(model, batch, lengths): + inputs = batch[:, :-1] + targets = batch[:, 1:] + logits = model(inputs) logits = logits.astype(mx.float32) - length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + steps = mx.arange(1, targets.shape[1] + 1) + mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) - ce = nn.losses.cross_entropy(logits, targets) * length_mask - ntoks = length_mask.sum() + ce = nn.losses.cross_entropy(logits, targets) * mask + ntoks = mask.sum() ce = ce.sum() / ntoks return ce, ntoks @@ -162,6 +167,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) indices = np.random.permutation(len(batch_idx)) for i in indices: batch = [dataset[j] for j in batch_idx[i]] + if len(batch[0]) == 2: + batch, offsets = zip(*batch) + else: + offsets = [0] * len(batch) lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: print( @@ -185,7 +194,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) ) batch = mx.array(batch_arr) - yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + yield batch, mx.array(list(zip(offsets, lengths))) if not train: break @@ -201,8 +210,8 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = 0 - ntokens = 0 + all_losses = mx.array(0.0) + ntokens = mx.array(0) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -283,6 +292,7 @@ def step(batch): n_tokens = 0 steps = 0 trained_tokens = 0 + train_time = 0 # Main training loop start = time.perf_counter() for it, batch in zip( @@ -295,10 +305,11 @@ def step(batch): train=True, ), ): + tic = time.perf_counter() # Report validation loss if needed, the first validation loss # is always measured before any training. if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: - stop = time.perf_counter() + tic = time.perf_counter() val_loss = evaluate( model=model, dataset=val_dataset, @@ -309,7 +320,7 @@ def step(batch): max_seq_length=args.max_seq_length, iterate_batches=iterate_batches, ) - val_time = time.perf_counter() - stop + val_time = time.perf_counter() - tic if rank == 0: print( f"Iter {it}: " @@ -326,24 +337,23 @@ def step(batch): } training_callback.on_val_loss_report(val_info) - start = time.perf_counter() + tic = time.perf_counter() lvalue, toks = step(batch) losses += lvalue n_tokens += toks steps += 1 mx.eval(state, losses, n_tokens) + train_time += time.perf_counter() - tic # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: - stop = time.perf_counter() - train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() - it_sec = args.steps_per_report / (stop - start) - tokens_sec = float(n_tokens) / (stop - start) + it_sec = args.steps_per_report / train_time + tokens_sec = float(n_tokens) / train_time trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: @@ -372,7 +382,7 @@ def step(batch): losses = 0 n_tokens = 0 steps = 0 - start = time.perf_counter() + train_time = 0 # Save adapter weights if it % args.steps_per_save == 0: From 5b7581f41ceb012d6a1a5b36f6ac5dfe5154580d Mon Sep 17 00:00:00 2001 From: paNikitin <115797306+paNikitin@users.noreply.github.com> Date: Sun, 23 Feb 2025 12:56:09 +0300 Subject: [PATCH 5/9] Update trainer.py --- llms/mlx_lm/tuner/trainer.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 24e93f92a..cd8f513c6 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -69,6 +69,14 @@ class TrainingArgs: default=False, metadata={"help": "Use CoT loss masking with positioning penalty"}, ) + reasoning_token: str = field( + default="[REASONING]", + metadata={"help": "Reasoning token"}, + ) + data_token: str = field( + default="[DATA]", + metadata={"help": "Final answer token"}, + ) def default_loss(model, batch, lengths): @@ -88,25 +96,19 @@ def default_loss(model, batch, lengths): return ce, ntoks -@dataclass -class CotTrainingArgs: - cot: bool = False - reasoning_token: str = "[REASONING]" - data_token: str = "[DATA]" - - def cot_loss( model: nn.Module, inputs: mx.array, targets: mx.array, lengths: int, tokenizer: TokenizerWrapper, + args: TrainingArgs, penalty: mx.float32 = 10.0, ) -> tuple[mx.array, mx.array]: logits = model(inputs).astype(mx.float32) - reasoning_token_id = tokenizer.encode(CotTrainingArgs.reasoning_token)[0] - data_token_id = tokenizer.encode(CotTrainingArgs.data_token)[0] + reasoning_token_id = tokenizer.encode(args.reasoning_token)[0] + data_token_id = tokenizer.encode(args.data_token)[0] reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1) data_positions = mx.argmax(targets == data_token_id, axis=1) @@ -268,7 +270,7 @@ def train( grad_checkpoint(model.layers[0]) if args.cot: - loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0) + loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0, args=args) else: loss = default_loss From 231f5e870e3e00a08094b9bd73d738b53e80c414 Mon Sep 17 00:00:00 2001 From: paNikitin <115797306+paNikitin@users.noreply.github.com> Date: Sun, 23 Feb 2025 15:24:59 +0300 Subject: [PATCH 6/9] fixed full dequantization mem leak --- llms/mlx_lm/tuner/new_tokens.py | 207 ++++++++++++++++++-------------- 1 file changed, 120 insertions(+), 87 deletions(-) diff --git a/llms/mlx_lm/tuner/new_tokens.py b/llms/mlx_lm/tuner/new_tokens.py index 981d4a56c..dc7355409 100644 --- a/llms/mlx_lm/tuner/new_tokens.py +++ b/llms/mlx_lm/tuner/new_tokens.py @@ -5,20 +5,61 @@ def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Module: """ - Resizes model embeddings to accommodate new tokens + Resizes model embeddings to accommodate new tokens, minimizing dequantization. """ old_embedding = model.model.embed_tokens - old_vocab_size = old_embedding.num_embeddings new_vocab_size = len(tokenizer._tokenizer) - if old_vocab_size != new_vocab_size: - if new_vocab_size < old_vocab_size: - print( - "Warning: New vocab size is smaller than original. Proceeding with trim." + if old_vocab_size == new_vocab_size: + print("Vocab already sized right.") + return model + + if new_vocab_size < old_vocab_size: + print("Warning: New vocab size is smaller than original. Proceeding with trim.") + + if ( + hasattr(old_embedding, "weight") + and hasattr(old_embedding, "scales") + and hasattr(old_embedding, "biases") + and hasattr(old_embedding, "group_size") + and hasattr(old_embedding, "bits") + ): + # quantized embedding case: minimize dequantization + + new_embedding = nn.QuantizedEmbedding( + new_vocab_size, + old_embedding.dims, + group_size=old_embedding.group_size, + bits=old_embedding.bits, + ) + if new_vocab_size > old_vocab_size: + # Add new rows + new_row_count = new_vocab_size - old_vocab_size + new_rows = mx.random.normal((new_row_count, old_embedding.dims), scale=0.02) + new_rows_q, new_rows_scales, new_rows_biases = mx.quantize( + new_rows, old_embedding.group_size, old_embedding.bits + ) + + new_embedding.weight = mx.concatenate( + [old_embedding.weight, new_rows_q], axis=0 + ) + new_embedding.scales = mx.concatenate( + [old_embedding.scales, new_rows_scales], axis=0 + ) + new_embedding.biases = mx.concatenate( + [old_embedding.biases, new_rows_biases], axis=0 ) - # check if QuantizedEmbedding has required attributes for dequantization + else: # new_vocab_size < old_vocab_size: Slice existing + new_embedding.weight = old_embedding.weight[:new_vocab_size] + new_embedding.scales = old_embedding.scales[:new_vocab_size] + new_embedding.biases = old_embedding.biases[:new_vocab_size] + + else: + # non-quantized embedding case (fallback, less efficient) + # dequantize ONLY if necessary + # should ideally be avoided entirely for quantized models. try: dequantized_weights = mx.dequantize( old_embedding.weight, @@ -27,14 +68,13 @@ def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Modul group_size=old_embedding.group_size, bits=old_embedding.bits, ) - except AttributeError as e: - print(f"Error: Cannot dequantize embed_tokens. Missing attributes: {e}") + # handle missing quantization attributes + except (AttributeError, TypeError): print("Falling back to random weights for embed_tokens.") dequantized_weights = mx.random.normal( (old_vocab_size, old_embedding.dims), loc=0.0, scale=0.02 ) - # resize embed_tokens new_embedding = nn.Embedding(new_vocab_size, old_embedding.dims) new_weights = mx.zeros((new_vocab_size, old_embedding.dims)) min_vocab_size = min(old_vocab_size, new_vocab_size) @@ -46,88 +86,81 @@ def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Modul scale=0.02, ) new_embedding.weight = new_weights - model.model.embed_tokens = new_embedding - - # attention layers handling - if hasattr(model, "args") and getattr(model.args, "tie_word_embeddings", False): - model.model.embed_tokens.weight = new_weights - elif hasattr(model, "lm_head"): - old_lm_head = model.lm_head - if isinstance(old_lm_head, nn.QuantizedLinear): - # resize nn.QuantizedLinear - output_dims, compressed_input_dims = old_lm_head.weight.shape - bits = old_lm_head.bits - input_dims = compressed_input_dims * (32 // bits) - - # dequantize lm_head weights - try: - dequantized_lm_weights = mx.dequantize( - old_lm_head.weight, - scales=old_lm_head.scales, - biases=old_lm_head.biases, - group_size=old_lm_head.group_size, - bits=old_lm_head.bits, - ) - except AttributeError as e: - print(f"Error: Cannot dequantize lm_head. Missing attributes: {e}") - print("Falling back to random weights for lm_head.") - dequantized_lm_weights = mx.random.normal( - (output_dims, input_dims), loc=0.0, scale=0.02 - ) - new_lm_head = nn.QuantizedLinear( - input_dims=input_dims, - output_dims=new_vocab_size, - bias="bias" in old_lm_head, - group_size=old_lm_head.group_size, - bits=old_lm_head.bits, + model.model.embed_tokens = new_embedding + + # handle lm_head + if hasattr(model, "args") and getattr(model.args, "tie_word_embeddings", False): + if hasattr(new_embedding, "weight") and not isinstance( + new_embedding, nn.QuantizedEmbedding + ): + model.model.embed_tokens.weight = new_embedding.weight + + elif hasattr(model, "lm_head"): + old_lm_head = model.lm_head + if isinstance(old_lm_head, nn.QuantizedLinear): + output_dims, compressed_input_dims = old_lm_head.weight.shape + bits = old_lm_head.bits + input_dims = compressed_input_dims * (32 // bits) + group_size = old_lm_head.group_size + + new_lm_head = nn.QuantizedLinear( + input_dims=input_dims, + output_dims=new_vocab_size, + bias="bias" in old_lm_head, + group_size=group_size, + bits=bits, + ) + + if new_vocab_size > old_vocab_size: + new_row_count = new_vocab_size - old_vocab_size + new_rows = mx.random.normal((new_row_count, input_dims), scale=0.02) + new_rows_q, new_rows_scales, new_rows_biases = mx.quantize( + new_rows, group_size, bits ) - new_weights_lm = mx.zeros((new_vocab_size, input_dims)) - new_weights_lm[:min_vocab_size] = dequantized_lm_weights[ - :min_vocab_size - ] - if new_vocab_size > output_dims: - new_weights_lm[output_dims:] = mx.random.normal( - (new_vocab_size - output_dims, input_dims), loc=0.0, scale=0.02 - ) - new_lm_head.weight, new_lm_head.scales, new_lm_head.biases = ( - mx.quantize( - new_weights_lm, new_lm_head.group_size, new_lm_head.bits - ) + new_lm_head.weight = mx.concatenate( + [old_lm_head.weight, new_rows_q], axis=0 ) - if "bias" in old_lm_head: - new_lm_head.bias = mx.zeros((new_vocab_size,)) - new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[ - :min_vocab_size - ] - else: - # resize nn.Linear - new_lm_head = nn.Linear( - old_lm_head.input_dims, new_vocab_size, bias="bias" in old_lm_head + new_lm_head.scales = mx.concatenate( + [old_lm_head.scales, new_rows_scales], axis=0 ) - new_weights_lm = mx.zeros((new_vocab_size, old_lm_head.input_dims)) - min_vocab_size = min(old_lm_head.weight.shape[0], new_vocab_size) - new_weights_lm[:min_vocab_size] = old_lm_head.weight[:min_vocab_size] - if new_vocab_size > old_lm_head.weight.shape[0]: - new_weights_lm[old_lm_head.weight.shape[0] :] = mx.random.normal( - ( - new_vocab_size - old_lm_head.weight.shape[0], - old_lm_head.input_dims, - ), - loc=0.0, - scale=0.02, + new_lm_head.biases = mx.concatenate( + [old_lm_head.biases, new_rows_biases], axis=0 + ) + else: + new_lm_head.weight = old_lm_head.weight[:new_vocab_size] + new_lm_head.scales = old_lm_head.scales[:new_vocab_size] + new_lm_head.biases = old_lm_head.biases[:new_vocab_size] + + if "bias" in old_lm_head: + if new_vocab_size > old_vocab_size: + new_bias = mx.concatenate( + [old_lm_head.bias, mx.zeros(new_vocab_size - old_vocab_size)] ) - new_lm_head.weight = new_weights_lm - # todo typechecking - if "bias" in old_lm_head: - new_lm_head.bias = mx.zeros((new_vocab_size,)) - new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[ - :min_vocab_size - ] - - model.lm_head = new_lm_head - else: - print("Vocab already sized right.") + else: + new_bias = old_lm_head.bias[:new_vocab_size] + new_lm_head.bias = new_bias + # nn.Linear case + else: + new_lm_head = nn.Linear( + old_lm_head.input_dims, new_vocab_size, bias="bias" in old_lm_head + ) + new_weights_lm = mx.zeros((new_vocab_size, old_lm_head.input_dims)) + min_vocab_size = min(old_vocab_size, new_vocab_size) + new_weights_lm[:min_vocab_size] = old_lm_head.weight[:min_vocab_size] + if new_vocab_size > old_vocab_size: + new_weights_lm[old_vocab_size:] = mx.random.normal( + (new_vocab_size - old_vocab_size, old_lm_head.input_dims), + loc=0.0, + scale=0.02, + ) + new_lm_head.weight = new_weights_lm + if "bias" in old_lm_head: + new_lm_head.bias = mx.zeros((new_vocab_size,)) + new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[:min_vocab_size] + + model.lm_head = new_lm_head + return model From e2ace6fb0f1411e28b82d0c07337a72a6233317e Mon Sep 17 00:00:00 2001 From: paNikitin <115797306+paNikitin@users.noreply.github.com> Date: Mon, 24 Feb 2025 09:12:31 +0300 Subject: [PATCH 7/9] Update trainer.py --- llms/mlx_lm/tuner/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index cd8f513c6..04933fe70 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -111,7 +111,9 @@ def cot_loss( data_token_id = tokenizer.encode(args.data_token)[0] reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1) - data_positions = mx.argmax(targets == data_token_id, axis=1) + # find the LAST occurrence of data_token_id using slicing (in case generated dataset has multiple occurrences of [DATA]) + data_positions = mx.argmax(targets[:, ::-1] == data_token_id, axis=1) + data_positions = targets.shape[1] - 1 - data_positions seq_indices = mx.arange(targets.shape[1])[None, :] From b2ab37238ee194552e305309d38eaa126b537289 Mon Sep 17 00:00:00 2001 From: paNikitin <115797306+paNikitin@users.noreply.github.com> Date: Mon, 24 Feb 2025 12:21:30 +0300 Subject: [PATCH 8/9] added adapter additional tokens load on fuse --- llms/mlx_lm/fuse.py | 5 +++-- llms/mlx_lm/tuner/utils.py | 22 ++++++++++++++++++++-- llms/mlx_lm/utils.py | 4 ++++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py index b0c46a748..47abd717d 100644 --- a/llms/mlx_lm/fuse.py +++ b/llms/mlx_lm/fuse.py @@ -77,7 +77,8 @@ def main() -> None: model, config, tokenizer = fetch_from_hub(model_path) model.freeze() - model = load_adapters(model, args.adapter_path) + + model, tokenizer = load_adapters(model, tokenizer, args.adapter_path) fused_linears = [ (n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse") @@ -105,7 +106,7 @@ def main() -> None: if args.de_quantize: config.pop("quantization", None) - save_config(config, config_path=save_path / "config.json") + save_config(config, tokenizer, config_path=save_path / "config.json") if args.export_gguf: model_type = config["model_type"] diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index f5df11e32..6948df15c 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -8,6 +8,8 @@ import mlx.nn as nn import mlx.optimizers as opt from mlx.utils import tree_flatten, tree_unflatten +from ..tokenizer_utils import TokenizerWrapper + from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear from .dora import DoRAEmbedding, DoRALinear @@ -159,7 +161,9 @@ def to_lora(layer): model.update_modules(tree_unflatten(lora_modules)) -def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module: +def load_adapters( + model: nn.Module, tokenizer: TokenizerWrapper, adapter_path: str +) -> nn.Module: """ Load any fine-tuned adapters / layers. @@ -184,7 +188,21 @@ def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module: use_dora=(fine_tune_type == "dora"), ) model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) - return model + if cot := config.cot: + print("Loading additional tokens") + if tokens := cot.get("additional_tokens"): + from .new_tokens import implement_new_tokens + + special = False + if (special_arg := cot.get("special")) and isinstance(special_arg, bool): + print("Updating model and tokenizer with new special tokens") + special = special_arg + else: + print("Updating model and tokenizer with new tokens") + model, tokenizer = implement_new_tokens( + model=model, tokenizer=tokenizer, tokens=tokens, special=special + ) + return model, tokenizer def dequantize(model: nn.Module) -> nn.Module: diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 1fae76fa7..b7bd1cba8 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -993,6 +993,7 @@ def _class_predicate(p, m): def save_config( config: dict, + tokenizer: TokenizerWrapper, config_path: Union[str, Path], ) -> None: """Save the model configuration to the ``config_path``. @@ -1009,6 +1010,9 @@ def save_config( # sort the config for better readability config = dict(sorted(config.items())) + if config["vocab_size"] != (cur := len(tokenizer._tokenizer)): + config["vocab_size"] = cur + print("Updated model`s config.json to match new tokenizer") # write the updated config to the config_path (if provided) with open(config_path, "w") as fid: json.dump(config, fid, indent=4) From 09f5add15149ccc27a67c680605da3113a4db0d7 Mon Sep 17 00:00:00 2001 From: paNikitin <115797306+paNikitin@users.noreply.github.com> Date: Mon, 24 Feb 2025 19:30:17 +0300 Subject: [PATCH 9/9] Update lora.py --- llms/mlx_lm/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index eb0a279e5..5f24da61b 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -215,7 +215,7 @@ def train_model( adapter_path.mkdir(parents=True, exist_ok=True) adapter_file = adapter_path / "adapters.safetensors" - save_config(vars(args), adapter_path / "adapter_config.json") + save_config(vars(args), tokenizer, adapter_path / "adapter_config.json") # init training args training_args = TrainingArgs(