Skip to content
/ PDMP Public

Repository for Piecewise Deterministic Generative Models

Notifications You must be signed in to change notification settings

darioShar/PDMP

Repository files navigation

Generative models with PDMPs

Piecewise Deterministic Generative Models are a class of generative models that replace traditional Gaussian noise-based diffusion with piecewise deterministic processes (PDMPs).

This repository contains the full implementation of these models, providing the tools for training, evaluation, and generation of data. It includes a modular structure, allowing users to customize different components like the model, logging mechanisms, and experiment setups.

For further details on the theoretical background and methodology, please refer to our paper here.

Repository Overview

We are using BEM (Better Experimentation Manager) to manage our experiments.

  • Generative Model Implementation: Located in PDMP/methods/pdmp.py and PDMP/methods/Diffusion.py, these files contain the core logic of the PDMP generative model, and the improved DDPM generative model (incorporating its heavy-tailed extension DLPM). Users interested in understanding or modifying the underlying generative processes should start here.

  • Neural Network Architecture: If you want to change the neural networks used in the PDMP, head to the PDMP/models directory. This is where all the neural network models are defined and can be customized according to your needs.

  • Logging Configuration: To customize how logging is handled, you can inherit from bem/Logger.py for integrating your own logger. An example of a custom logging setup is available in PDMP/NeptuneLogger.py.

  • Experiment Workflow: The PDMP/pdmp_experiment file orchestrates the initialization of the training, evaluation, and data generation processes. To integrate your modifications into the experiment flow, update the init functions here. These functions will be provided to the Experiment class from the bem library.

  • Configuration Parameters: Parameters for specific runs are passed in a dictionary called p, which is loaded from configuration files located in PDMP/configs/. Adjust these files to fine-tune the behavior of the model and experiment settings.

  • Comparison between diffusion and PDMP methods: This repository supports both traditional diffusion models and PDMP, which allows for structured sampling using piecewise deterministic processes. When using the PDMP method, users can specify different samplers such as 'ZigZag', 'HMC', and 'BPS', and choose from various loss types. Our paper provides a detailed discussion on the advantages of PDMP, though users can experiment with both approaches here.

Supported Datasets

Here’s a brief overview of the supported datasets, as provided by BEM, and how to specify them:

  • 2D Datasets:

    • The repository supports synthetic 2D datasets. See Generator.available_distributions in bem/datasets/Data.py.
  • Image Datasets:

    • You can use standard image datasets (MNIST, CIFAR-10, its long-tailed version CIFAR-10-LT, CELEBA etc.). See bem/datasets/__init__.py.

Configuration files for some of these datasets are provided in the pdmp/configs/ directory: mnist.yml for MNIST, cifar10.yml for CIFAR-10, cifar10_lt.yml for CIFAR-10-LT.

You can modify the configuration files to adjust data loading settings, such as the batch size or data augmentation options, according to your experiment needs.

Using the Provided Scripts

This repository includes scripts that simplify the process of training, evaluating, and visualizing the results of PDMP. Below is a description of each script and how to use them:

1. run.py

This script is used to train a model. It accepts various command-line arguments to control the training process, including configuration settings and experiment parameters.

Example Command:

python ./run.py --config mnist --name pdmp_test --method pdmp --sampler ZigZag --loss hyvarinen --epochs 100 --eval 50 --check 50 --train_reverse_steps 1000

Explanation:

  • --config: Specifies the configuration file to use (e.g., mnist).
  • --name: The name of the experiment run, used for logging and identification. Here, the checkpointed models will be stored in /models/pdmp_test/.
  • --method: Specifies the generative method to use (either diffusion or pdmp), in this case, pdmp.
  • --sampler: Specifies the sampler to use with PDMP (options: 'ZigZag', 'HMC', 'BPS'). Required when --method pdmp is selected.
  • --loss: Specifies the loss type for training (options: 'square', 'kl', 'logistic', 'hyvarinen', 'ml', 'hyvarinen_simple', 'kl_simple'). We recommend 'hyvarinen' for ZigZag and 'ml' for HMC and BPS.
  • --epochs: The total number of training epochs.
  • --eval: Specifies the interval (in epochs) for running evaluations during training.
  • --check: Interval for model checkpointing (in epochs).
  • --train_reverse_steps: The number of reverse steps to use during training.

2. eval.py

This script evaluates a pre-trained model and can also be used for generating samples from the trained model.

Example Command:

python ./eval.py --config mnist --name pdmp_test --method pdmp --sampler HMC --loss ml --epochs 100 --eval 100 --generate 2000 --reverse_steps 1000

Explanation:

  • --config, --name, --method, --sampler, --loss, and --epochs: Same as in run.py.
  • --eval: Specifies the evaluation checkpoint to use.
  • --generate: Number of samples to generate.
  • --reverse_steps: Number of reverse steps to use during the generation process.

3. display.py

This script is used to visualize the generated samples or the results from an experiment.

Example Command:

python ./display.py --config mnist --name pdmp_test --method pdmp --sampler ZigZag --loss hyvarinen --epochs 100 --reverse_steps 1000 --generate 1

Explanation:

  • --config, --name, --method, --sampler, --loss, --epochs, and --reverse_steps: Same as in the previous scripts.
  • --generate: Specifies the number of samples to visualize (e.g., 1 for displaying a single sample).

Citation

@misc{bertazzi2024piecewisedeterministicgenerativemodels,
      title={Piecewise deterministic generative models}, 
      author={Andrea Bertazzi and Alain Oliviero-Durmus and Dario Shariatian and Umut Simsekli and Eric Moulines},
      year={2024},
      eprint={2407.19448},
      archivePrefix={arXiv},
      primaryClass={stat.ML},
      url={https://arxiv.org/abs/2407.19448}, 
}

Contribute

We welcome issues, pull requests, and contributions. We will try our best to improve readability and answer questions.

About

Repository for Piecewise Deterministic Generative Models

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages