Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cot loss masking #1298

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Cot loss masking #1298

wants to merge 9 commits into from

Conversation

paNikitin
Copy link

This PR implements CoT loss masking and additional tokens append to model vocabulary with embeddings resizing.
For now there is only an ability to add special tokens.

@paNikitin paNikitin changed the title Cot loss Cot loss masking Feb 23, 2025
@awni
Copy link
Member

awni commented Feb 27, 2025

We should think carefully here about how to integrate this in a way that doesn't make the code too difficult to modify / maintain moving forward. Right now there are lots of little places that need to get updated to manage the COT loss masking you implemented and that makes things quite brittle.

First suggestion is I think it's ok if the resize of the model is a separate step (separate script) from the actual training.

  • Step 1 you resize a model/tokenizer into a new directory
  • Step 2 train a model from the directory you made

Second suggestion, I'm wondering if we can handle the loss masking by having a dataset which generates the right length start/stop values? That way we wouldn't need have all the downstream code be aware of the "reasoning" and "data" tokens. This should look similar to the way we do completion only fine tuning.

Also I'm wondering if you can explain the loss a bit. I notice it only incurs loss after the "[DATA]" token. But this part I don't understand:

    # masking loss before [DATA]; applying penalty for invalid seq
    valid_loss = (ce * loss_mask).sum(axis=1) / (mx.sum(loss_mask, axis=1) + 1e-8)
    final_loss = mx.where(valid_seq, valid_loss, penalty)  # 10.0 as invalid penalty

Could you explain it or point to a reference?

@paNikitin
Copy link
Author

paNikitin commented Feb 27, 2025

  1. I completely agree. Yes, there are a lot of things to refine.

  2. Yes, that makes sense

  3. The idea was to let model's pretrained knowledge to be aligned under a teacher's model reasoning without breaking it (instead of intermediate steps aligning). So, that's why the loss is only computed after final response token.

I better to show results on some benchmarks in comparison to base model first.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants