Skip to content

Commit

Permalink
wip: trying to enable schedule free
Browse files Browse the repository at this point in the history
  • Loading branch information
ex3ndr committed Jul 1, 2024
1 parent 867b0dd commit f1fdd1e
Show file tree
Hide file tree
Showing 5 changed files with 403 additions and 355 deletions.
2 changes: 1 addition & 1 deletion supervoice_valle/model_nar.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def forward(self, *, condition_text, condition_audio, audio, codec, loss = False
# Transform
#

x = self.transformer(x)
x = self.transformer(x, mask = m)

#
# Predict
Expand Down
61 changes: 39 additions & 22 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,30 @@
from einops import rearrange, reduce, repeat
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import set_seed
import schedulefree

# Local
from supervoice_valle import SupervoceNARModel, Tokenizer
from train.dataset import load_sampler, create_async_loader

# Train parameters
train_experiment = "valle-18"
train_experiment = "valle-25"
train_project="supervoice-valle"
train_auto_resume = True

# We speculate that original paper has about 6k tokens per GPU
# 6k tokens is routhly 2 batches, because a single row is a 1000-2500 tokens
# 6k tokens is routhly 3 rows, because a single row is a 1500-2500 tokens
# We have MUCH faster GPUs and therefore instead of gradient accumulation,
# we increase batch size 4x and reduce number of gradients to just 2x
train_grad_accum_every = 2
# we increase batch size 4x and reduce number of gradients to just 4x
train_grad_accum_every = 8
train_batch_size = 8

# We speculate that learning rate is given for all GPUs, so we divide it by number of GPUs
train_lr_start = 1e-10
train_lr_max = 1e-4
train_lr_start = 1e-12
train_lr_max = 1e-5
train_steps = 600000
train_warmup_steps = 5000 # I am using faster warmup - it is more natural for me after working on voicebox
train_warmup_steps = 32000 # I am using faster warmup - it is more natural for me after working on voicebox
train_schedule_free = False

train_loader_workers = 32
train_log_every = 1
Expand Down Expand Up @@ -77,24 +79,28 @@ def main():
# Prepare dataset
accelerator.print("Loading dataset...")
tokenizer = Tokenizer("./tokenizer_text.model")
train_sampler = load_sampler("./external_datasets/libriheavy/libriheavy_cuts_medium.jsonl.gz", "./external_datasets/libriheavy-medium-encodec/", train_batch_size, tokenizer)
# train_sampler = load_sampler("./external_datasets/libriheavy/libriheavy_cuts_medium.jsonl.gz", "./external_datasets/libriheavy-medium-encodec/", train_batch_size, tokenizer)
train_sampler = load_sampler("./external_datasets/libriheavy/libriheavy_cuts_large.jsonl.gz", "./external_datasets/libriheavy-large-encodec/", train_batch_size, tokenizer)
# train_sampler = load_sampler("./external_datasets/libriheavy/libriheavy_cuts_small.jsonl.gz", "./external_datasets/libriheavy-encodec/", train_batch_size, tokenizer)
train_loader = create_async_loader(train_sampler, num_workers = train_loader_workers)
train_cycle = cycle(train_loader)

# Model
accelerator.print("Loading model...")
step = 0
step = 1
model = SupervoceNARModel().to(device)
raw_model = model
wd_params, no_wd_params = [], []
for param in model.parameters():
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
optim = torch.optim.AdamW([{'params': wd_params}, {'params': no_wd_params, 'weight_decay': 0}], train_lr_start, betas=[0.9, 0.99],weight_decay=0.01)
if not train_schedule_free:
optim = torch.optim.AdamW([{'params': wd_params}, {'params': no_wd_params, 'weight_decay': 0}], train_lr_start, betas=[0.9, 0.95],weight_decay=0.01, eps=1e-6)
else:
optim = schedulefree.AdamWScheduleFree([{'params': wd_params}, {'params': no_wd_params, 'weight_decay': 0}], lr=train_lr_max, betas=[0.9, 0.95],weight_decay=0.01, eps=1e-6)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max = train_steps)
if train_compile:
model = torch.compile(model)
model = torch.compile(model, mode="reduce-overhead")

# Checkpoint
checkpoint = None
Expand All @@ -104,7 +110,10 @@ def main():
run_id = checkpoint['run_id']

# Accelerate
model, optim = accelerator.prepare(model, optim)
if not train_schedule_free:
model, optim = accelerator.prepare(model, optim)
else:
model = accelerator.prepare(model)
hps = {
"train_lr_start": train_lr_start,
"train_lr_max": train_lr_max,
Expand All @@ -120,6 +129,8 @@ def main():

# Save
def save():
if train_schedule_free:
optim.eval()
# Save step checkpoint
fname = str(output_dir / f"{train_experiment}.pt")
fname_step = str(output_dir / f"{train_experiment}.{step}.pt")
Expand Down Expand Up @@ -151,16 +162,21 @@ def save():
# Train step
def train_step():
model.train()
if train_schedule_free:
optim.train()

# Update LR
if step < train_warmup_steps:
lr = (lr_start + ((lr_max - lr_start) * step) / train_warmup_steps)
for param_group in optim.param_groups:
param_group['lr'] = lr
lr = lr / accelerator.num_processes
if not train_schedule_free:
if step < train_warmup_steps:
lr = (lr_start + ((lr_max - lr_start) * step) / train_warmup_steps)
for param_group in optim.param_groups:
param_group['lr'] = lr
lr = lr / accelerator.num_processes
else:
scheduler.step()
lr = scheduler.get_last_lr()[0] / accelerator.num_processes
else:
scheduler.step()
lr = scheduler.get_last_lr()[0] / accelerator.num_processes
lr = lr_max / accelerator.num_processes

# Load batch
for _ in range(train_grad_accum_every):
Expand Down Expand Up @@ -211,8 +227,9 @@ def train_step():
optim.step()

# Log skipping step
if optim.step_was_skipped:
accelerator.print("Step was skipped")
if not train_schedule_free:
if optim.step_was_skipped:
accelerator.print("Step was skipped")

return loss, lr

Expand All @@ -234,7 +251,7 @@ def train_step():
# Summary
if step % train_log_every == 0:
accelerator.log({
"learning_rate": lr,
"learning_rate": lr if not train_schedule_free else train_lr_max,
"loss": loss,
"scale": accelerator.scaler.get_scale() if accelerator.scaler is not None else 1.0
}, step=step)
Expand Down
1 change: 1 addition & 0 deletions train.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
set -e
export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512'
while true; do
accelerate launch ./train.py || true
done
4 changes: 2 additions & 2 deletions train/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def sample():
# Load text
with open(dir + id + ".txt", 'r') as file:
text = file.read()
if text == "":
raise Exception("Empty file")
text = tokenizer.encode(text) if random.random() < 0.3 else tokenizer.encode_sample(text) # 30% chance of sampling optimal
if text.shape[0] == 0:
raise Exception("Empty file")

# Load encoded
encoded = torch.load(dir + id + ".pt")
Expand Down
690 changes: 360 additions & 330 deletions welcome.ipynb

Large diffs are not rendered by default.

0 comments on commit f1fdd1e

Please sign in to comment.