Code for the paper Nonparametric Generative Modeling with Conditional Sliced-Wasserstein Flows (ICML 2023).
This project has been tested under the following environment settings:
- OS: Ubuntu 20.04.5
- GPU: NVIDIA A100 80GB (
$\times$ 8) - Python: 3.8.10
- jax: 0.3.24
- jaxlib: 0.3.24+cuda11.cudnn805
We use the standard MNIST, FashionMNIST, and CIFAR10 datasets from torchvision
. For the CelebA 64*64 dataset we follow the preprocessing of yang-song/score_sde_pytorch to first crop the original images to 140*140 and then resize them to 64*64:
cd data
python get_celebA.py
The configuration files for all experiments can be found in the config
folder. For example, the following command produces the results of the class-conditional generation on MNIST:
python main.py -c configs/mnist_class_cond.yaml
Our method relies on the CDF functions of the projected target data distributions, which we estimate using empirical distributions represented as sorted arrays. The advantage of sorted arrays, besides being nonparametric, lies in their efficient updating capability, achieved in
Online Unconditional Generative Modeling
In the unconditional setting, we assume sequential observation of data class by class, starting from the first class of MNIST (digit '0') to the tenth class (digit '9'), followed by Fashion-MNIST classes from the first to the tenth. The CDF functions of the projected target data distributions are continuously updated as new data arrives, leading to dynamic changes in the batched samples (i.e., run Algorithm 1 with the latest CDF functions), as shown below.
Online Conditional Generative Modeling
In the class-conditional setting, we begin with fully observed MNIST and proceed with sequential observation of Fashion-MNIST data class by class, starting from the first class (T-shirt/top) to the tenth class (Ankle boot). The resulting dynamic changes in class-conditional batched samples are illustrated below.
If you find this project helpful in your research, please consider citing our paper:
@inproceedings{du2023nonparametric,
title={Nonparametric Generative Modeling with Conditional Sliced-Wasserstein Flows},
author={Du, Chao and Li, Tianbo and Pang, Tianyu and Yan, Shuicheng and Lin, Min},
booktitle={International Conference on Machine Learning (ICML)},
year={2023}
}