Skip to content

Commit bebc280

Browse files
committed
support for transparent images
1 parent fcd35de commit bebc280

File tree

5 files changed

+19
-7
lines changed

5 files changed

+19
-7
lines changed

dalle_pytorch/dalle_pytorch.py

+1
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
120120
has_resblocks = num_resnet_blocks > 0
121121

122+
self.channels = channels
122123
self.image_size = image_size
123124
self.num_tokens = num_tokens
124125
self.num_layers = num_layers

dalle_pytorch/vae.py

+3
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(self):
108108
self.dec = load_model(download(OPENAI_VAE_DECODER_PATH))
109109
make_contiguous(self)
110110

111+
self.channels = 3
111112
self.num_layers = 3
112113
self.image_size = 256
113114
self.num_tokens = 8192
@@ -175,7 +176,9 @@ def __init__(self, vqgan_model_path=None, vqgan_config_path=None):
175176

176177
# f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
177178
f = config.model.params.ddconfig.resolution / config.model.params.ddconfig.attn_resolutions[0]
179+
178180
self.num_layers = int(log(f)/log(2))
181+
self.channels = 3
179182
self.image_size = 256
180183
self.num_tokens = config.model.params.n_embed
181184
self.is_gumbel = isinstance(self.model, GumbelVQ)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '1.5.2',
7+
version = '1.6.0',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

train_dalle.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,6 @@ def cp_path_to_dir(cp_path, tag):
268268
else:
269269
vae = OpenAIDiscreteVAE()
270270

271-
IMAGE_SIZE = vae.image_size
272271
resume_epoch = loaded_obj.get('epoch', 0)
273272
else:
274273
if exists(VAE_PATH):
@@ -296,8 +295,6 @@ def cp_path_to_dir(cp_path, tag):
296295
else:
297296
vae = OpenAIDiscreteVAE()
298297

299-
IMAGE_SIZE = vae.image_size
300-
301298
dalle_params = dict(
302299
num_text_tokens=tokenizer.vocab_size,
303300
text_seq_len=TEXT_SEQ_LEN,
@@ -319,6 +316,10 @@ def cp_path_to_dir(cp_path, tag):
319316
)
320317
resume_epoch = 0
321318

319+
IMAGE_SIZE = vae.image_size
320+
CHANNELS = vae.channels
321+
IMAGE_MODE = 'RGBA' if CHANNELS == 4 else 'RGB'
322+
322323
# configure OpenAI VAE for float16s
323324

324325
if isinstance(vae, OpenAIDiscreteVAE) and args.fp16:
@@ -345,8 +346,8 @@ def group_weight(model):
345346
is_shuffle = not distributed_utils.using_backend(distributed_utils.HorovodBackend)
346347

347348
imagepreproc = T.Compose([
348-
T.Lambda(lambda img: img.convert('RGB')
349-
if img.mode != 'RGB' else img),
349+
T.Lambda(lambda img: img.convert(IMAGE_MODE)
350+
if img.mode != IMAGE_MODE else img),
350351
T.RandomResizedCrop(IMAGE_SIZE,
351352
scale=(args.resize_ratio, 1.),
352353
ratio=(1., 1.)),

train_vae.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@
6868

6969
model_group.add_argument('--kl_loss_weight', type = float, default = 0., help = 'KL loss weight')
7070

71+
model_group.add_argument('--transparent', dest = 'transparent', action = 'store_true')
72+
7173
args = parser.parse_args()
7274

7375
# constants
@@ -88,6 +90,10 @@
8890
HIDDEN_DIM = args.hidden_dim
8991
KL_LOSS_WEIGHT = args.kl_loss_weight
9092

93+
TRANSPARENT = args.transparent
94+
CHANNELS = 4 if TRANSPARENT else 3
95+
IMAGE_MODE = 'RGBA' if TRANSPARENT else 'RGB'
96+
9197
STARTING_TEMP = args.starting_temp
9298
TEMP_MIN = args.temp_min
9399
ANNEAL_RATE = args.anneal_rate
@@ -107,7 +113,7 @@
107113
ds = ImageFolder(
108114
IMAGE_PATH,
109115
T.Compose([
110-
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
116+
T.Lambda(lambda img: img.convert(IMAGE_MODE) if img.mode != IMAGE_MODE else img),
111117
T.Resize(IMAGE_SIZE),
112118
T.CenterCrop(IMAGE_SIZE),
113119
T.ToTensor()
@@ -127,6 +133,7 @@
127133
image_size = IMAGE_SIZE,
128134
num_layers = NUM_LAYERS,
129135
num_tokens = NUM_TOKENS,
136+
channels = CHANNELS,
130137
codebook_dim = EMB_DIM,
131138
hidden_dim = HIDDEN_DIM,
132139
num_resnet_blocks = NUM_RESNET_BLOCKS

0 commit comments

Comments
 (0)