-
Notifications
You must be signed in to change notification settings - Fork 989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding grpo training #1233
base: main
Are you sure you want to change the base?
Adding grpo training #1233
Conversation
Absolute HERO! Been trying to figure this out myself the past week but made pretty much no progress whatsoever, other than to make a script that fills up all the RAM on my Mac 🤣 Is there any way to run this yet? I assume no since at the mo it's still marked as in draft + there isn't a lora_config.yaml like in the DPO example yet (not sure if it's needed)? |
No, not yet I still have to implement the Dataset Wrapper and some other stuff, I'll tell you when it's done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible need to use expanded_prompts, expanded_answers in both reward and loss
python -m mlx_lm.lora \
--model Qwen/Qwen2.5-0.5B \
--train \
--data /Users/gokdenizgulmez/Desktop/test_grpo \
--iters 5 \
--batch-size 1 \
--num-layers 4 \
--val-batches 1 \
--steps-per-report 1 \
--adapter-path /Users/gokdenizgulmez/Desktop/test-grpo-full \
--max-seq-length 128 \
--grad-checkpoint \
--training-mode grpo \
--fine-tune-type lora \
--beta 0.1 \
--steps-per-eval 500 \
--group-size 2 Output
But after that my 32 GB of ram get fully used. I tried to add some memory optimisations but the memory usage is still too much. |
Still uses too much memory. |
So I tried using trl and the same amount of ram has been used, so no error on my side |
🚀 Would you be able to share the datasets you used for the training? Will give it a go on my machine as soon as I can 🙌 |
Will do that tomorrow 🤝 |
I created a quick one only for testing the code |
python -m mlx_lm.lora \
--model Qwen/Qwen2.5-0.5B \
--train \
--data /Users/gokdenizgulmez/Desktop/test_grpo \
--iters 5 \
--batch-size 1 \
--num-layers 8 \
--val-batches 1 \
--steps-per-report 1 \
--adapter-path /Users/gokdenizgulmez/Desktop/test-grpo-full \
--max-seq-length 255 \
--grad-checkpoint \
--training-mode grpo \
--fine-tune-type lora \
--beta 0.1 \
--steps-per-eval 500 \
--group-size 2 \
--max-completion-length 6 Output:
|
@wangcheng0825 It's normal to see zero loss during RL fine-tuning of LLMs as long as rewards are improving. Here is Unsloth: |
@kiratp that's a great idea! I'll push the update later today. |
…s (generates now faster, while same RAM usage), fix for the identical generatrions, seperated the reward functions into a seperate file.
Huge changes! the generating are different now (it was because I used argmax instead of mx.random.categorical), now this has system message support in the dataset loader too.
|
Args:
|
This looks like a solid improvement! I’ve heard that the mx.random.categorical method produces more diverse results. Adding system message support is a great enhancement—really useful. Thanks for all your effort in improving this feature! |
Thanks for your kind words @lin72h! This really motivates me to know that my efforts are appreciated and that there's a clear desire within the community for these enhancements. Special thanks to @Guo-astro @deathcoder @kiratp and everyone else here!!!!! Your support means a lot. |
Just wanted to pop up again and express my support again 😁 The efforts are very much appreciated!!! (Been dealing with some personal issues lately so haven't been near as active in community as I'd like, but have been keeping an eye on this repo every day regardless pahahahaha) Thanks for the awesome work @Goekdeniz-Guelmez 😁 |
You only ran 250 iterations with batch size 1, which is likely insufficient for meaningful changes in model behavior, especially for a 3B parameter model. Can you also show me the logs from the training? If the rewards go up that means the mdoel is learning. Is the adapter path correct? Also is the system promtp you used correct? The default system prompt is |
thx @Goekdeniz-Guelmez , It's a problem with my system prompt, I try to use
|
Thanks!! It should NOT be indented because it should execute regardless of whether the weights were provided or defaulted. |
@Goekdeniz-Guelmez i am testing on the latest commit (first of all, again, amazing improvements) i was running training for
i'm saving adapters every 10 steps, so this wasnt the first time they were being saved... not really sure what else i can add about this unfortunately it didnt print a stacktrace |
@deathcoder Probably has something to do with the memory handling and clearing, I'll look into it when im home. |
Hello, Thanks a lot for all the work, it is a pleasure to be able to play with GRPO locally !! After a few tests, it seems to work perfectly fine for very short prompts, but I struggle with prompts of 1000 tokens, even with small models like mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit My args :
With the model Qwen/Qwen2.5-0.5B, I get 5 tokens/sec, so with a group-size of 2 generating 500 tokens = 1000 tokens generated => 200 sec per iteration I have a M4 max 128gb : my gpu usage peak at 10% sometimes, but stay very low. The memory used is close to 40 BG with the 1.5B @Goekdeniz-Guelmez, do you have any idea how to improve the performance ? |
@Vi-cs have you tried reducing the num-layers? you are tuning all layers with -1, also make sure you are on the latest commit i never tried with -1, but in my tests with 8 layers i get much higher speeds than that,
|
My commit was a few days old but I pulled the last commit just to be sure. And I changed to --num-layer 4. Edit :
|
not sure what it is on your side that is slowing you down, i just ran the exact same command you just sent, only difference is the dataset:
and training is still going i'm now on iteration 14 and i just launched it before i start writing this message edit: just to confirm you are actually on latest commit, do you have the validation sample details in your logs?
|
@Goekdeniz-Guelmez Thanks for the nice job! I wonder if the current codes support GRPO training with |
Thanks @deathcoder. The dataset is this one : https://huggingface.co/datasets/vi-c/test/.
Also, I think I am on the last commit based on le log :
|
Hi @Goekdeniz-Guelmez , I am using a R1 distill model which is not an instruct model. It doesn't behave correctly with --use-chat-template. I set the --use-prompt on purpose and it works fine (I tested this on Unsloth GRPO). |
@Vi-cs You're using the wrong Dataset! Your dataset doesn't match the normal reward functions. The reward functions are looking for specific XML tags and formatted answers, but your dataset contains JSON with Chinese text instead. If you really need to, then you have create new reward functions and prompt that works with the JSON data, and train your model via code. The Dataset is not suited for GRPO training since GRPO needs structured data with clear evaluation criteria to optimize against, which your mixed-language JSON data doesn't provide. Look into the LORA.md documentation to understand the dataset that should be used, your dataset is usually more suited for basic SFT training and not for GRPO. Bad Dataset![]() Good Dataset:Goastro/mlx-grpo-dataset |
Totally agree! The reward functions provided only reward the model if the completion contains a strict xml structure, an Int inside the answer tag, and the correct int. The training doesn't work with my reward functions. Just to make sure the issue is not related to my custom reward functions, I use the ones provided (which don't make sense with my dataset, but should not break the training) Still, the training doesn't work. @Goekdeniz-Guelmez do you see a reason why the training would not succeed in computing iterations with my dataset? Any chance you try to execute my args (with my dataset), to see if on you side you process at least a few itérations? Thanks! |
No description provided.