@@ -84,13 +84,13 @@ def main(args):
84
84
shuffle = True ,
85
85
batch_size = args .batch_size ,
86
86
)
87
-
87
+ if args .learning_rate != 2e-5 :
88
+ accelerator .print (f"Warning: You also need to modify accelerate_configs/zero3_offload.json to change the learning rate" )
88
89
optim = DummyOptim (model .parameters (), lr = args .learning_rate )
89
90
scheduler = DummyScheduler (
90
91
optim ,
91
92
num_training_steps = args .max_train_steps ,
92
93
total_num_steps = args .max_train_steps ,
93
- num_warmup_steps = args .warmup_steps ,
94
94
)
95
95
model , optim , scheduler = accelerator .prepare (model , optim , scheduler )
96
96
train_loader = prepare_dataloader (args .parallel_mode , train_loader , accelerator )
@@ -192,11 +192,9 @@ def main(args):
192
192
args .add_argument ("--batch-size" , type = int , default = 1 )
193
193
args .add_argument ("--gradient-accumulate-every" , type = int , default = 8 )
194
194
args .add_argument ("--output-dir" , type = str , required = True )
195
- args .add_argument ("--lora" , action = "store_true" )
196
195
args .add_argument ("--wandb" , type = str )
197
196
args .add_argument ("--seed" , type = int , default = 42 )
198
197
args .add_argument ("--max-train-steps" , type = int , default = 400 )
199
- args .add_argument ("--warmup-steps" , type = int , default = 20 )
200
198
args .add_argument ("--learning-rate" , type = float , default = 2e-5 )
201
199
args .add_argument ("--rope-theta" , type = float , default = 100000 )
202
200
args .add_argument ("--model" , type = str , default = "meta-llama/Llama-2-7b-hf" )
@@ -205,13 +203,7 @@ def main(args):
205
203
type = str ,
206
204
default = "emozilla/pg_books-tokenized-bos-eos-chunked-65536" ,
207
205
)
208
- args .add_argument ("--num-proc" , type = int , default = 32 )
209
- args .add_argument (
210
- "--lr-schedule" , type = str , choices = ["linear" , "constant" ], default = "linear"
211
- )
212
- args .add_argument ("--log-loss" , type = str )
213
206
args .add_argument ("--seq-length" , type = int , default = 16384 )
214
- args .add_argument ("--debug" , action = "store_true" )
215
207
args .add_argument (
216
208
"--parallel_mode" ,
217
209
type = str ,
0 commit comments