-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_sd.py
342 lines (294 loc) · 12.8 KB
/
train_sd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
import os
import random
import numpy as np
from glob import glob
from tqdm.auto import tqdm
from contextlib import contextmanager
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2, InterpolationMode
from torchvision.io import read_image
from torchvision.utils import save_image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from diffusers.training_utils import compute_snr
class T2iDataset(Dataset):
def __init__(self, root_folder, resolution=512, random_crop=False):
self.root_folder = root_folder
self.resolution = resolution
self.images = []
self.captions = []
for image_path in glob(os.path.join(root_folder, "*.png")):
self.images.append(Image.open(image_path).convert('RGB'))
with open(os.path.splitext(image_path)[0] + ".txt", "r") as capfile:
self.captions.append(capfile.read())
if random_crop:
self.transforms = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.RandomResizedCrop(
size = self.resolution,
scale = (0.25, 1.0),
ratio = (0.9, 1.1),
),
])
else:
self.transforms = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Resize(size=self.resolution),
v2.CenterCrop(size=self.resolution),
])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
pixels = self.images[idx]
pixels = self.transforms(pixels) * 2 - 1
pixels = torch.clamp(torch.nan_to_num(pixels), min=-1, max=1)
caption = self.captions[idx]
return {"pixels": pixels, "caption": caption}
@contextmanager
def temp_rng(new_seed=None):
"""
https://github.com/fpgaminer/bigasp-training/blob/main/utils.py#L73
Context manager that saves and restores the RNG state of PyTorch, NumPy and Python.
If new_seed is not None, the RNG state is set to this value before the context is entered.
"""
# Save RNG state
old_torch_rng_state = torch.get_rng_state()
old_torch_cuda_rng_state = torch.cuda.get_rng_state()
old_numpy_rng_state = np.random.get_state()
old_python_rng_state = random.getstate()
# Set new seed
if new_seed is not None:
torch.manual_seed(new_seed)
torch.cuda.manual_seed(new_seed)
np.random.seed(new_seed)
random.seed(new_seed)
yield
# Restore RNG state
torch.set_rng_state(old_torch_rng_state)
torch.cuda.set_rng_state(old_torch_cuda_rng_state)
np.random.set_state(old_numpy_rng_state)
random.setstate(old_python_rng_state)
def train(
output_path = "./experiments/",
dataset_path = None,
lr = 1e-4,
train_te = False,
te_lr_mult = 0.5,
train_steps = 1000,
save_steps = 1000,
val_steps = 100,
stable_train_loss = True,
seed = None,
val_seed = 1234,
batch_size = 1,
val_repeats = 4,
device = "cuda",
):
if seed is not None:
random.seed(seed)
torch.manual_seed(seed)
os.makedirs(output_path, exist_ok=True)
t_writer = SummaryWriter(log_dir=output_path, flush_secs=60)
def collate_batch(batch):
pixels = []
captions = []
for sample in batch:
pixels.append(sample["pixels"])
captions.append(sample["caption"])
pixels = torch.stack(pixels, dim=0)
return pixels, captions
train_dataset = T2iDataset(os.path.join(dataset_path, "train"), random_crop=False)
train_dataloader = DataLoader(
dataset = train_dataset,
batch_size = batch_size,
shuffle = True,
collate_fn = collate_batch,
num_workers = 0,
pin_memory = True,
drop_last = False,
)
test_dataset = T2iDataset(os.path.join(dataset_path, "test"))
test_dataloader = DataLoader(
dataset = test_dataset,
batch_size = 1,
shuffle = False,
collate_fn = collate_batch,
num_workers = 0,
pin_memory = True,
drop_last = False,
)
val_dataset = T2iDataset(os.path.join(dataset_path, "val"))
val_dataloader = DataLoader(
dataset = val_dataset,
batch_size = 1,
shuffle = False,
collate_fn = collate_batch,
num_workers = 0,
pin_memory = True,
drop_last = False,
)
hf_identifier = "stable-diffusion-v1-5/stable-diffusion-v1-5"
noise_scheduler = DDPMScheduler.from_pretrained(hf_identifier, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(hf_identifier, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(hf_identifier, subfolder="text_encoder").to(device)
vae = AutoencoderKL.from_pretrained(hf_identifier, subfolder="vae").to(device)
unet = UNet2DConditionModel.from_pretrained(hf_identifier, subfolder="unet").to(device)
if train_te:
text_encoder.requires_grad_(True)
text_encoder.train()
else:
text_encoder.requires_grad_(False)
vae.requires_grad_(False)
unet.requires_grad_(True)
unet.train()
train_lr = lr * (batch_size ** 0.5)
optim_cls = torch.optim.AdamW
# optim_cls = torch.optim.Adafactor
# optim_cls = torch.optim.SGD
# import bitsandbytes as bnb
# optim_cls = bnb.optim.AdamW8bit
if train_te:
optimizer = optim_cls([
{"params": unet.parameters()},
{"params": text_encoder.parameters(), "lr": train_lr * te_lr_mult},
],
lr = train_lr,
)
else:
optimizer = optim_cls(
unet.parameters(),
lr = train_lr,
# weight_decay = 1e-4,
)
global_step = 0
train_logs = {"train_step": [], "train_loss": [], "train_timestep": []}
test_logs = {"train_step": [], "train_loss": [], "train_timestep": []}
val_logs = {"train_step": [], "train_loss": [], "train_timestep": []}
def encode_captions(captions, dropout=0):
input_ids = []
for caption in captions:
if torch.rand(1) < dropout:
caption = "" # caption dropout for better CFG
ids = tokenizer(
caption,
max_length=tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
input_ids.append(ids)
input_ids = torch.stack(input_ids, dim=0).to(device)
return text_encoder(input_ids, return_dict=False)[0]
def vae_encode(pixels):
latents = vae.encode(pixels.to(device)).latent_dist.sample()
return latents * vae.config.scaling_factor
def sample_timesteps(latents, timestep_range=None):
min_timestep = timestep_range[0] if timestep_range is not None else 0
max_timestep = timestep_range[1] if timestep_range is not None else noise_scheduler.config.num_train_timesteps
timesteps = torch.randint(
min_timestep,
max_timestep,
(latents.shape[0],),
device = latents.device,
).long()
return timesteps
def sample_noise(latents, offset=0):
noise = torch.randn_like(latents)
if offset > 0:
noise += offset * torch.randn_like(latents[..., 0, 0])[..., None, None]
return noise
def mse_loss(pred, target, timesteps, log_to=None):
loss = F.mse_loss(pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) # reduce over all dimensions except batch
if log_to is not None:
for i in range(timesteps.shape[0]):
log_to["train_step"].append(global_step)
log_to["train_loss"].append(loss[i].item())
log_to["train_timestep"].append(timesteps[i].item())
debiased_loss = loss / (0.7365 * torch.exp(-0.0052 * timesteps)) # debias by loss/timestep fit function
return loss.mean(), debiased_loss.mean()
def get_pred(batch, dropout=0, offset=0, timestep_range=None, log_to=None):
pixels, captions = batch
encoder_hidden_states = encode_captions(captions, dropout=dropout)
latents = vae_encode(pixels)
timesteps = sample_timesteps(latents, timestep_range=timestep_range)
noise = sample_noise(latents, offset=offset)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
return mse_loss(model_pred, noise, timesteps, log_to=log_to)
def plot_logs(log_dict):
plt.scatter(log_dict["train_timestep"], log_dict["train_loss"], s=3, c=log_dict["train_step"], marker=".", cmap='cool')
plt.xlabel("timestep")
plt.ylabel("loss")
plt.yscale("log")
progress_bar = tqdm(range(0, train_steps))
while global_step < train_steps:
for step, batch in enumerate(train_dataloader):
optimizer.zero_grad()
loss, debiased_loss = get_pred(batch, log_to=train_logs)
t_writer.add_scalar("loss/train", loss.detach().item(), global_step * batch_size)
t_writer.add_scalar("loss/debiased", debiased_loss.detach().item(), global_step * batch_size)
loss.backward()
optimizer.step()
progress_bar.update(1)
global_step += 1
if global_step == 1 or global_step % val_steps == 0:
with torch.inference_mode(), temp_rng(val_seed):
inference_steps = len(val_dataloader) * val_repeats
if stable_train_loss:
inference_steps += len(test_dataloader) * val_repeats
temp_pbar = tqdm(range(inference_steps), desc="validation", leave=False)
test_loss = 0.0
val_loss = 0.0
for i in range(val_repeats):
min_timestep = int(i * noise_scheduler.config.num_train_timesteps / val_repeats)
max_timestep = int((i + 1) * noise_scheduler.config.num_train_timesteps / val_repeats)
if stable_train_loss:
for step, batch in enumerate(test_dataloader):
loss, _ = get_pred(batch, timestep_range=(min_timestep, max_timestep), log_to=test_logs)
test_loss += loss.detach().item()
temp_pbar.update(1)
for step, batch in enumerate(val_dataloader):
loss, _ = get_pred(batch, timestep_range=(min_timestep, max_timestep), log_to=val_logs)
val_loss += loss.detach().item()
temp_pbar.update(1)
del temp_pbar
plot_logs(train_logs)
t_writer.add_figure("train_loss", plt.gcf(), global_step * batch_size)
plot_logs(test_logs)
t_writer.add_figure("test_loss", plt.gcf(), global_step * batch_size)
t_writer.add_scalar("test/test", test_loss / (len(test_dataloader) * val_repeats), global_step * batch_size)
plot_logs(val_logs)
t_writer.add_figure("val_loss", plt.gcf(), global_step * batch_size)
t_writer.add_scalar("test/val", val_loss / (len(val_dataloader) * val_repeats), global_step * batch_size)
if global_step >= train_steps or global_step % save_steps == 0:
checkpoint_path = os.path.join(output_path, f"checkpoint-{global_step:08}")
unet.save_pretrained(os.path.join(checkpoint_path, "unet"), safe_serialization=True)
if train_te:
text_encoder.save_pretrained(os.path.join(checkpoint_path, "text_encoder"), safe_serialization=True)
if global_step >= train_steps:
break
if __name__ == "__main__":
experiment_name = "./experiments/example_finetune"
# dataset subfolders: example/train, example/test, example/val
dataset_path = "./datasets/example"
train(
output_path = experiment_name,
dataset_path = dataset_path,
lr = 5e-7,
train_steps = 25_000,
save_steps = 5_000,
val_steps = 500,
seed = 1234,
val_seed = 1234,
batch_size = 1,
val_repeats = 4,
device = "cuda",
)