Skip to content

Commit 6dfd77e

Browse files
authored
Update train.py
1 parent 559cbcd commit 6dfd77e

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

Diff for: train.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,13 @@ def main(args):
8484
shuffle=True,
8585
batch_size=args.batch_size,
8686
)
87-
87+
if args.learning_rate != 2e-5:
88+
accelerator.print(f"Warning: You also need to modify accelerate_configs/zero3_offload.json to change the learning rate")
8889
optim = DummyOptim(model.parameters(), lr=args.learning_rate)
8990
scheduler = DummyScheduler(
9091
optim,
9192
num_training_steps=args.max_train_steps,
9293
total_num_steps=args.max_train_steps,
93-
num_warmup_steps=args.warmup_steps,
9494
)
9595
model, optim, scheduler = accelerator.prepare(model, optim, scheduler)
9696
train_loader = prepare_dataloader(args.parallel_mode, train_loader, accelerator)
@@ -192,11 +192,9 @@ def main(args):
192192
args.add_argument("--batch-size", type=int, default=1)
193193
args.add_argument("--gradient-accumulate-every", type=int, default=8)
194194
args.add_argument("--output-dir", type=str, required=True)
195-
args.add_argument("--lora", action="store_true")
196195
args.add_argument("--wandb", type=str)
197196
args.add_argument("--seed", type=int, default=42)
198197
args.add_argument("--max-train-steps", type=int, default=400)
199-
args.add_argument("--warmup-steps", type=int, default=20)
200198
args.add_argument("--learning-rate", type=float, default=2e-5)
201199
args.add_argument("--rope-theta", type=float, default=100000)
202200
args.add_argument("--model", type=str, default="meta-llama/Llama-2-7b-hf")
@@ -205,13 +203,7 @@ def main(args):
205203
type=str,
206204
default="emozilla/pg_books-tokenized-bos-eos-chunked-65536",
207205
)
208-
args.add_argument("--num-proc", type=int, default=32)
209-
args.add_argument(
210-
"--lr-schedule", type=str, choices=["linear", "constant"], default="linear"
211-
)
212-
args.add_argument("--log-loss", type=str)
213206
args.add_argument("--seq-length", type=int, default=16384)
214-
args.add_argument("--debug", action="store_true")
215207
args.add_argument(
216208
"--parallel_mode",
217209
type=str,

0 commit comments

Comments
 (0)