-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9dd755e
commit 06abbaa
Showing
68 changed files
with
6,855 additions
and
1,036 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# checkpoint path, only enabled during inference | ||
ckpt_path: 'output/gmem_b_vavae_f16d32/checkpoints/0080000.pt' | ||
|
||
# imagenet safetensor data, see datasets/img_latent_dataset.py for details | ||
data: | ||
data_path: 'data/preprocessed/in1k256/vavae_f16d32/imagenet_train_256' | ||
# fid reference file, see ADM<https://github.com/openai/guided-diffusion> for details | ||
fid_reference_file: 'data/fids/VIRTUAL_imagenet256_labeled.npz' | ||
image_size: 256 | ||
num_classes: 1000 | ||
num_workers: 8 | ||
# latent normalization, originated from our previous research FasterDiT <https://arxiv.org/abs/2410.10356> | ||
# The standard deviation of latents directly affects the SNR distribution during training | ||
# Channel-wise normalization provides stability but may not be optimal for all cases. | ||
latent_norm: true | ||
latent_multiplier: 1.0 | ||
|
||
# our pre-trained vision foundation model aligned VAE. see VA-VAE <to be released> for details. | ||
vae: | ||
model_name: 'vavae_f16d32' | ||
downsample_ratio: 16 | ||
|
||
# We explored several optimization techniques for transformers: | ||
model: | ||
model_type: LightningDiT-B/1 | ||
use_qknorm: false | ||
use_swiglu: true | ||
use_rope: true | ||
use_rmsnorm: true | ||
wo_shift: false | ||
in_chans: 32 | ||
|
||
# training parameters | ||
train: | ||
max_steps: 80000 | ||
# We use large batch training (1024) with adjusted learning rate and beta2 accordingly | ||
# this is inspired by AuraFlow and muP. | ||
global_batch_size: 1024 | ||
global_seed: 0 | ||
output_dir: 'output' | ||
exp_name: 'gmem_b_vavae_f16d32' | ||
ckpt: null | ||
log_every: 100 | ||
ckpt_every: 20000 | ||
|
||
optimizer: | ||
lr: 0.0002 | ||
beta2: 0.95 | ||
max_grad_norm: 1.0 | ||
|
||
# we use rectified flow for fast training. | ||
transport: | ||
# We inherit these settings from SiT, no parameters are changed | ||
path_type: Linear | ||
prediction: velocity | ||
loss_weight: null | ||
sample_eps: null | ||
train_eps: null | ||
|
||
# Inspired by SD3 and our previous work FasterDiT | ||
# In small-scale experiments, we enable lognorm | ||
# In large-scale experiments, we disable lognorm at the mid of training | ||
use_lognorm: true | ||
# cosine loss is enabled at all times | ||
use_cosine_loss: true | ||
|
||
# REPA settings | ||
proj_loss_weight: 0.5 | ||
|
||
sample: | ||
mode: ODE | ||
# here we mainly adopt 2 settings: 1. dopri5, 2. euler | ||
# dopri5 has adaptive step size, which is faster but has a slight performance drop | ||
sampling_method: euler | ||
atol: 0.000001 | ||
rtol: 0.001 | ||
reverse: false | ||
likelihood: false | ||
num_sampling_steps: 50 | ||
cfg_scale: 1 | ||
per_proc_batch_size: 32 | ||
fid_num: 50000 | ||
|
||
# cfg interval, it is inspired by <https://arxiv.org/abs/2404.07724> | ||
cfg_interval_start: 0.89 | ||
# timestep shift, it is inspired by FLUX. please refer to transport/integrators.py ode function for details. | ||
timestep_shift: 0.3 | ||
GMem: | ||
bank_type: 'full' | ||
bank_path: 'data/preprocessed/in1k256/banks/in1k256_Kfull_dim768_init_seed0.pth.pth' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# checkpoint path, only enabled during inference | ||
ckpt_path: 'output/gmem_xl_vavae_f16d32/checkpoints/0200000.pt' | ||
|
||
# imagenet safetensor data, see datasets/img_latent_dataset.py for details | ||
data: | ||
data_path: 'data/preprocessed/in1k256/vavae_f16d32/imagenet_train_256' | ||
# fid reference file, see ADM<https://github.com/openai/guided-diffusion> for details | ||
fid_reference_file: 'data/fids/VIRTUAL_imagenet256_labeled.npz' | ||
image_size: 256 | ||
num_classes: 1000 | ||
num_workers: 8 | ||
# latent normalization, originated from our previous research FasterDiT <https://arxiv.org/abs/2410.10356> | ||
# The standard deviation of latents directly affects the SNR distribution during training | ||
# Channel-wise normalization provides stability but may not be optimal for all cases. | ||
latent_norm: true | ||
latent_multiplier: 1.0 | ||
|
||
# our pre-trained vision foundation model aligned VAE. see VA-VAE <to be released> for details. | ||
vae: | ||
model_name: 'vavae_f16d32' | ||
downsample_ratio: 16 | ||
|
||
# We explored several optimization techniques for transformers: | ||
model: | ||
model_type: LightningDiT-XL/1 | ||
use_qknorm: false | ||
use_swiglu: true | ||
use_rope: true | ||
use_rmsnorm: true | ||
wo_shift: false | ||
in_chans: 32 | ||
|
||
# training parameters | ||
train: | ||
max_steps: 200000 | ||
# We use large batch training (1024) with adjusted learning rate and beta2 accordingly | ||
# this is inspired by AuraFlow and muP. | ||
global_batch_size: 1024 | ||
global_seed: 0 | ||
output_dir: 'output' | ||
exp_name: 'gmem_xl_vavae_f16d32' | ||
ckpt: null | ||
log_every: 100 | ||
ckpt_every: 2500 | ||
|
||
optimizer: | ||
lr: 0.0002 | ||
beta2: 0.95 | ||
max_grad_norm: 1.0 | ||
|
||
# we use rectified flow for fast training. | ||
transport: | ||
# We inherit these settings from SiT, no parameters are changed | ||
path_type: Linear | ||
prediction: velocity | ||
loss_weight: null | ||
sample_eps: null | ||
train_eps: null | ||
|
||
# Inspired by SD3 and our previous work FasterDiT | ||
# In small-scale experiments, we enable lognorm | ||
# In large-scale experiments, we disable lognorm at the mid of training | ||
use_lognorm: true | ||
# cosine loss is enabled at all times | ||
use_cosine_loss: true | ||
|
||
# REPA settings | ||
proj_loss_weight: 0.5 | ||
|
||
sample: | ||
mode: ODE | ||
# here we mainly adopt 2 settings: 1. dopri5, 2. euler | ||
# dopri5 has adaptive step size, which is faster but has a slight performance drop | ||
sampling_method: euler | ||
atol: 0.000001 | ||
rtol: 0.001 | ||
reverse: false | ||
likelihood: false | ||
cfg_scale: 1 | ||
num_sampling_steps: 50 | ||
per_proc_batch_size: 32 | ||
fid_num: 50000 | ||
|
||
cfg_interval_start: 0.89 | ||
timestep_shift: 0.3 | ||
|
||
GMem: | ||
bank_type: 'full' | ||
bank_path: 'data/preprocessed/in1k256/banks/in1k256_Kfull_dim768_init_seed0.pth.pth' |
Oops, something went wrong.