Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cot loss masking #1298

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llms/mlx_lm/examples/lora_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ lora_parameters:
scale: 20.0
dropout: 0.0

# cot loss masking training
# cot:
# use_cot: true
# special: true
# additional_tokens: ["[REASONING]", "[DATA]"]

# Schedule can only be specified in a config file, uncomment to use.
#lr_schedule:
# name: cosine_decay
Expand Down
5 changes: 3 additions & 2 deletions llms/mlx_lm/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def main() -> None:
model, config, tokenizer = fetch_from_hub(model_path)

model.freeze()
model = load_adapters(model, args.adapter_path)

model, tokenizer = load_adapters(model, tokenizer, args.adapter_path)

fused_linears = [
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
Expand Down Expand Up @@ -105,7 +106,7 @@ def main() -> None:
if args.de_quantize:
config.pop("quantization", None)

save_config(config, config_path=save_path / "config.json")
save_config(config, tokenizer, config_path=save_path / "config.json")

if args.export_gguf:
model_type = config["model_type"]
Expand Down
32 changes: 27 additions & 5 deletions llms/mlx_lm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"cot": False,
}


Expand Down Expand Up @@ -94,14 +95,12 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)

parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=False,
)

parser.add_argument(
"--num-layers",
type=int,
Expand Down Expand Up @@ -169,6 +168,11 @@ def build_parser():
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
parser.add_argument(
"--cot",
type=bool,
help="Use CoT loss masking",
)
return parser


Expand All @@ -186,7 +190,6 @@ def train_model(
f"Requested to train {args.num_layers} layers "
f"but the model only has {len(model.layers)} layers."
)

if args.fine_tune_type == "full":
for l in model.layers[-max(args.num_layers, 0) :]:
l.unfreeze()
Expand All @@ -212,7 +215,7 @@ def train_model(
adapter_path.mkdir(parents=True, exist_ok=True)

adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
save_config(vars(args), tokenizer, adapter_path / "adapter_config.json")

# init training args
training_args = TrainingArgs(
Expand All @@ -225,10 +228,13 @@ def train_model(
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
cot=(cot := args.cot),
)

model.train()
opt = optim.Adam(
# todo optimizer from args

opt = optim.AdamW(
learning_rate=(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
Expand Down Expand Up @@ -269,6 +275,21 @@ def run(args, training_callback: TrainingCallback = None):
print("Loading pretrained model")
model, tokenizer = load(args.model)

if cot := args.cot:
print("Using CoT loss masking")
if tokens := cot.get("additional_tokens"):
from .tuner.new_tokens import implement_new_tokens

special = False
if (special_arg := cot.get("special")) and isinstance(special_arg, bool):
print("Updating model and tokenizer with new special tokens")
special = special_arg
else:
print("Updating model and tokenizer with new tokens")
model, tokenizer = implement_new_tokens(
model=model, tokenizer=tokenizer, tokens=tokens, special=special
)

print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer)

Expand All @@ -293,6 +314,7 @@ def main():
parser = build_parser()
args = parser.parse_args()
config = args.config

args = vars(args)
if config:
print("Loading configuration file", config)
Expand Down
195 changes: 195 additions & 0 deletions llms/mlx_lm/tuner/new_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import mlx.nn as nn
import mlx.core as mx
from mlx_lm.tokenizer_utils import TokenizerWrapper


def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Module:
"""
Resizes model embeddings to accommodate new tokens, minimizing dequantization.
"""
old_embedding = model.model.embed_tokens
old_vocab_size = old_embedding.num_embeddings
new_vocab_size = len(tokenizer._tokenizer)

if old_vocab_size == new_vocab_size:
print("Vocab already sized right.")
return model

if new_vocab_size < old_vocab_size:
print("Warning: New vocab size is smaller than original. Proceeding with trim.")

if (
hasattr(old_embedding, "weight")
and hasattr(old_embedding, "scales")
and hasattr(old_embedding, "biases")
and hasattr(old_embedding, "group_size")
and hasattr(old_embedding, "bits")
):
# quantized embedding case: minimize dequantization

new_embedding = nn.QuantizedEmbedding(
new_vocab_size,
old_embedding.dims,
group_size=old_embedding.group_size,
bits=old_embedding.bits,
)
if new_vocab_size > old_vocab_size:
# Add new rows
new_row_count = new_vocab_size - old_vocab_size
new_rows = mx.random.normal((new_row_count, old_embedding.dims), scale=0.02)
new_rows_q, new_rows_scales, new_rows_biases = mx.quantize(
new_rows, old_embedding.group_size, old_embedding.bits
)

new_embedding.weight = mx.concatenate(
[old_embedding.weight, new_rows_q], axis=0
)
new_embedding.scales = mx.concatenate(
[old_embedding.scales, new_rows_scales], axis=0
)
new_embedding.biases = mx.concatenate(
[old_embedding.biases, new_rows_biases], axis=0
)

else: # new_vocab_size < old_vocab_size: Slice existing
new_embedding.weight = old_embedding.weight[:new_vocab_size]
new_embedding.scales = old_embedding.scales[:new_vocab_size]
new_embedding.biases = old_embedding.biases[:new_vocab_size]

else:
# non-quantized embedding case (fallback, less efficient)
# dequantize ONLY if necessary
# should ideally be avoided entirely for quantized models.
try:
dequantized_weights = mx.dequantize(
old_embedding.weight,
scales=old_embedding.scales,
biases=old_embedding.biases,
group_size=old_embedding.group_size,
bits=old_embedding.bits,
)
# handle missing quantization attributes
except (AttributeError, TypeError):
print("Falling back to random weights for embed_tokens.")
dequantized_weights = mx.random.normal(
(old_vocab_size, old_embedding.dims), loc=0.0, scale=0.02
)

new_embedding = nn.Embedding(new_vocab_size, old_embedding.dims)
new_weights = mx.zeros((new_vocab_size, old_embedding.dims))
min_vocab_size = min(old_vocab_size, new_vocab_size)
new_weights[:min_vocab_size] = dequantized_weights[:min_vocab_size]
if new_vocab_size > old_vocab_size:
new_weights[old_vocab_size:] = mx.random.normal(
(new_vocab_size - old_vocab_size, old_embedding.dims),
loc=0.0,
scale=0.02,
)
new_embedding.weight = new_weights

model.model.embed_tokens = new_embedding

# handle lm_head
if hasattr(model, "args") and getattr(model.args, "tie_word_embeddings", False):
if hasattr(new_embedding, "weight") and not isinstance(
new_embedding, nn.QuantizedEmbedding
):
model.model.embed_tokens.weight = new_embedding.weight

elif hasattr(model, "lm_head"):
old_lm_head = model.lm_head
if isinstance(old_lm_head, nn.QuantizedLinear):
output_dims, compressed_input_dims = old_lm_head.weight.shape
bits = old_lm_head.bits
input_dims = compressed_input_dims * (32 // bits)
group_size = old_lm_head.group_size

new_lm_head = nn.QuantizedLinear(
input_dims=input_dims,
output_dims=new_vocab_size,
bias="bias" in old_lm_head,
group_size=group_size,
bits=bits,
)

if new_vocab_size > old_vocab_size:
new_row_count = new_vocab_size - old_vocab_size
new_rows = mx.random.normal((new_row_count, input_dims), scale=0.02)
new_rows_q, new_rows_scales, new_rows_biases = mx.quantize(
new_rows, group_size, bits
)
new_lm_head.weight = mx.concatenate(
[old_lm_head.weight, new_rows_q], axis=0
)
new_lm_head.scales = mx.concatenate(
[old_lm_head.scales, new_rows_scales], axis=0
)
new_lm_head.biases = mx.concatenate(
[old_lm_head.biases, new_rows_biases], axis=0
)
else:
new_lm_head.weight = old_lm_head.weight[:new_vocab_size]
new_lm_head.scales = old_lm_head.scales[:new_vocab_size]
new_lm_head.biases = old_lm_head.biases[:new_vocab_size]

if "bias" in old_lm_head:
if new_vocab_size > old_vocab_size:
new_bias = mx.concatenate(
[old_lm_head.bias, mx.zeros(new_vocab_size - old_vocab_size)]
)
else:
new_bias = old_lm_head.bias[:new_vocab_size]
new_lm_head.bias = new_bias
# nn.Linear case
else:
new_lm_head = nn.Linear(
old_lm_head.input_dims, new_vocab_size, bias="bias" in old_lm_head
)
new_weights_lm = mx.zeros((new_vocab_size, old_lm_head.input_dims))
min_vocab_size = min(old_vocab_size, new_vocab_size)
new_weights_lm[:min_vocab_size] = old_lm_head.weight[:min_vocab_size]
if new_vocab_size > old_vocab_size:
new_weights_lm[old_vocab_size:] = mx.random.normal(
(new_vocab_size - old_vocab_size, old_lm_head.input_dims),
loc=0.0,
scale=0.02,
)
new_lm_head.weight = new_weights_lm
if "bias" in old_lm_head:
new_lm_head.bias = mx.zeros((new_vocab_size,))
new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[:min_vocab_size]

model.lm_head = new_lm_head

return model


def update_tokenizer(
tokenizer: TokenizerWrapper, tokens: list[str], special: bool
) -> TokenizerWrapper:
"""
Appends new tokens to the end of the tokenizer vocab
"""
if special:
# todo TokenizerWrapper access method
tokenizer._tokenizer.add_special_tokens({"additional_special_tokens": tokens})
print(f"Tokenizer updated with special tokens: {tokens}")
print(f"Tokenizer vocab size after append: {len(tokenizer._tokenizer)}")
else:
# todo add regular tokens
pass
return tokenizer


def implement_new_tokens(
model: nn.Module,
tokenizer: TokenizerWrapper,
tokens: list[str],
special: bool = False,
) -> tuple[nn.Module, TokenizerWrapper]:
"""
Update model`s tokenizer and embeddings with new tokens accordingly
"""
tokenizer = update_tokenizer(tokenizer=tokenizer, tokens=tokens, special=special)
model = resize_embeddings(model=model, tokenizer=tokenizer)
return model, tokenizer
Loading