Skip to content

Commit

Permalink
update: training code
Browse files Browse the repository at this point in the history
  • Loading branch information
TangentOne committed Feb 11, 2025
1 parent 9dd755e commit 06abbaa
Show file tree
Hide file tree
Showing 68 changed files with 6,855 additions and 1,036 deletions.
107 changes: 87 additions & 20 deletions README.md
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
# GMem: Generative Modeling with Explicit Memory

![Teaser image](./docs/selected_pics.png)
# GMem: A Modular Approach for Ultra-Efficient Generative Models

<div align="center">
Yi Tang :man_student:, <a href="https://sp12138.github.io/">Peng Sun :man_artist:</a>, Zhenglin Cheng :man_student:, <a href="https://tlin-taolin.github.io/">Tao Lin :skier:</a>

<a href="https://arxiv.org/abs/2412.08781">[arXiv] :page_facing_up:</a> | <a href="#bibliography">[BibTeX] :label:</a>
</div>

![Teaser image](./assets/docs/selected_pics.png)

**ImageNet Generation (w/o cfg or any other guidance techniques):**
- **$256\times 256$**: ~**$20\text{h}$ total training time** ($160$ epochs) → $100$ NFE → **FID $1.59$**
- **$512\times 512$**: ~**$50\text{h}$ total training time** ($400$ epochs) → $100$ NFE → **FID $1.89$**



### Abstract
Recent studies indicate that the denoising process in deep generative diffusion models implicitly learns and memorizes semantic information from the data distribution. These findings suggest that capturing more complex data distributions requires larger neural networks, leading to a substantial increase in computational demands, which in turn become the primary bottleneck in both training and inference of diffusion models.
To this end, we introduce **G**enerative **M**odeling with **E**xplicit **M**emory **GMem**, leveraging an external memory bank in both training and sampling phases of diffusion models. This approach preserves semantic information from data distributions, reducing reliance on neural network capacity for learning and generalizing across diverse datasets. The results are significant: our **GMem** enhances both training, sampling efficiency, and generation quality. For instance, on ImageNet at $256 \times 256$ resolution, **GMem** accelerates SiT training by over $46.7\times$, achieving the performance of a SiT model trained for $7 M$ steps in fewer than $150K$ steps. Compared to the most efficient existing method, REPA, **GMem** still offers a $16\times$ speedup, attaining an FID score of 5.75 within $250K$ steps, whereas REPA requires over $4M$ steps. Additionally, our method achieves state-of-the-art generation quality, with an FID score of **3.56** without classifier-free guidance on ImageNet $256\times256$.

Recent studies indicate that the denoising process in deep generative diffusion models implicitly learns and memorizes semantic information from the data distribution.
These findings suggest that capturing more complex data distributions requires larger neural networks, leading to a substantial increase in computational demands, which in turn become the primary bottleneck in both training and inference of diffusion models.
To this end, we introduce GMem: A Modular Approach for Ultra-Efficient Generative Models.
Our approach GMem decouples the memory capacity from model and implements it as a separate, immutable memory set that preserves the essential semantic information in the data.
The results are significant: GMem enhances both training, sampling efficiency, and diversity generation.
This design on one hand reduces the reliance on network for memorize complex data distribution and thus enhancing both training and sampling efficiency.
On ImageNet at $256 \times 256$ resolution, GMem achieves a $50\times$ training speedup compared to SiT, reaching **FID $=7.66$** in fewer than $28$ epochs (**$\sim 4$ hours** training time), while SiT requires $1400$ epochs.
Without classifier-free guidance, GMem achieves state-of-the-art (SoTA) performance **FID $=1.59$** in $160$ epochs with **only $\sim 20$ hours** of training, outperforming LightningDiT which requires $800$ epochs and $\sim 95$ hours to attain FID $=2.17$.

---

Expand All @@ -31,32 +43,87 @@

---

### Getting Started
### Evaluation

To reproduce the results from the paper, run the following script:
To set up the evaluation and sampling of images from the pretrained GMem-XL model, here are the steps to follow:

```bash
bash scripts/sample-gmem-xl.sh
```
#### 1. **Download the Pretrained Weights:**

**Important:** make sure to change `--ckpt` to correct path.
- **Pretrained model**: Download the pretrained weights for the network and corresponding memory bank from the provided link on Huggingface:

| Backbone | Training Epoch | Dataset | Bank Size | FID | Download |
|----------------|----------------|---------------------------|-----------|-----|----------|
| LightningDiT-XL| 160 | ImageNet $256\times 256$ | 1.28M |1.53 | [Huggingface](https://huggingface.co/Tangentone/GMem) |

- **VA-VAE Tokenizer**: You also need the VA-VAE tokenizer. Download the tokenizer from the official repository at [VA-VAE on GitHub](https://github.com/hustvl/LightningDiT/tree/main?tab=readme-ov-file#inference-with-pre-trained-models).

#### 2. **Modify Config Files:**

- Once you’ve downloaded the necessary pretrained models and tokenizers, modify the following configuration files with the correct paths:

- **For the GMem model (`configs/gmem_sde_xl.yaml`)**:
- Update the `ckpt_path` with the location where you saved the pretrained weights.
- Update the `GMem:bank_path` with the location of the bank size data.
- Also, specify the path to the reference file (`VIRTUAL_imagenet256_labeled.npz`, see [ADM](https://github.com/openai/guided-diffusion) for details) for FID calculation in the `data:fid_reference_file` argument.

- **For the VA-VAE Tokenizer (`tokenizer/configs/vavae_f16d32.yaml`)**:
- Specify the path to the tokenizer in the `ckpt_path` section of the configuration.

#### 3. **Run Evaluation Scripts:**

- Use the provided script to sample images and automatically calculate the FID score:
```bash
bash scripts/evaluation_gmem_xl.sh
```

---

### Pre-trained Models and Memory Bank
### Memory Manipulation

#### **External Knowledge Manipulation**

We offer the following pre-trained model and memory bank here:
To incorporate external knowledge using previously unseen images, follow the steps below:

#### GMem Checkpoints
| Backbone | Training Steps | Dataset | Bank Size | Training Epo. | Download |
|----------------|----------------|---------------------------|-----------|---------------|----------|
| SiT-XL/2 | 2M | ImageNet $256\times 256$ | 640,000 | 5 | [Huggingface](https://huggingface.co/Tangentone/GMem) |
1. Store the new images in the `assets/novel_images` directory.
2. Execute the script to generate new images:
```bash
bash scripts/external_knowledge_generation.sh
```

#### **Internal Knowledge Manipulation**

To generate new memory snippets by interpolating between two existing images, follow these steps:

1. Place the source images in the `assets/interpolation/lerp/a` and `assets/interpolation/lerp/b` directories, ensuring both images have identical filenames.
2. Run the script to create interpolated images:
```bash
bash scripts/internal_knowledge_generation.sh
```

---

### Preparing Data

1. **Set up VA-VAE**: Follow the instructions in the **Evaluation** and [VA-VAE tutorial](https://github.com/hustvl/LightningDiT/blob/main/docs/tutorial.md) to properly set up and configure the VA-VAE model.

2. **Extract Latents**: Once VA-VAE is set up, you can run the following script to extract the latents for all ImageNet images:
```bash
bash scripts/preprocessing.sh
```
This script will process all ImageNet images and store their corresponding latents.

3. **Modify the Configuration**: After extracting the latents, you need to update the `data:data_path` in the `configs/gmem_sde_xl.yaml` file. Set this path to the location where the extracted latents are stored. This ensures that GMem-XL can access the processed latents during training.

---

### Additional Information
### Train GMem

With the data prepared and the latents extracted, you can proceed to train the GMem-XL model by simply run the following script:

```bash
bash scripts/train_gmem_xl.sh
```

- Up next: the training code and scripts for GMem.

---

Expand All @@ -76,6 +143,6 @@ If you find this repository helpful for your project, please consider citing our

### Acknowledgement

This code is mainly built upon [SiT](https://github.com/willisma/SiT), [edm2](https://github.com/NVlabs/edm2), and [REPA](https://github.com/sihyun-yu/REPA) repositories.
This code is mainly built upon [VA-VAE](https://github.com/hustvl/LightningDiT), [SiT](https://github.com/willisma/SiT), [edm2](https://github.com/NVlabs/edm2), and [REPA](https://github.com/sihyun-yu/REPA) repositories.


Binary file added assets/docs/selected_pics.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/interplotation/lerp/a/hated_dog_a.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/interplotation/lerp/a/horse_maple_a.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/interplotation/lerp/a/picasso_bottle_a.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/interplotation/lerp/a/toy_cat_a.webp
Binary file not shown.
Binary file added assets/interplotation/lerp/b/hated_dog_b.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/interplotation/lerp/b/horse_maple_b.webp
Binary file not shown.
Binary file added assets/interplotation/lerp/b/picasso_bottle_b.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/interplotation/lerp/b/toy_cat_b.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/novel_images/cat_neon.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/novel_images/dog_lowpoly.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/novel_images/dog_silhouette.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/novel_images/mug_graffiti.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/novel_images/shoe_graffiti.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/novel_images/tower_lego.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/output/interploation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/output/new_snippet_generation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
90 changes: 90 additions & 0 deletions configs/gmem_ode_b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# checkpoint path, only enabled during inference
ckpt_path: 'output/gmem_b_vavae_f16d32/checkpoints/0080000.pt'

# imagenet safetensor data, see datasets/img_latent_dataset.py for details
data:
data_path: 'data/preprocessed/in1k256/vavae_f16d32/imagenet_train_256'
# fid reference file, see ADM<https://github.com/openai/guided-diffusion> for details
fid_reference_file: 'data/fids/VIRTUAL_imagenet256_labeled.npz'
image_size: 256
num_classes: 1000
num_workers: 8
# latent normalization, originated from our previous research FasterDiT <https://arxiv.org/abs/2410.10356>
# The standard deviation of latents directly affects the SNR distribution during training
# Channel-wise normalization provides stability but may not be optimal for all cases.
latent_norm: true
latent_multiplier: 1.0

# our pre-trained vision foundation model aligned VAE. see VA-VAE <to be released> for details.
vae:
model_name: 'vavae_f16d32'
downsample_ratio: 16

# We explored several optimization techniques for transformers:
model:
model_type: LightningDiT-B/1
use_qknorm: false
use_swiglu: true
use_rope: true
use_rmsnorm: true
wo_shift: false
in_chans: 32

# training parameters
train:
max_steps: 80000
# We use large batch training (1024) with adjusted learning rate and beta2 accordingly
# this is inspired by AuraFlow and muP.
global_batch_size: 1024
global_seed: 0
output_dir: 'output'
exp_name: 'gmem_b_vavae_f16d32'
ckpt: null
log_every: 100
ckpt_every: 20000

optimizer:
lr: 0.0002
beta2: 0.95
max_grad_norm: 1.0

# we use rectified flow for fast training.
transport:
# We inherit these settings from SiT, no parameters are changed
path_type: Linear
prediction: velocity
loss_weight: null
sample_eps: null
train_eps: null

# Inspired by SD3 and our previous work FasterDiT
# In small-scale experiments, we enable lognorm
# In large-scale experiments, we disable lognorm at the mid of training
use_lognorm: true
# cosine loss is enabled at all times
use_cosine_loss: true

# REPA settings
proj_loss_weight: 0.5

sample:
mode: ODE
# here we mainly adopt 2 settings: 1. dopri5, 2. euler
# dopri5 has adaptive step size, which is faster but has a slight performance drop
sampling_method: euler
atol: 0.000001
rtol: 0.001
reverse: false
likelihood: false
num_sampling_steps: 50
cfg_scale: 1
per_proc_batch_size: 32
fid_num: 50000

# cfg interval, it is inspired by <https://arxiv.org/abs/2404.07724>
cfg_interval_start: 0.89
# timestep shift, it is inspired by FLUX. please refer to transport/integrators.py ode function for details.
timestep_shift: 0.3
GMem:
bank_type: 'full'
bank_path: 'data/preprocessed/in1k256/banks/in1k256_Kfull_dim768_init_seed0.pth.pth'
89 changes: 89 additions & 0 deletions configs/gmem_ode_xl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# checkpoint path, only enabled during inference
ckpt_path: 'output/gmem_xl_vavae_f16d32/checkpoints/0200000.pt'

# imagenet safetensor data, see datasets/img_latent_dataset.py for details
data:
data_path: 'data/preprocessed/in1k256/vavae_f16d32/imagenet_train_256'
# fid reference file, see ADM<https://github.com/openai/guided-diffusion> for details
fid_reference_file: 'data/fids/VIRTUAL_imagenet256_labeled.npz'
image_size: 256
num_classes: 1000
num_workers: 8
# latent normalization, originated from our previous research FasterDiT <https://arxiv.org/abs/2410.10356>
# The standard deviation of latents directly affects the SNR distribution during training
# Channel-wise normalization provides stability but may not be optimal for all cases.
latent_norm: true
latent_multiplier: 1.0

# our pre-trained vision foundation model aligned VAE. see VA-VAE <to be released> for details.
vae:
model_name: 'vavae_f16d32'
downsample_ratio: 16

# We explored several optimization techniques for transformers:
model:
model_type: LightningDiT-XL/1
use_qknorm: false
use_swiglu: true
use_rope: true
use_rmsnorm: true
wo_shift: false
in_chans: 32

# training parameters
train:
max_steps: 200000
# We use large batch training (1024) with adjusted learning rate and beta2 accordingly
# this is inspired by AuraFlow and muP.
global_batch_size: 1024
global_seed: 0
output_dir: 'output'
exp_name: 'gmem_xl_vavae_f16d32'
ckpt: null
log_every: 100
ckpt_every: 2500

optimizer:
lr: 0.0002
beta2: 0.95
max_grad_norm: 1.0

# we use rectified flow for fast training.
transport:
# We inherit these settings from SiT, no parameters are changed
path_type: Linear
prediction: velocity
loss_weight: null
sample_eps: null
train_eps: null

# Inspired by SD3 and our previous work FasterDiT
# In small-scale experiments, we enable lognorm
# In large-scale experiments, we disable lognorm at the mid of training
use_lognorm: true
# cosine loss is enabled at all times
use_cosine_loss: true

# REPA settings
proj_loss_weight: 0.5

sample:
mode: ODE
# here we mainly adopt 2 settings: 1. dopri5, 2. euler
# dopri5 has adaptive step size, which is faster but has a slight performance drop
sampling_method: euler
atol: 0.000001
rtol: 0.001
reverse: false
likelihood: false
cfg_scale: 1
num_sampling_steps: 50
per_proc_batch_size: 32
fid_num: 50000

cfg_interval_start: 0.89
timestep_shift: 0.3

GMem:
bank_type: 'full'
bank_path: 'data/preprocessed/in1k256/banks/in1k256_Kfull_dim768_init_seed0.pth.pth'
Loading

0 comments on commit 06abbaa

Please sign in to comment.