Skip to content
/ eqvae Public

EQ-VAE: Equivariance Regularized Latent Space for Improved Generative Image Modeling.

Notifications You must be signed in to change notification settings

zelaki/eqvae

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

EQ-VAE: Equivariance Regularized Latent Space for Improved Generative Image Modeling

Theodoros Kouzelis1,3·Ioannis Kakogeorgiou1·Spyros Gidaris2·Nikos Komodakis1,4,5
1 Archimedes/Athena RC   2 valeo.ai   3 National Technical University of Athens  
4 University of Crete   5 IACM-Forth  

teaser.png


TL;DR: We propose EQ-VAE, a straightforward regularization objective that promotes equivariance in the latent space of pretrained autoencoders under scaling and rotation. This leads to a more structured latent distribution, which accelerates generative model training and improves performance.

0. Quick Start with Hugging Face

If you just want to use EQ-VAE to speedup 🚀 the training on your diffusion model you can use our HuggingFace checkpoints 🤗. We provide two models eq-vae and eq-vae-ema.

Model Basemodel Dataset Epochs rFID PSNR LPIPS SSIM
eq-vae SD-VAE OpenImages 5 0.82 25.95 0.141 0.72
eq-vae-ema SD-VAE Imagenet 44 0.55 26.15 0.133 0.72
from diffusers import AutoencoderKL
eqvae = AutoencoderKL.from_pretrained("zelaki/eq-vae")

If you are looking for the weights in the original LDM format you can find them here: eq-vae-ldm, eq-vae-ema-ldm

1. Environment setup

conda env create -f environment.yml
conda activate eqvae

2. Train EQ-VAE

We provide a training script to finetune SD-VAE with EQ-VAE regularization. For detailed guide go to train_eqvae.

3. Evaluate Reconstruction

To evaluate the reconstruction of EQ-VAE, calculate rFID, LPIPS, SSIM and PSNR on a validation set (we use Imagenet Validation in our paper) with the following:

torchrun --nproc_per_node=8 eval.py \
  --data_path /path/to/imagenet/validation \
  --output_path results \
  --ckpt_path /path/to/your/ckpt

4. Train DiT with EQ-VAE

To train a DiT model with EQ-VAE on ImageNet:

  • First extract the latent representations:
torchrun --nnodes=1 --nproc_per_node=8  train_gen/extract_features.py \
    --data-path /path/to/imagenet/train \
    --features-path /path/to/latents \
    --vae-ckpt /path/to/eqvae.ckpt \
    --vae-config configs/eqvae_config.yaml 
  • Then train DiT on the precomputed latents:
accelerate launch --mixed_precision fp16 train_gen/train.py \
    --model DiT-XL/2 \
    --feature-path /path/to/latents \
    --results-dir results
  • Evaluate generation as follows:
torchrun --nnodes=1 --nproc_per_node=8 sample_ddp.py \
    --model DiT-XL/2 \
    --num-fid-samples 50000 \
    --ckpt /path/to/dit.cpt \
    --sample-dir samples \
    --vae-ckpt /path/to/eqvae.ckpt \
    --vae-config configs/eqvae_config.yaml \
    --ddpm True \
    --cfg-scale 1.0

This script generates a folder of 50k samples as well as a .npz file and directly used with ADM's TensorFlow evaluation suite to compute gFID.

Acknowledgement

This code is mainly built upon LDM and fastDiT.

Citation

@inproceedings{eqvae,
  title = {EQ-VAE: Equivariance Regularized Latent Space for
    Improved Generative Image Modeling},
  author = {Kouzelis, Theodoros and Kakogeorgiou, Ioannis and Gidaris, Spyros and Komodakis, Nikos},
  booktitle = {arxiv},
  year = {2025},
}

``

About

EQ-VAE: Equivariance Regularized Latent Space for Improved Generative Image Modeling.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages