PyTorch Implementation of Denoising Diffusion Probabilistic Models [paper] [official repo]


  • Original DDPM1 training & sampling
  • DDIM2 sampler
  • Standard evaluation metrics
    • Fréchet Inception Distance3 (FID)
    • Precision & Recall4
  • Distributed Data Parallel5 (DDP) multi-GPU training


  • torch>=1.12.0
  • torchvision>=1.13.0
  • scipy>=1.7.3

Code usage

Toy data Real-world data 
Training Training Generation Evaluation

usage: [-h] [--dataset {gaussian8,gaussian25,swissroll}]      
                    [--size SIZE] [--root ROOT] [--epochs EPOCHS] [--lr LR]
                    [--beta1 BETA1] [--beta2 BETA2] [--lr-warmup LR_WARMUP]
                    [--batch-size BATCH_SIZE] [--timesteps TIMESTEPS]      
                    [--beta-schedule {quad,linear,warmup10,warmup50,jsd}]  
                    [--beta-start BETA_START] [--beta-end BETA_END]        
                    [--model-mean-type {mean,x_0,eps}]                     
                    [--model-var-type {learned,fixed-small,fixed-large}]   
                    [--loss-type {kl,mse}] [--image-dir IMAGE_DIR]         
                    [--chkpt-dir CHKPT_DIR] [--chkpt-intv CHKPT_INTV]      
                    [--eval-intv EVAL_INTV] [--seed SEED] [--resume]       
                    [--device DEVICE] [--mid-features MID_FEATURES]        
                    [--num-temporal-layers NUM_TEMPORAL_LAYERS]            
optional arguments:                                                        
  -h, --help            show this help message and exit                    
  --dataset {gaussian8,gaussian25,swissroll}                               
  --size SIZE                                                              
  --root ROOT           root directory of datasets                         
  --epochs EPOCHS       total number of training epochs                    
  --lr LR               learning rate                                      
  --beta1 BETA1         beta_1 in Adam                                     
  --beta2 BETA2         beta_2 in Adam                                     
  --lr-warmup LR_WARMUP                                                    
                        number of warming-up epochs                        
  --batch-size BATCH_SIZE                                                  
  --timesteps TIMESTEPS                                                    
                        number of diffusion steps                          
  --beta-schedule {quad,linear,warmup10,warmup50,jsd}                      
  --beta-start BETA_START                                                  
  --beta-end BETA_END                                                      
  --model-mean-type {mean,x_0,eps}
  --model-var-type {learned,fixed-small,fixed-large}
  --loss-type {kl,mse}
  --image-dir IMAGE_DIR
  --chkpt-dir CHKPT_DIR
  --chkpt-intv CHKPT_INTV
                        frequency of saving a checkpoint
  --eval-intv EVAL_INTV
  --seed SEED           random seed
  --resume              to resume training from a checkpoint
  --device DEVICE
  --mid-features MID_FEATURES
  --num-temporal-layers NUM_TEMPORAL_LAYERS

usage: [-h] [--dataset {mnist,cifar10,celeba,celebahq}] [--root ROOT]
                [--epochs EPOCHS] [--lr LR] [--beta1 BETA1] [--beta2 BETA2]   
                [--batch-size BATCH_SIZE] [--num-accum NUM_ACCUM]
                [--block-size BLOCK_SIZE] [--timesteps TIMESTEPS]
                [--beta-schedule {quad,linear,warmup10,warmup50,jsd}]
                [--beta-start BETA_START] [--beta-end BETA_END]
                [--model-mean-type {mean,x_0,eps}]
                [--model-var-type {learned,fixed-small,fixed-large}]
                [--loss-type {kl,mse}] [--num-workers NUM_WORKERS]
                [--train-device TRAIN_DEVICE] [--eval-device EVAL_DEVICE]
                [--image-dir IMAGE_DIR] [--image-intv IMAGE_INTV]
                [--num-save-images NUM_SAVE_IMAGES] [--config-dir CONFIG_DIR]
                [--chkpt-dir CHKPT_DIR] [--chkpt-name CHKPT_NAME]
                [--chkpt-intv CHKPT_INTV] [--seed SEED] [--resume]
                [--chkpt-path CHKPT_PATH] [--eval] [--use-ema]
                [--ema-decay EMA_DECAY] [--distributed] [--rigid-launch]
                [--num-gpus NUM_GPUS] [--dry-run]
optional arguments:
  -h, --help            show this help message and exit
  --dataset {mnist,cifar10,celeba,celebahq}
  --root ROOT           root directory of datasets
  --epochs EPOCHS       total number of training epochs
  --lr LR               learning rate
  --beta1 BETA1         beta_1 in Adam
  --beta2 BETA2         beta_2 in Adam
  --batch-size BATCH_SIZE
  --num-accum NUM_ACCUM
                        number of mini-batches before an update
  --block-size BLOCK_SIZE
                        block size used for pixel shuffle
  --timesteps TIMESTEPS
                        number of diffusion steps
  --beta-schedule {quad,linear,warmup10,warmup50,jsd}
  --beta-start BETA_START
  --beta-end BETA_END
  --model-mean-type {mean,x_0,eps}
  --model-var-type {learned,fixed-small,fixed-large}
  --loss-type {kl,mse}
  --chkpt-path CHKPT_PATH
                        checkpoint path used to resume training
  --eval                whether to evaluate fid during training
  --use-ema             whether to use exponential moving average
  --ema-decay EMA_DECAY
                        decay factor of ema
  --distributed         whether to use distributed training
  --rigid-launch        whether to use torch multiprocessing spawn
  --num-gpus NUM_GPUS   number of gpus for distributed training
  --dry-run             test-run till the first model update completes

usage: [-h] [--dataset {mnist,cifar10,celeba}]
                   [--batch-size BATCH_SIZE] [--total-size TOTAL_SIZE]
                   [--config-dir CONFIG_DIR] [--chkpt-dir CHKPT_DIR]
                   [--chkpt-path CHKPT_PATH] [--save-dir SAVE_DIR]
                   [--device DEVICE] [--use-ema] [--use-ddim] [--eta ETA]
                   [--skip-schedule SKIP_SCHEDULE] [--subseq-size SUBSEQ_SIZE]
                   [--suffix SUFFIX] [--max-workers MAX_WORKERS]
                   [--num-gpus NUM_GPUS]
optional arguments:
  -h, --help            show this help message and exit
  --dataset {mnist,cifar10,celeba}
  --batch-size BATCH_SIZE
  --total-size TOTAL_SIZE
  --config-dir CONFIG_DIR
  --chkpt-dir CHKPT_DIR
  --chkpt-path CHKPT_PATH
  --save-dir SAVE_DIR
  --device DEVICE
  --eta ETA
  --skip-schedule SKIP_SCHEDULE
  --subseq-size SUBSEQ_SIZE
  --suffix SUFFIX
  --max-workers MAX_WORKERS
  --num-gpus NUM_GPUS

usage: [-h] [--root ROOT] [--dataset {mnist,cifar10,celeba}]
               [--model-device MODEL_DEVICE] [--eval-device EVAL_DEVICE]
               [--eval-batch-size EVAL_BATCH_SIZE]
               [--eval-total-size EVAL_TOTAL_SIZE] [--num-workers NUM_WORKERS]
               [--nhood-size NHOOD_SIZE] [--row-batch-size ROW_BATCH_SIZE]
               [--col-batch-size COL_BATCH_SIZE] [--device DEVICE]
               [--eval-dir EVAL_DIR] [--precomputed-dir PRECOMPUTED_DIR]
               [--metrics METRICS [METRICS ...]] [--seed SEED]
               [--folder-name FOLDER_NAME]
optional arguments:
  -h, --help            show this help message and exit
  --root ROOT
  --dataset {mnist,cifar10,celeba}
  --model-device MODEL_DEVICE
  --eval-device EVAL_DEVICE
  --eval-batch-size EVAL_BATCH_SIZE
  --eval-total-size EVAL_TOTAL_SIZE
  --num-workers NUM_WORKERS
  --nhood-size NHOOD_SIZE
  --row-batch-size ROW_BATCH_SIZE
  --col-batch-size COL_BATCH_SIZE
  --device DEVICE
  --eval-dir EVAL_DIR
  --precomputed-dir PRECOMPUTED_DIR
  --metrics METRICS [METRICS ...]
  --seed SEED
  --folder-name FOLDER_NAME


  • Train a 25-Gaussian toy model with single GPU (device id: 0) for a total of 100 epochs

    python --dataset gaussian25 --device cuda:0 --epochs 100
  • Train CIFAR-10 model with single GPU (device id: 0) for a total of 50 epochs

    python --dataset cifar10 --train-device cuda:0 --epochs 50

(You can always use dry-run for testing/tuning purpose.)

  • Train a CelebA model with an effective batch size of 64 x 2 x 4 = 128 on a four-card machine (single node) using shared file-system initialization

    python --dataset celeba --use-ema --num-accum 2 --num-gpus 4 --distributed --rigid-launch
    • use-ema: use exponential moving average (0.9999 decay by default)
    • num-accum 2: accumulate gradients for 2 mini-batches
    • num-gpus: number of GPU(s) to use for training, i.e. WORLD_SIZE of the process group
    • distributed: enable multi-gpu DDP training
    • rigid-run: use shared-file system initialization and torch.multiprocessing
  • (Recommended) Train a CelebA model with an effective batch-size of 64 x 1 x 2 = 128 using only two GPUs with torchrun Elastic Launch6 (TCP initialization)

    export CUDA_VISIBLE_DEVICES=0,1&&torchrun --standalone --nproc_per_node 2 --rdzv_backend c10d --dataset celeba --distributed
  • Generate 50,000 samples (128 per mini-batch) of the EMA checkpoint located at ./chkpts/train/ in parallel using 4 GPUs and DDIM sampler. The results are stored in ./images/eval/cifar10_2160

     python --dataset cifar10 --chkpt-path ./chkpt/train/ --use-ema --use-ddim --skip-schedule quadratic --subseq-size 100 --suffix _2160 --num-gpus 4
    • use-ddim: use DDIM
    • skip-schedule quadratic: use the quadratic schedule
    • subseq-size: length of sub-sequence, i.e. DDIM timesteps
    • suffix: suffix string to the dataset name in the folder name
    • num-gpus: number of GPU(s) to use for generation
  • Evaluate FID, Precision/Recall of generated samples in ./images/eval/cifar10_2160

     python --dataset cifar10 --folder-name cifar10_2160

Experiment results

Toy data

Dataset 8 Gaussian 25 Gaussian Swiss Roll
True gaussian8_true_thumbnail gaussian25_true_thumbnail swissroll_true_thumbnail
Generated gaussian8_true_thumbnail gaussian25_true_thumbnail swissroll_true_thumbnail

Training process (animated)

Dataset 8 Gaussian 25 Gaussian Swiss Roll
Generated gaussian8_train_thumbnail gaussian25_train_thumbnail swissroll_train_thumbnail

Real-world data

Table of evaluated metrics

Dataset FID (↓) Precision (↑) Recall (↑) Training steps Training loss Checkpoint
CIFAR-10 9.23 0.692 0.473 46.8k 0.0302 -
|__ 6.02 0.693 0.510 93.6k 0.0291 -
|__ 4.04 0.701 0.550 234.0k 0.0298 -
|__ 3.36 0.717 0.559 468.0k 0.0284 -
|__ 3.25 0.736 0.548 842.4k 0.0277 [Link]
CelebA 4.81 0.766 0.490 189.8k 0.0153 -
|__ 3.88 0.760 0.516 379.7k 0.0151 -
|__ 3.07 0.754 0.540 949.2k 0.0147 [Link]

Dataset CIFAR-10 CelebA
Generated images cifar10_gen celeba_gen_thumbnail

Training process (animated)

Dataset CIFAR-10 CelebA
Generated images cifar10_train celeba_train_thumbnail

Denoising process (animated)

Dataset CIFAR-10 CelebA
Generated images cifar10_denoise celeba_denoise_thumbnail

