Skip to content

Commit 133ef08

Browse files
committed
updated.
1 parent 59de981 commit 133ef08

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

examples/text_to_image/main.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def main(dataset_name='lambdalabs/naruto-blip-captions',
122122
model_name_or_path='stabilityai/stable-diffusion-2-1-base',
123123
init_ckpt_dir='./stable-diffusion-2-1-base',
124124
n_model_shards=1,
125-
n_epochs=3,
125+
n_epochs=8,
126126
global_batch_size=8,
127127
per_device_batch_size=1,
128128
learning_rate=1e-5,
@@ -156,13 +156,13 @@ def main(dataset_name='lambdalabs/naruto-blip-captions',
156156
text_encoder = FlaxCLIPTextModel(
157157
config=CLIPTextConfig.from_pretrained(
158158
model_name_or_path, subfolder="text_encoder"),
159-
dtype=jnp.float16, _do_init=False)
159+
dtype=jnp.float32, _do_init=False)
160160
vae = FlaxAutoencoderKL.from_config(
161161
config=FlaxAutoencoderKL.load_config(
162-
model_name_or_path, subfolder='vae'), dtype=jnp.float16)
162+
model_name_or_path, subfolder='vae'), dtype=jnp.float32)
163163
unet = FlaxUNet2DConditionModel.from_config(
164164
config=FlaxUNet2DConditionModel.load_config(
165-
model_name_or_path, subfolder='unet'), dtype=jnp.float16)
165+
model_name_or_path, subfolder='unet'), dtype=jnp.float32)
166166
noise_scheduler, noise_scheduler_state = \
167167
FlaxPNDMScheduler.from_pretrained(
168168
model_name_or_path, subfolder='scheduler')

examples/text_to_image/save_init_ckpt.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import fire
22
import jax
3-
import jax.numpy as jnp
43
from transformers import FlaxCLIPTextModel
54
from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel
65
from redco import Deployer
@@ -14,14 +13,11 @@ def main(model_name_or_path='stabilityai/stable-diffusion-2-1-base'):
1413

1514
with jax.default_device(jax.local_devices(backend='cpu')[0]):
1615
text_encoder = FlaxCLIPTextModel.from_pretrained(
17-
model_name_or_path,
18-
subfolder="text_encoder", from_pt=True, dtype=jnp.float16)
16+
model_name_or_path, subfolder="text_encoder", from_pt=True)
1917
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
20-
model_name_or_path,
21-
subfolder="vae", from_pt=True, dtype=jnp.float16)
18+
model_name_or_path, subfolder="vae", from_pt=True)
2219
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
23-
model_name_or_path,
24-
subfolder="unet", from_pt=True, dtype=jnp.float32)
20+
model_name_or_path, subfolder="unet", from_pt=True)
2521
params = {
2622
'text_encoder': text_encoder.params,
2723
'unet': unet_params,

0 commit comments

Comments
 (0)