Skip to content

RGenDiff/open-litevae

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

open-LiteVAE


arxiv Hugging Face Collections

Implementation of "LiteVAE: Lightweight and Efficient Variational Autoencoders for Latent Diffusion Models" [2024]. The paper introduces an efficient wavelet-encoder-based variational autoencoder, which demonstrates a significant performance improvement and stable training compared with previous works. This implementation aims to replicate and extend its findings using GPU-accelerated wavelet transformations (torch-dwt), stochastic image rescaling, and improved discriminator models.

Note

This implementation was independently developed before the authors provided pseudocode in the appendix of their paper. As a result, the approach here may differ slightly in details but adheres to the paper's methodology and goals.

Please Cite the Original Paper

@inproceedings{
sadat2024litevae,
title={Lite{VAE}: Lightweight and Efficient Variational Autoencoders for Latent Diffusion Models},
author={Seyedmorteza Sadat and Jakob Buhmann and Derek Bradley and Otmar Hilliges and Romann M. Weber},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=mTAbl8kUzq}
}

Model Configurations

Encoder

Comparison of encoder model configurations for nz=12, following the structure proposed by LiteVAE. FLOPs are reported for a resolution of 256x256. For comparison, the SD-VAE encoder is 34.16 M parameters and uses 137 GFLOPs at 256x156. LiteVAE showed equivelent performance with the B-scale encoder when matching nz.

Scale Parameter Count FLOPs Extractor Aggregator
LiteVAE Ours Ours C Mult Blocks C Mult Blocks
S 1.03 M 0.97 M 5.5 G 16 [1,2,2] 3 16 [1,2,2] 3
B 6.75 M 6.60 M 37.4 G 32 [1,2,3] 4 32 [1,2,3] 4
M 32.75 M 34.00 M 234.9 G 64 [1,2,4] 5 32 [1,2,3] 4
L 41.42 M 41.15 M 242.1 G 64 [1,2,4] 5 64 [1,2,4] 4

Decoder

We follow LiteVAE basing our decoder on the SD VAE, replacing Conv-Norm pairs with SMC layers. Skip connections retain the original Conv2d as they are used to adapt feature dimensions. We also remove the single-headed attention layer from the VAE in place of an additional ResNet block since this layer has little effect and limits the decoder resolution. FLOPs are reported for a resolution of 256x256.

Model Params FLOPs C Mult Blocks Attn
SDVAE 49.49 M 312 G 128 [1,2,4,4] + 4 [3,3,3,3] + 2 [ ] + 1
Ours 53.33 M 324 G 128 [1,2,4,4] + 4 [3,3,3,3] + 3 None

Discriminator

Comparison of Discriminators, scaled to roughly match FLOPs between variants. Notably the GigaGAN discriminator contains more parameters due to the multiple predictor levels and soft filter-bank convolution layers. FLOPs are reported for a resolution of 256x256.

Model Params FLOPs Config (256x256)
PatchGAN 2.77M 3.15G nlayers=3, ndf=64
GigaGAN 14.38M 3.23G Cbase=4096, Cmax=256, nblocks=2, attn=[8,16]
UNetGAN-S 2.75M 2.31G Dch=16, attn=None
UNetGAN-M 11.0M 9.13G Dch=32, attn=None
UNetGAN-L 44.1M 36.34G Dch=64, attn=None

Comparisons

Evaluation Metrics

Evaluation metrics were computed on the ImageNet training set with the B-Scale encoder using nz=12. Training is conducted in two phases: A) pre-training at 128x128 with no discriminator for 100k steps, B) finetuning at 256x256 with discriminator for 50k steps.

All metrics are computed on the full ImageNet-1k validation set (50k images) using bi-cubic rescaling and center cropping. Comparing reported VAE (retrained SDVAE) and LiteVAE to configurations in this repo. Also showing the SD1-MSE VAE (nz=4) and SD3-VAE (nz=16). Notably, results seem to be highly dependent on loss weights (including discriminator), trading rFID for LPIPS, PSNR, and SSIM.

Method Disc Type Loss Weights Evaluation 256x256
wkl wwave wgauss LPIPS PSNR rFID SSIM
SD1-VAE PatchGAN ? N/A N/A 0.138 25.70 0.75 0.72
SD3-VAE PatchGAN? ? N/A N/A 0.069 29.59 0.22 0.86
VAE PatchGAN ? N/A N/A 0.069 29.25 0.95 0.86
LiteVAE UNetGAN-? ? ? ? 0.069 29.55 0.94 0.87
A1 (Ours) N/A 0.01 0.1 0.1 0.081 29.38 1.86 0.85
A2 (Ours) N/A 0.01 1.0 0.5 0.084 29.78 1.12 0.86
B0 (Ours) None 0.01 1.0 0.5 0.076 30.03 1.29 0.86
B1 (Ours) PatchGAN 0.01 1.0 0.5 0.080 29.12 0.30 0.84
B2 (Ours) GigaGAN 0.01 1.0 0.5 0.080 29.27 0.31 0.85
B3 (Ours) UNetGAN-S 0.01 1.0 0.5 0.080 29.15 0.38 0.85
B4 (Ours) UNetGAN-M 0.01 1.0 0.5 0.080 29.07 0.30 0.84
B5 (Ours) UNetGAN-L 0.01 1.0 0.5 0.084 28.74 0.24 0.84

Latent Distributions

Latent Distribution Comparison

Comparison Findings

  • We achieve similar results to the original LiteVAE paper, which are in turn competitive with the SD3-VAE, with differences that can be attributed to the value of nz.
  • The differing LPIPS and rFID scores from LiteVAE may be attributed to the KL loss weight (wkl), where a higher value will result in worse reconstruction but stronger adherence to the latent space unit Gaussian objective.
  • rFID and the other metrics (LPIPS, SSIM, and PSNR) are inversely related, where a better rFID will degrade the other metrics. However, while FID itself has many issues, it appears to be a stronger indicator of image quality compare with the other metrics.
  • Overall, the discriminator type has a minor impact, with the original PatchGAN discriminator performing well. UNetGAN may perform better when considering a larger version, but this becomes difficult to justify given the increased training FLOPs.

Repository Structure

configs/         
    - Configuration files for setting up the experiments and models.

olvae/
    - Main directory for the Open LiteVAE project.
    ├── data/        
    │   - Contains the dataloader for data preparation and augmentation. 
    ├── models/      
    │   - Contains PyTorch Lightning models for training and evaluation. 
    ├── modules/     
        - Contains model-specific layers and architectures.
        ├── litevae/
        │   - Specialized layers for the LiteVAE model. 
        ├── basicgan/
        │   - Common GAN components and loss functions [include PatchGAN discriminator]. 
        ├── gigagan/
        │   - Layers specific to the GigaGAN discriminator. 
        ├── unetgan/
        │   - Layers and components for the UnetGAN discriminator. 

scripts/
    - Code for evaluation, testing, and utilities.

Setup

Dependencies

  • Python >= 3.9
  • PyTorch >= 2.0
  • Torch-DWT

Installation

Install via requirements.txt

pip install -r requirements.txt
pip install -e .

Model Checkpoints

Coming Soon.


Evaluation

Run the compute_metrics.py script to compute the evaluation metrics. Will compute LPIPS, PSNR, rFID, and SSIM using the test loader specified in the configuration file.

Expects the parent directory structure to be in the form <parent>/checkpoints/last.ckpt and will search for the <config.yaml> in the <config directory> which follows from <parent> = <timestamp>_<config>_<precision>. This search will still work if the timestamp and precision are not included as <config> is extracted by splitting on _ and removing the first and last element in the list.

python script/compute_metrics.py --config_base <config directory> -B <batch_size> <run parent directory>

Alternatively, you can specify the configuration and checkpoint directly.

python script/compute_metrics.py --config <config.yaml> -B <batch_size> <checkpoint.ckpt>

Training

The example code uses OmegaConfig yaml files to specify and dynamically construct the models. By default, we utilize the WebDataset format to efficiently stream large collections of data from disk, with an interface constructor for train_loader(), val_loader(), and test_loader(). Notably, this pattern can wrap other dataloader types such as the typical ImageDataset or a default torchvision dataset.

Training utilizes Pytorch Lightning to handle multi-GPU communication, mixed-precision training, and gradient accumulation. As such, training becomes relatively simple. The current codebase supports training in full FP32 (32), half precision (16), and bfloat (bf16).

python train.py --base <config.yaml> \
                --logdir <log directory> \
                --precision <training precision> \
                --gpus <gpu count or list> \
                --seed <random seed> \
                --name <config_name> 

This codebase also supports resuming from a stopped run. Note: we recommend changing the random seed when resuming.

python train.py --base <config.yaml> \
                --logdir <log directory> \
                --precision <training precision> \
                --gpus <gpu count or list> \
                --seed <random seed> \
                --resume <previous run directory>

For transitioning between PhaseA and PhaseB training, we recommend using --actual_resume which will reset the optimizer states. This trick is adapted from Textual-Inversion.

python train.py --base <config.yaml> \
                --logdir <log directory> \
                --precision <training precision> \
                --gpus <gpu count or list> \
                --seed <random seed> \
                --actual_resume <previous.ckpt> \
                --name <config_name> 

TODO

  • Add Description of Improved Methods
  • Add Training Code
  • Add Evaluation Code
  • More Experiments

References

LiteVAE

@inproceedings{
sadat2024litevae,
title={Lite{VAE}: Lightweight and Efficient Variational Autoencoders for Latent Diffusion Models},
author={Seyedmorteza Sadat and Jakob Buhmann and Derek Bradley and Otmar Hilliges and Romann M. Weber},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=mTAbl8kUzq}
}

SD-VAE

@article{Rombach2021HighResolutionIS,
  title={High-Resolution Image Synthesis with Latent Diffusion Models},
  author={Robin Rombach and A. Blattmann and Dominik Lorenz and Patrick Esser and Bj{\"o}rn Ommer},
  journal={2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2021},
  pages={10674-10685},
}

PatchGAN

@article{Isola2016ImagetoImageTW,
  title={Image-to-Image Translation with Conditional Adversarial Networks},
  author={Phillip Isola and Jun-Yan Zhu and Tinghui Zhou and Alexei A. Efros},
  journal={2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2016},
  pages={5967-5976},
}

GigaGAN

@article{Kang2023ScalingUG,
  title={Scaling up GANs for Text-to-Image Synthesis},
  author={Minguk Kang and Jun-Yan Zhu and Richard Zhang and Jaesik Park and Eli Shechtman and Sylvain Paris and Taesung Park},
  journal={2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2023},
  pages={10124-10134},
}

U-NetGAN

@article{Schnfeld2020AUB,
  title={A U-Net Based Discriminator for Generative Adversarial Networks},
  author={Edgar Sch{\"o}nfeld and Bernt Schiele and Anna Khoreva},
  journal={2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2020},
  pages={8204-8213},
}