Implementation of "LiteVAE: Lightweight and Efficient Variational Autoencoders for Latent Diffusion Models" [2024]. The paper introduces an efficient wavelet-encoder-based variational autoencoder, which demonstrates a significant performance improvement and stable training compared with previous works. This implementation aims to replicate and extend its findings using GPU-accelerated wavelet transformations (torch-dwt), stochastic image rescaling, and improved discriminator models.
This implementation was independently developed before the authors provided pseudocode in the appendix of their paper. As a result, the approach here may differ slightly in details but adheres to the paper's methodology and goals.
@inproceedings{
sadat2024litevae,
title={Lite{VAE}: Lightweight and Efficient Variational Autoencoders for Latent Diffusion Models},
author={Seyedmorteza Sadat and Jakob Buhmann and Derek Bradley and Otmar Hilliges and Romann M. Weber},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=mTAbl8kUzq}
}
Comparison of encoder model configurations for nz=12, following the structure proposed by LiteVAE. FLOPs are reported for a resolution of 256x256. For comparison, the SD-VAE encoder is 34.16 M parameters and uses 137 GFLOPs at 256x156. LiteVAE showed equivelent performance with the B-scale encoder when matching nz.
Scale | Parameter Count | FLOPs | Extractor | Aggregator | |||||
---|---|---|---|---|---|---|---|---|---|
LiteVAE | Ours | Ours | C | Mult | Blocks | C | Mult | Blocks | |
S | 1.03 M | 0.97 M | 5.5 G | 16 | [1,2,2] | 3 | 16 | [1,2,2] | 3 |
B | 6.75 M | 6.60 M | 37.4 G | 32 | [1,2,3] | 4 | 32 | [1,2,3] | 4 |
M | 32.75 M | 34.00 M | 234.9 G | 64 | [1,2,4] | 5 | 32 | [1,2,3] | 4 |
L | 41.42 M | 41.15 M | 242.1 G | 64 | [1,2,4] | 5 | 64 | [1,2,4] | 4 |
We follow LiteVAE basing our decoder on the SD VAE, replacing Conv-Norm pairs with SMC layers. Skip connections retain the original Conv2d as they are used to adapt feature dimensions. We also remove the single-headed attention layer from the VAE in place of an additional ResNet block since this layer has little effect and limits the decoder resolution. FLOPs are reported for a resolution of 256x256.
Model | Params | FLOPs | C | Mult | Blocks | Attn |
---|---|---|---|---|---|---|
SDVAE | 49.49 M | 312 G | 128 | [1,2,4,4] + 4 | [3,3,3,3] + 2 | [ ] + 1 |
Ours | 53.33 M | 324 G | 128 | [1,2,4,4] + 4 | [3,3,3,3] + 3 | None |
Comparison of Discriminators, scaled to roughly match FLOPs between variants. Notably the GigaGAN discriminator contains more parameters due to the multiple predictor levels and soft filter-bank convolution layers. FLOPs are reported for a resolution of 256x256.
Model | Params | FLOPs | Config (256x256) |
---|---|---|---|
PatchGAN | 2.77M | 3.15G | nlayers=3, ndf=64 |
GigaGAN | 14.38M | 3.23G | Cbase=4096, Cmax=256, nblocks=2, attn=[8,16] |
UNetGAN-S | 2.75M | 2.31G | Dch=16, attn=None |
UNetGAN-M | 11.0M | 9.13G | Dch=32, attn=None |
UNetGAN-L | 44.1M | 36.34G | Dch=64, attn=None |
Evaluation metrics were computed on the ImageNet training set with the B-Scale encoder using nz=12. Training is conducted in two phases: A) pre-training at 128x128 with no discriminator for 100k steps, B) finetuning at 256x256 with discriminator for 50k steps.
All metrics are computed on the full ImageNet-1k validation set (50k images) using bi-cubic rescaling and center cropping. Comparing reported VAE (retrained SDVAE) and LiteVAE to configurations in this repo. Also showing the SD1-MSE VAE (nz=4) and SD3-VAE (nz=16). Notably, results seem to be highly dependent on loss weights (including discriminator), trading rFID for LPIPS, PSNR, and SSIM.
Method | Disc Type | Loss Weights | Evaluation 256x256 | |||||
---|---|---|---|---|---|---|---|---|
wkl | wwave | wgauss | LPIPS | PSNR | rFID | SSIM | ||
SD1-VAE | PatchGAN | ? | N/A | N/A | 0.138 | 25.70 | 0.75 | 0.72 |
SD3-VAE | PatchGAN? | ? | N/A | N/A | 0.069 | 29.59 | 0.22 | 0.86 |
VAE | PatchGAN | ? | N/A | N/A | 0.069 | 29.25 | 0.95 | 0.86 |
LiteVAE | UNetGAN-? | ? | ? | ? | 0.069 | 29.55 | 0.94 | 0.87 |
A1 (Ours) | N/A | 0.01 | 0.1 | 0.1 | 0.081 | 29.38 | 1.86 | 0.85 |
A2 (Ours) | N/A | 0.01 | 1.0 | 0.5 | 0.084 | 29.78 | 1.12 | 0.86 |
B0 (Ours) | None | 0.01 | 1.0 | 0.5 | 0.076 | 30.03 | 1.29 | 0.86 |
B1 (Ours) | PatchGAN | 0.01 | 1.0 | 0.5 | 0.080 | 29.12 | 0.30 | 0.84 |
B2 (Ours) | GigaGAN | 0.01 | 1.0 | 0.5 | 0.080 | 29.27 | 0.31 | 0.85 |
B3 (Ours) | UNetGAN-S | 0.01 | 1.0 | 0.5 | 0.080 | 29.15 | 0.38 | 0.85 |
B4 (Ours) | UNetGAN-M | 0.01 | 1.0 | 0.5 | 0.080 | 29.07 | 0.30 | 0.84 |
B5 (Ours) | UNetGAN-L | 0.01 | 1.0 | 0.5 | 0.084 | 28.74 | 0.24 | 0.84 |
- We achieve similar results to the original LiteVAE paper, which are in turn competitive with the SD3-VAE, with differences that can be attributed to the value of nz.
- The differing LPIPS and rFID scores from LiteVAE may be attributed to the KL loss weight (wkl), where a higher value will result in worse reconstruction but stronger adherence to the latent space unit Gaussian objective.
- rFID and the other metrics (LPIPS, SSIM, and PSNR) are inversely related, where a better rFID will degrade the other metrics. However, while FID itself has many issues, it appears to be a stronger indicator of image quality compare with the other metrics.
- Overall, the discriminator type has a minor impact, with the original PatchGAN discriminator performing well. UNetGAN may perform better when considering a larger version, but this becomes difficult to justify given the increased training FLOPs.
configs/
- Configuration files for setting up the experiments and models.
olvae/
- Main directory for the Open LiteVAE project.
├── data/
│ - Contains the dataloader for data preparation and augmentation.
├── models/
│ - Contains PyTorch Lightning models for training and evaluation.
├── modules/
- Contains model-specific layers and architectures.
├── litevae/
│ - Specialized layers for the LiteVAE model.
├── basicgan/
│ - Common GAN components and loss functions [include PatchGAN discriminator].
├── gigagan/
│ - Layers specific to the GigaGAN discriminator.
├── unetgan/
│ - Layers and components for the UnetGAN discriminator.
scripts/
- Code for evaluation, testing, and utilities.
- Python >= 3.9
- PyTorch >= 2.0
- Torch-DWT
Install via requirements.txt
pip install -r requirements.txt
pip install -e .
Coming Soon.
Run the compute_metrics.py
script to compute the evaluation metrics. Will compute LPIPS, PSNR, rFID, and SSIM using the test loader specified in the configuration file.
Expects the parent directory structure to be in the form <parent>/checkpoints/last.ckpt
and will search for the <config.yaml>
in the <config directory>
which follows from <parent> = <timestamp>_<config>_<precision>
. This search will still work if the timestamp and precision are not included as <config>
is extracted by splitting on _
and removing the first and last element in the list.
python script/compute_metrics.py --config_base <config directory> -B <batch_size> <run parent directory>
Alternatively, you can specify the configuration and checkpoint directly.
python script/compute_metrics.py --config <config.yaml> -B <batch_size> <checkpoint.ckpt>
The example code uses OmegaConfig yaml files to specify and dynamically construct the models. By default, we utilize the WebDataset format to efficiently stream large collections of data from disk, with an interface constructor for train_loader()
, val_loader()
, and test_loader()
. Notably, this pattern can wrap other dataloader types such as the typical ImageDataset or a default torchvision dataset.
Training utilizes Pytorch Lightning to handle multi-GPU communication, mixed-precision training, and gradient accumulation. As such, training becomes relatively simple. The current codebase supports training in full FP32 (32), half precision (16), and bfloat (bf16).
python train.py --base <config.yaml> \
--logdir <log directory> \
--precision <training precision> \
--gpus <gpu count or list> \
--seed <random seed> \
--name <config_name>
This codebase also supports resuming from a stopped run. Note: we recommend changing the random seed when resuming.
python train.py --base <config.yaml> \
--logdir <log directory> \
--precision <training precision> \
--gpus <gpu count or list> \
--seed <random seed> \
--resume <previous run directory>
For transitioning between PhaseA and PhaseB training, we recommend using --actual_resume
which will reset the optimizer states. This trick is adapted from Textual-Inversion.
python train.py --base <config.yaml> \
--logdir <log directory> \
--precision <training precision> \
--gpus <gpu count or list> \
--seed <random seed> \
--actual_resume <previous.ckpt> \
--name <config_name>
- Add Description of Improved Methods
- Add Training Code
- Add Evaluation Code
- More Experiments
@inproceedings{
sadat2024litevae,
title={Lite{VAE}: Lightweight and Efficient Variational Autoencoders for Latent Diffusion Models},
author={Seyedmorteza Sadat and Jakob Buhmann and Derek Bradley and Otmar Hilliges and Romann M. Weber},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=mTAbl8kUzq}
}
@article{Rombach2021HighResolutionIS,
title={High-Resolution Image Synthesis with Latent Diffusion Models},
author={Robin Rombach and A. Blattmann and Dominik Lorenz and Patrick Esser and Bj{\"o}rn Ommer},
journal={2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2021},
pages={10674-10685},
}
@article{Isola2016ImagetoImageTW,
title={Image-to-Image Translation with Conditional Adversarial Networks},
author={Phillip Isola and Jun-Yan Zhu and Tinghui Zhou and Alexei A. Efros},
journal={2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2016},
pages={5967-5976},
}
@article{Kang2023ScalingUG,
title={Scaling up GANs for Text-to-Image Synthesis},
author={Minguk Kang and Jun-Yan Zhu and Richard Zhang and Jaesik Park and Eli Shechtman and Sylvain Paris and Taesung Park},
journal={2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2023},
pages={10124-10134},
}
@article{Schnfeld2020AUB,
title={A U-Net Based Discriminator for Generative Adversarial Networks},
author={Edgar Sch{\"o}nfeld and Bernt Schiele and Anna Khoreva},
journal={2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2020},
pages={8204-8213},
}