1 Archimedes/Athena RC 2 valeo.ai 3 National Technical University of Athens
4 University of Crete 5 IACM-Forth
TL;DR: We propose EQ-VAE, a straightforward regularization objective that promotes equivariance in the latent space of pretrained autoencoders under scaling and rotation. This leads to a more structured latent distribution, which accelerates generative model training and improves performance.
If you just want to use EQ-VAE to speedup 🚀 the training on your diffusion model you can use our HuggingFace checkpoints 🤗. We provide two models eq-vae and eq-vae-ema.
Model | Basemodel | Dataset | Epochs | rFID | PSNR | LPIPS | SSIM |
---|---|---|---|---|---|---|---|
eq-vae | SD-VAE | OpenImages | 5 | 0.82 | 25.95 | 0.141 | 0.72 |
eq-vae-ema | SD-VAE | Imagenet | 44 | 0.55 | 26.15 | 0.133 | 0.72 |
from diffusers import AutoencoderKL
eqvae = AutoencoderKL.from_pretrained("zelaki/eq-vae")
If you are looking for the weights in the original LDM format you can find them here: eq-vae-ldm, eq-vae-ema-ldm
conda env create -f environment.yml
conda activate eqvae
We provide a training script to finetune SD-VAE with EQ-VAE regularization. For detailed guide go to train_eqvae.
To evaluate the reconstruction of EQ-VAE, calculate rFID, LPIPS, SSIM and PSNR on a validation set (we use Imagenet Validation in our paper) with the following:
torchrun --nproc_per_node=8 eval.py \
--data_path /path/to/imagenet/validation \
--output_path results \
--ckpt_path /path/to/your/ckpt
To train a DiT model with EQ-VAE on ImageNet:
- First extract the latent representations:
torchrun --nnodes=1 --nproc_per_node=8 train_gen/extract_features.py \
--data-path /path/to/imagenet/train \
--features-path /path/to/latents \
--vae-ckpt /path/to/eqvae.ckpt \
--vae-config configs/eqvae_config.yaml
- Then train DiT on the precomputed latents:
accelerate launch --mixed_precision fp16 train_gen/train.py \
--model DiT-XL/2 \
--feature-path /path/to/latents \
--results-dir results
- Evaluate generation as follows:
torchrun --nnodes=1 --nproc_per_node=8 sample_ddp.py \
--model DiT-XL/2 \
--num-fid-samples 50000 \
--ckpt /path/to/dit.cpt \
--sample-dir samples \
--vae-ckpt /path/to/eqvae.ckpt \
--vae-config configs/eqvae_config.yaml \
--ddpm True \
--cfg-scale 1.0
This script generates a folder of 50k samples as well as a .npz file and directly used with ADM's TensorFlow evaluation suite to compute gFID.
This code is mainly built upon LDM and fastDiT.
@inproceedings{eqvae,
title = {EQ-VAE: Equivariance Regularized Latent Space for
Improved Generative Image Modeling},
author = {Kouzelis, Theodoros and Kakogeorgiou, Ioannis and Gidaris, Spyros and Komodakis, Nikos},
booktitle = {arxiv},
year = {2025},
}
``