@@ -122,7 +122,7 @@ def main(dataset_name='lambdalabs/naruto-blip-captions',
122
122
model_name_or_path = 'stabilityai/stable-diffusion-2-1-base' ,
123
123
init_ckpt_dir = './stable-diffusion-2-1-base' ,
124
124
n_model_shards = 1 ,
125
- n_epochs = 3 ,
125
+ n_epochs = 8 ,
126
126
global_batch_size = 8 ,
127
127
per_device_batch_size = 1 ,
128
128
learning_rate = 1e-5 ,
@@ -156,13 +156,13 @@ def main(dataset_name='lambdalabs/naruto-blip-captions',
156
156
text_encoder = FlaxCLIPTextModel (
157
157
config = CLIPTextConfig .from_pretrained (
158
158
model_name_or_path , subfolder = "text_encoder" ),
159
- dtype = jnp .float16 , _do_init = False )
159
+ dtype = jnp .float32 , _do_init = False )
160
160
vae = FlaxAutoencoderKL .from_config (
161
161
config = FlaxAutoencoderKL .load_config (
162
- model_name_or_path , subfolder = 'vae' ), dtype = jnp .float16 )
162
+ model_name_or_path , subfolder = 'vae' ), dtype = jnp .float32 )
163
163
unet = FlaxUNet2DConditionModel .from_config (
164
164
config = FlaxUNet2DConditionModel .load_config (
165
- model_name_or_path , subfolder = 'unet' ), dtype = jnp .float16 )
165
+ model_name_or_path , subfolder = 'unet' ), dtype = jnp .float32 )
166
166
noise_scheduler , noise_scheduler_state = \
167
167
FlaxPNDMScheduler .from_pretrained (
168
168
model_name_or_path , subfolder = 'scheduler' )
0 commit comments