From 9e39c544eb7263518d88923592e87f718a23873a Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 31 Jan 2025 17:03:40 +0100 Subject: [PATCH 1/4] initial commit --- llms/mlx_lm/tuner/ppo_trainer.py | 327 +++++++++++++++++++++++++++++++ 1 file changed, 327 insertions(+) create mode 100644 llms/mlx_lm/tuner/ppo_trainer.py diff --git a/llms/mlx_lm/tuner/ppo_trainer.py b/llms/mlx_lm/tuner/ppo_trainer.py new file mode 100644 index 000000000..63ca58bb3 --- /dev/null +++ b/llms/mlx_lm/tuner/ppo_trainer.py @@ -0,0 +1,327 @@ +# Copyright © 2024 Apple Inc. + +import glob +import shutil +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.nn.utils import average_gradients +from mlx.utils import tree_flatten + + +def grad_checkpoint(layer): + """ + Update all instances of type(layer) to use gradient checkpointing. + """ + fn = type(layer).__call__ + + def checkpointed_fn(model, *args, **kwargs): + def inner_fn(params, *args, **kwargs): + model.update(params) + return fn(model, *args, **kwargs) + + return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) + + type(layer).__call__ = checkpointed_fn + + +@dataclass +class TrainingArgs: + batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) + iters: int = field(default=100, metadata={"help": "Iterations to train for."}) + val_batches: int = field( + default=25, + metadata={ + "help": "Number of validation batches, -1 uses the entire validation set." + }, + ) + steps_per_report: int = field( + default=10, + metadata={"help": "Number of training steps between loss reporting."}, + ) + steps_per_eval: int = field( + default=200, metadata={"help": "Number of training steps between validations."} + ) + steps_per_save: int = field( + default=100, metadata={"help": "Save the model every number steps"} + ) + max_seq_length: int = field( + default=2048, metadata={"help": "Maximum sequence length."} + ) + adapter_file: str = field( + default="adapters.safetensors", + metadata={"help": "Save/load path for the trained adapter weights."}, + ) + grad_checkpoint: bool = field( + default=False, + metadata={"help": "Use gradient checkpointing to reduce memory use."}, + ) + + +def default_loss(model, inputs, targets, lengths): + logits = model(inputs) + logits = logits.astype(mx.float32) + + length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + + ce = nn.losses.cross_entropy(logits, targets) * length_mask + ntoks = length_mask.sum() + ce = ce.sum() / ntoks + + return ce, ntoks + + +def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): + # Sort by length: + idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) + if len(dataset) < batch_size: + raise ValueError( + f"Dataset must have at least batch_size={batch_size}" + f" examples but only has {len(dataset)}." + ) + + # If running in distributed mode (N machines) then each one should skip N-1 + # samples + step = mx.distributed.init().size() + if batch_size % step != 0: + raise ValueError("The batch size must be divisible by the number of workers") + + # Make the batches: + batch_idx = [ + idx[i : i + batch_size : step] + for i in range(0, len(idx) - batch_size + 1, batch_size) + ] + + while True: + indices = np.random.permutation(len(batch_idx)) + for i in indices: + batch = [dataset[j] for j in batch_idx[i]] + lengths = [len(x) for x in batch] + if max(lengths) > max_seq_length: + print( + f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " + f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " + "Consider pre-splitting your data to save memory." + ) + + # Pad to the nearest multiple of 8 or the maximum length + pad_to = 8 + max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) + max_length_in_batch = min(max_length_in_batch, max_seq_length) + + batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) + + for j in range(batch_size // step): + truncated_length = min(lengths[j], max_seq_length) + batch_arr[j, :truncated_length] = batch[j][:truncated_length] + lengths[j] = ( + truncated_length # Update lengths to match truncated lengths + ) + batch = mx.array(batch_arr) + + yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + + if not train: + break + + +def evaluate( + model, + dataset, + tokenizer, + batch_size, + num_batches, + max_seq_length=2048, + loss: callable = default_loss, + iterate_batches: callable = iterate_batches, +): + all_losses = 0 + ntokens = 0 + + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + + for _, batch in zip( + index_iterator, + iterate_batches( + dataset=dataset, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_length=max_seq_length, + ), + ): + losses, toks = loss(model, *batch) + all_losses += losses * toks + ntokens += toks + mx.eval(all_losses, ntokens) + + all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) + ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) + + return (all_losses / ntokens).item() + + +class TrainingCallback: + + def on_train_loss_report(self, train_info: dict): + """Called to report training loss at specified intervals.""" + pass + + def on_val_loss_report(self, val_info: dict): + """Called to report validation loss at specified intervals or the beginning.""" + pass + + +def train( + model, + tokenizer, + optimizer, + train_dataset, + val_dataset, + args: TrainingArgs = TrainingArgs(), + loss: callable = default_loss, + iterate_batches: callable = iterate_batches, + training_callback: TrainingCallback = None, +): + print(f"Starting training..., iters: {args.iters}") + world = mx.distributed.init() + world_size = world.size() + rank = world.rank() + if world_size > 1: + print(f"Node {rank} of {world_size}") + + if args.grad_checkpoint: + grad_checkpoint(model.layers[0]) + + state = [model.state, optimizer.state] + + def step(batch): + # Forward and backward pass + (lvalue, toks), grad = loss_value_and_grad(model, *batch) + + # All reduce the gradients if running in distributed mode + grad = average_gradients(grad) + + # Model update + optimizer.update(model, grad) + + return lvalue, toks + + loss_value_and_grad = nn.value_and_grad(model, loss) + + losses = 0 + n_tokens = 0 + steps = 0 + trained_tokens = 0 + # Main training loop + start = time.perf_counter() + for it, batch in zip( + range(1, args.iters + 1), + iterate_batches( + dataset=train_dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + max_seq_length=args.max_seq_length, + train=True, + ), + ): + # Report validation loss if needed, the first validation loss + # is always measured before any training. + if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: + stop = time.perf_counter() + val_loss = evaluate( + model=model, + dataset=val_dataset, + loss=loss, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_batches=args.val_batches, + max_seq_length=args.max_seq_length, + iterate_batches=iterate_batches, + ) + val_time = time.perf_counter() - stop + if rank == 0: + print( + f"Iter {it}: " + f"Val loss {val_loss:.3f}, " + f"Val took {val_time:.3f}s", + flush=True, + ) + + if training_callback is not None: + val_info = { + "iteration": it, + "val_loss": val_loss, + "val_time": val_time, + } + training_callback.on_val_loss_report(val_info) + + start = time.perf_counter() + + lvalue, toks = step(batch) + losses += lvalue + n_tokens += toks + steps += 1 + mx.eval(state, losses, n_tokens) + + # Report training loss if needed + if it % args.steps_per_report == 0 or it == args.iters: + stop = time.perf_counter() + + train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() + train_loss /= steps * mx.distributed.init().size() + n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() + learning_rate = optimizer.learning_rate.item() + it_sec = args.steps_per_report / (stop - start) + tokens_sec = float(n_tokens) / (stop - start) + trained_tokens += n_tokens + peak_mem = mx.metal.get_peak_memory() / 1e9 + if rank == 0: + print( + f"Iter {it}: Train loss {train_loss:.3f}, " + f"Learning Rate {learning_rate:.3e}, " + f"It/sec {it_sec:.3f}, " + f"Tokens/sec {tokens_sec:.3f}, " + f"Trained Tokens {trained_tokens}, " + f"Peak mem {peak_mem:.3f} GB", + flush=True, + ) + + if training_callback is not None: + train_info = { + "iteration": it, + "train_loss": train_loss, + "learning_rate": learning_rate, + "iterations_per_second": it_sec, + "tokens_per_second": tokens_sec, + "trained_tokens": trained_tokens, + "peak_memory": peak_mem, + } + training_callback.on_train_loss_report(train_info) + + losses = 0 + n_tokens = 0 + steps = 0 + start = time.perf_counter() + + # Save adapter weights + if it % args.steps_per_save == 0: + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.save_safetensors(str(args.adapter_file), adapter_weights) + checkpoint = ( + Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" + ) + mx.save_safetensors(str(checkpoint), adapter_weights) + print( + f"Iter {it}: Saved adapter weights to " + f"{args.adapter_file} and {checkpoint}." + ) + + # Save final weights + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.save_safetensors(str(args.adapter_file), adapter_weights) + print(f"Saved final weights to {args.adapter_file}.") From 595125ad4ebe9a0c4f08bf092f8c3a783dc464ff Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 31 Jan 2025 17:19:05 +0100 Subject: [PATCH 2/4] updates --- llms/mlx_lm/tuner/ppo_trainer.py | 235 ++++++++++++++++++------------- 1 file changed, 137 insertions(+), 98 deletions(-) diff --git a/llms/mlx_lm/tuner/ppo_trainer.py b/llms/mlx_lm/tuner/ppo_trainer.py index 63ca58bb3..40dffe630 100644 --- a/llms/mlx_lm/tuner/ppo_trainer.py +++ b/llms/mlx_lm/tuner/ppo_trainer.py @@ -13,67 +13,92 @@ from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten - -def grad_checkpoint(layer): - """ - Update all instances of type(layer) to use gradient checkpointing. - """ - fn = type(layer).__call__ - - def checkpointed_fn(model, *args, **kwargs): - def inner_fn(params, *args, **kwargs): - model.update(params) - return fn(model, *args, **kwargs) - - return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) - - type(layer).__call__ = checkpointed_fn - - -@dataclass -class TrainingArgs: - batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) - iters: int = field(default=100, metadata={"help": "Iterations to train for."}) - val_batches: int = field( - default=25, - metadata={ - "help": "Number of validation batches, -1 uses the entire validation set." - }, - ) - steps_per_report: int = field( - default=10, - metadata={"help": "Number of training steps between loss reporting."}, +from trainer import TrainingArgs, TrainingCallback, grad_checkpoint + + + +def compute_ppo_loss( + new_logprobs: mx.array, + old_logprobs: mx.array, + values: mx.array, + old_values: mx.array, + advantages: mx.array, + returns: mx.array, + padding_mask: mx.array, + padding_mask_p1: mx.array = None, + vf_coef: float = 0.5, + cliprange: float = 0.2, + cliprange_value: float = 0.2 +) -> tuple[mx.array, mx.array, mx.array]: + """Compute PPO loss with policy and value components and masking""" + padding_mask_p1 = padding_mask_p1 if padding_mask_p1 is not None else padding_mask + + # Value loss + vpred_clipped = mx.clip(values, old_values - cliprange_value, old_values + cliprange_value) + vf_losses = mx.maximum( + mx.square(values - returns), + mx.square(vpred_clipped - returns) ) - steps_per_eval: int = field( - default=200, metadata={"help": "Number of training steps between validations."} - ) - steps_per_save: int = field( - default=100, metadata={"help": "Save the model every number steps"} - ) - max_seq_length: int = field( - default=2048, metadata={"help": "Maximum sequence length."} - ) - adapter_file: str = field( - default="adapters.safetensors", - metadata={"help": "Save/load path for the trained adapter weights."}, - ) - grad_checkpoint: bool = field( - default=False, - metadata={"help": "Use gradient checkpointing to reduce memory use."}, + vf_loss = 0.5 * mx.mean(mx.where(~padding_mask_p1, vf_losses, 0)) + + # Policy loss + ratio = mx.exp(new_logprobs - old_logprobs) + pg_losses = mx.maximum( + -advantages * ratio, + -advantages * mx.clip(ratio, 1.0 - cliprange, 1.0 + cliprange) ) + pg_loss = mx.mean(mx.where(~padding_mask, pg_losses, 0)) + + total_loss = pg_loss + vf_coef * vf_loss + return total_loss, pg_loss, vf_loss -def default_loss(model, inputs, targets, lengths): - logits = model(inputs) - logits = logits.astype(mx.float32) - - length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] +@dataclass +class PPOTrainingArgs(TrainingArgs): + vf_coef: float = field(default=0.5, metadata={"help": "Value function coefficient"}) + cliprange: float = field(default=0.2, metadata={"help": "Policy gradient clipping range"}) + cliprange_value: float = field(default=0.2, metadata={"help": "Value function clipping range"}) - ce = nn.losses.cross_entropy(logits, targets) * length_mask - ntoks = length_mask.sum() - ce = ce.sum() / ntoks - return ce, ntoks +def ppo_loss( + model, + inputs, + targets, + lengths, + old_logprobs, + values, + old_values, + advantages, + returns, + vf_coef=0.5, + cliprange=0.2, + cliprange_value=0.2 +): + # Get new logits and create length mask + logits = model(inputs).astype(mx.float32) + length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + + # Get new log probs + new_logprobs = nn.losses.cross_entropy(logits, targets) * length_mask + ntoks = length_mask.sum() + new_logprobs = new_logprobs.sum() / ntoks + + # Value loss with clipping + vpred_clipped = mx.clip(values, old_values - cliprange_value, old_values + cliprange_value) + vf_loss = 0.5 * mx.maximum( + mx.square(values - returns), + mx.square(vpred_clipped - returns) + ).mean() + + # Policy loss with clipping + ratio = mx.exp(new_logprobs - old_logprobs) + pg_loss = mx.maximum( + -advantages * ratio, + -advantages * mx.clip(ratio, 1.0 - cliprange, 1.0 + cliprange) + ).mean() + + total_loss = pg_loss + vf_coef * vf_loss + return total_loss, pg_loss, vf_loss, ntoks def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): @@ -131,49 +156,63 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) def evaluate( - model, - dataset, - tokenizer, - batch_size, - num_batches, - max_seq_length=2048, - loss: callable = default_loss, - iterate_batches: callable = iterate_batches, + model, + dataset, + tokenizer, + batch_size, + num_batches, + max_seq_length=2048, + old_logprobs=None, + values=None, + old_values=None, + advantages=None, + returns=None, + vf_coef=0.5, + cliprange=0.2, + cliprange_value=0.2, + loss: callable = compute_ppo_loss, + iterate_batches: callable = iterate_batches, ): - all_losses = 0 - ntokens = 0 - - index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) - - for _, batch in zip( - index_iterator, - iterate_batches( - dataset=dataset, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_length=max_seq_length, - ), - ): - losses, toks = loss(model, *batch) - all_losses += losses * toks - ntokens += toks - mx.eval(all_losses, ntokens) - - all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) - ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) - - return (all_losses / ntokens).item() - - -class TrainingCallback: - - def on_train_loss_report(self, train_info: dict): - """Called to report training loss at specified intervals.""" - pass - - def on_val_loss_report(self, val_info: dict): - """Called to report validation loss at specified intervals or the beginning.""" - pass + total_loss = 0 + total_pg_loss = 0 + total_vf_loss = 0 + ntokens = 0 + + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + + for _, batch in zip( + index_iterator, + iterate_batches( + dataset=dataset, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_length=max_seq_length, + ), + ): + losses, pg_loss, vf_loss, toks = loss( + model, *batch, + old_logprobs=old_logprobs, + values=values, + old_values=old_values, + advantages=advantages, + returns=returns, + vf_coef=vf_coef, + cliprange=cliprange, + cliprange_value=cliprange_value + ) + + total_loss += losses * toks + total_pg_loss += pg_loss * toks + total_vf_loss += vf_loss * toks + ntokens += toks + mx.eval(total_loss, total_pg_loss, total_vf_loss, ntokens) + + total_loss = mx.distributed.all_sum(total_loss, stream=mx.cpu) + total_pg_loss = mx.distributed.all_sum(total_pg_loss, stream=mx.cpu) + total_vf_loss = mx.distributed.all_sum(total_vf_loss, stream=mx.cpu) + ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) + + return (total_loss / ntokens).item(), (total_pg_loss / ntokens).item(), (total_vf_loss / ntokens).item() def train( @@ -183,7 +222,7 @@ def train( train_dataset, val_dataset, args: TrainingArgs = TrainingArgs(), - loss: callable = default_loss, + loss: callable = ppo_loss, iterate_batches: callable = iterate_batches, training_callback: TrainingCallback = None, ): From aa7a11c7535741b5167439ffd5345722539203ad Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 31 Jan 2025 17:38:01 +0100 Subject: [PATCH 3/4] updates --- llms/mlx_lm/tuner/ppo_trainer.py | 264 +++++++++++++++++-------------- 1 file changed, 148 insertions(+), 116 deletions(-) diff --git a/llms/mlx_lm/tuner/ppo_trainer.py b/llms/mlx_lm/tuner/ppo_trainer.py index 40dffe630..9c5f0a902 100644 --- a/llms/mlx_lm/tuner/ppo_trainer.py +++ b/llms/mlx_lm/tuner/ppo_trainer.py @@ -16,55 +16,77 @@ from trainer import TrainingArgs, TrainingCallback, grad_checkpoint - -def compute_ppo_loss( - new_logprobs: mx.array, - old_logprobs: mx.array, - values: mx.array, - old_values: mx.array, - advantages: mx.array, - returns: mx.array, - padding_mask: mx.array, - padding_mask_p1: mx.array = None, - vf_coef: float = 0.5, - cliprange: float = 0.2, - cliprange_value: float = 0.2 -) -> tuple[mx.array, mx.array, mx.array]: - """Compute PPO loss with policy and value components and masking""" - padding_mask_p1 = padding_mask_p1 if padding_mask_p1 is not None else padding_mask - - # Value loss - vpred_clipped = mx.clip(values, old_values - cliprange_value, old_values + cliprange_value) - vf_losses = mx.maximum( - mx.square(values - returns), - mx.square(vpred_clipped - returns) - ) - vf_loss = 0.5 * mx.mean(mx.where(~padding_mask_p1, vf_losses, 0)) - - # Policy loss - ratio = mx.exp(new_logprobs - old_logprobs) - pg_losses = mx.maximum( - -advantages * ratio, - -advantages * mx.clip(ratio, 1.0 - cliprange, 1.0 + cliprange) - ) - pg_loss = mx.mean(mx.where(~padding_mask, pg_losses, 0)) - - total_loss = pg_loss + vf_coef * vf_loss - return total_loss, pg_loss, vf_loss - - @dataclass class PPOTrainingArgs(TrainingArgs): vf_coef: float = field(default=0.5, metadata={"help": "Value function coefficient"}) cliprange: float = field(default=0.2, metadata={"help": "Policy gradient clipping range"}) cliprange_value: float = field(default=0.2, metadata={"help": "Value function clipping range"}) + gamma: float = field(default=0.99, metadata={"help": "Discount factor"}) + lambda_: float = field(default=0.95, metadata={"help": "GAE lambda"}) +def compute_returns( + rewards: mx.array, + gamma: float = 0.99 +) -> mx.array: + """Compute returns with Generalized Advantage Estimation""" + returns = mx.zeros_like(rewards) + running_return = 0 + + for t in reversed(range(len(rewards))): + running_return = rewards[t] + gamma * running_return + returns = returns.at[t].set(running_return) + + return returns + +def compute_advantages( + values: mx.array, + returns: mx.array, + rewards: mx.array, + gamma: float = 0.99, + lambda_: float = 0.95 +) -> mx.array: + """Compute advantages using GAE""" + advantages = mx.zeros_like(returns) + running_advantage = 0 + + for t in reversed(range(len(returns))): + if t < len(returns) - 1: + delta = rewards[t] + gamma * values[t + 1] - values[t] + else: + delta = rewards[t] - values[t] + + running_advantage = delta + gamma * lambda_ * running_advantage + advantages = advantages.at[t].set(running_advantage) + + return (advantages - advantages.mean()) / (advantages.std() + 1e-8) + +def make_predictions(model, x, mask): + inputs = x[:, :-1] + targets = x[:, 1:] + + logits = model(inputs) + logits = logits.astype(mx.float32) + + return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1] + +def compute_rewards(model, x, mask, reward_scale=1.0): + """ + Compute rewards based on model predictions and actual targets. + Basic implementation using log probabilities as rewards. + """ + logits = model(x[:, :-1]) + targets = x[:, 1:] + + log_probs = -nn.losses.cross_entropy(logits, targets, reduction='none') + rewards = log_probs * mask[:, :-1] * reward_scale + + return rewards + def ppo_loss( model, inputs, - targets, - lengths, + mask, old_logprobs, values, old_values, @@ -73,14 +95,10 @@ def ppo_loss( vf_coef=0.5, cliprange=0.2, cliprange_value=0.2 -): - # Get new logits and create length mask - logits = model(inputs).astype(mx.float32) - length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] - +): # Get new log probs - new_logprobs = nn.losses.cross_entropy(logits, targets) * length_mask - ntoks = length_mask.sum() + new_logprobs = make_predictions(model, inputs, mask) + ntoks = mask[:, :-1].sum() new_logprobs = new_logprobs.sum() / ntoks # Value loss with clipping @@ -101,58 +119,52 @@ def ppo_loss( return total_loss, pg_loss, vf_loss, ntoks -def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): - # Sort by length: - idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) - if len(dataset) < batch_size: - raise ValueError( - f"Dataset must have at least batch_size={batch_size}" - f" examples but only has {len(dataset)}." - ) - - # If running in distributed mode (N machines) then each one should skip N-1 - # samples - step = mx.distributed.init().size() - if batch_size % step != 0: - raise ValueError("The batch size must be divisible by the number of workers") - - # Make the batches: - batch_idx = [ - idx[i : i + batch_size : step] - for i in range(0, len(idx) - batch_size + 1, batch_size) - ] - - while True: - indices = np.random.permutation(len(batch_idx)) - for i in indices: - batch = [dataset[j] for j in batch_idx[i]] - lengths = [len(x) for x in batch] - if max(lengths) > max_seq_length: - print( - f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " - f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " - "Consider pre-splitting your data to save memory." - ) - - # Pad to the nearest multiple of 8 or the maximum length - pad_to = 8 - max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) - max_length_in_batch = min(max_length_in_batch, max_seq_length) - - batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - - for j in range(batch_size // step): - truncated_length = min(lengths[j], max_seq_length) - batch_arr[j, :truncated_length] = batch[j][:truncated_length] - lengths[j] = ( - truncated_length # Update lengths to match truncated lengths - ) - batch = mx.array(batch_arr) - - yield batch[:, :-1], batch[:, 1:], mx.array(lengths) - - if not train: - break +def iterate_ppo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): + # Sort by length + idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) + if len(dataset) < batch_size: + raise ValueError(f"Dataset must have at least batch_size={batch_size} examples but only has {len(dataset)}.") + + # Handle distributed training + step = mx.distributed.init().size() + if batch_size % step != 0: + raise ValueError("The batch size must be divisible by the number of workers") + + # Make batches + batch_idx = [idx[i:i+batch_size:step] for i in range(0, len(idx)-batch_size+1, batch_size)] + + while True: + indices = np.random.permutation(len(batch_idx)) + for i in indices: + batch = [dataset[j] for j in batch_idx[i]] + lengths = [len(x) for x in batch] + + # Handle sequence length + if max(lengths) > max_seq_length: + print(f"[WARNING] Truncating sequences longer than {max_seq_length}") + + # Pad to multiple of 8 + pad_to = 8 + max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) + max_length_in_batch = min(max_length_in_batch, max_seq_length) + + # Create batch array + batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) + mask = np.zeros((batch_size // step, max_length_in_batch), np.int32) + + for j in range(batch_size // step): + truncated_length = min(lengths[j], max_seq_length) + batch_arr[j, :truncated_length] = batch[j][:truncated_length] + mask[j, :truncated_length] = 1 + lengths[j] = truncated_length + + batch = mx.array(batch_arr) + mask = mx.array(mask) + + yield batch, mask + + if not train: + break def evaluate( @@ -170,8 +182,8 @@ def evaluate( vf_coef=0.5, cliprange=0.2, cliprange_value=0.2, - loss: callable = compute_ppo_loss, - iterate_batches: callable = iterate_batches, + loss: callable = ppo_loss, + iterate_ppo_batches: callable = iterate_ppo_batches, ): total_loss = 0 total_pg_loss = 0 @@ -182,7 +194,7 @@ def evaluate( for _, batch in zip( index_iterator, - iterate_batches( + iterate_ppo_batches( dataset=dataset, tokenizer=tokenizer, batch_size=batch_size, @@ -221,12 +233,12 @@ def train( optimizer, train_dataset, val_dataset, - args: TrainingArgs = TrainingArgs(), + args: PPOTrainingArgs = PPOTrainingArgs(), loss: callable = ppo_loss, - iterate_batches: callable = iterate_batches, + iterate_ppo_batches: callable = iterate_ppo_batches, training_callback: TrainingCallback = None, ): - print(f"Starting training..., iters: {args.iters}") + print(f"Starting PPO training..., iters: {args.iters}") world = mx.distributed.init() world_size = world.size() rank = world.rank() @@ -239,18 +251,38 @@ def train( state = [model.state, optimizer.state] def step(batch): - # Forward and backward pass - (lvalue, toks), grad = loss_value_and_grad(model, *batch) - - # All reduce the gradients if running in distributed mode + x, mask = batch + + # Initial forward pass + old_logprobs = make_predictions(model, x, mask) + values = model.value_head(x[:, :-1]) + old_values = values.copy() + + # Compute rewards (implement reward calculation based on your task) + rewards = compute_rewards(model, x, mask) + + # Compute returns and advantages + returns = compute_returns(rewards, values, gamma=args.gamma) + advantages = compute_advantages(values, returns, rewards, + gamma=args.gamma, + lambda_=args.lambda_) + + def loss_fn(model, x, mask): + total_loss, pg_loss, vf_loss, ntoks = ppo_loss( + model, x, mask, + old_logprobs, values, old_values, + advantages, returns, + vf_coef=args.vf_coef, + cliprange=args.cliprange, + cliprange_value=args.cliprange_value + ) + return total_loss, ntoks, pg_loss, vf_loss + + (loss_val, toks, pg_loss, vf_loss), grad = nn.value_and_grad(model, loss_fn)(x, mask) grad = average_gradients(grad) - - # Model update optimizer.update(model, grad) - - return lvalue, toks - - loss_value_and_grad = nn.value_and_grad(model, loss) + + return loss_val, toks, pg_loss, vf_loss losses = 0 n_tokens = 0 @@ -260,7 +292,7 @@ def step(batch): start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), - iterate_batches( + iterate_ppo_batches( dataset=train_dataset, tokenizer=tokenizer, batch_size=args.batch_size, @@ -280,7 +312,7 @@ def step(batch): batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, - iterate_batches=iterate_batches, + iterate_ppo_batches=iterate_ppo_batches, ) val_time = time.perf_counter() - stop if rank == 0: From a7e414687e89d0c457ae0a555053593cf4aa5d37 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 4 Feb 2025 10:45:23 +0100 Subject: [PATCH 4/4] update create_dataset --- llms/mlx_lm/tuner/datasets.py | 36 +++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 377e7cae1..f753b3f6f 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -102,8 +102,35 @@ def create_dataset( "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." ) +def create_dataset( + args, + data, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): + prompt_feature = prompt_feature or "prompt" + completion_feature = completion_feature or "completion" + sample = data[0] + + if args.training_mode == "normal": + if "messages" in sample: + return ChatDataset(data, tokenizer) + elif prompt_feature in sample and completion_feature in sample: + return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) + elif "text" in sample: + return Dataset(data, tokenizer) + else: + raise ValueError( + "Unsupported data format, check the supported formats here:\n" + "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." + ) + else: + return "" + def load_local_dataset( + args, data_path: Path, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -114,7 +141,7 @@ def load_subset(path): return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer, prompt_feature, completion_feature) + return create_dataset(args, data, tokenizer, prompt_feature, completion_feature) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] @@ -122,6 +149,7 @@ def load_subset(path): def load_hf_dataset( + args, data_id: str, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -137,7 +165,7 @@ def load_hf_dataset( train, valid, test = [ ( create_dataset( - dataset[n], tokenizer, prompt_feature, completion_feature + args, dataset[n], tokenizer, prompt_feature, completion_feature ) if n in dataset.keys() else [] @@ -202,12 +230,12 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): train, valid, test = load_local_dataset( - data_path, tokenizer, prompt_feature, completion_feature + args, data_path, tokenizer, prompt_feature, completion_feature ) else: print(f"Loading Hugging Face dataset {args.data}.") train, valid, test = load_hf_dataset( - args.data, tokenizer, prompt_feature, completion_feature + args, args.data, tokenizer, prompt_feature, completion_feature ) if args.train and len(train) == 0: