Skip to content

PyTorch implementation for the paper Don't Look into the Dark: Latent Codes for Pluralistic Image Inpainting (CVPR2024).

License

Notifications You must be signed in to change notification settings

nintendops/latent-code-inpainting

Repository files navigation

Pluralistic Image Inpainting with Latent Codes

paper | arXiv

This repository contains the code (in PyTorch) for ''Don't Look into the Dark: Latent Codes for Pluralistic Image Inpainting'' (CVPR'2024) by Haiwei Chen and Yajie Zhao.

Contents

  1. Requirements
  2. Usage
  3. Experiments
  4. Contact

Getting Started

The inpainting method in this repository utilizes priors learnt from discrete latent codes to diversely complete a masked image. The method works in both free-form and large-hole mask settings:

Requirements

The code has been tested on Python3.11, PyTorch 2.1.0 and CUDA (12.1). The additional dependencies can be installed with

pip install -r environment.txt

Getting Started

Our models are built upon training data from both Places365-Standard and CelebA-HQ.

As the first step, please download the respective pretrained models (Places | CelebA-HQ) and places the checkpoint files under the ckpts/ folder in the root directory.

Quick Test

We provide a demo notebook at eval.ipynb for quickly testing the inpainting models. Please follow instructions in the notebook to set up inference with your desired configurations.

Training

If you are interested in training our models on custom data, please refer to the list of training configurations under the folder training_configs/. To train everything from scratch, the complete model will need to go through a total of 4 training stages. Below lists the stages and their respective configuration templates:

Stage 1: training the VQGAN backbone 
   - training_configs/places_vqgan.yaml 
Stage 2: training the encoder module
   - training_configs/places_partialencoder.yaml 
Stage 3: training the transformer module
   - training_configs/places_transformer.yaml 
Stage 4: training the decoder module
   - training_configs/places_unet_256.yaml 
   - training_configs/places_unet_512.yaml 

Note that the modules for stage 2,3,4 can be trained independently, or concurrently, as these stages only require a pretrained VQGAN backbone from stage 1.

Please modify the path to the dataset, the path to the pretrained model, and optionally other hyperparameters in these configuration files to suit your needs. The basic command for training these models is as follow:

python train.py --base PATH_TO_CONFIG -n NAME --gpus GPU_INDEX 

For instance, to train the VQGAN backbone on a single gpu at index 0:

python train.py --base training_configs/places_vqgan.yaml -n my_vqgan_backbone --gpus 0, 

To train the transformer on multiple gpus at index 1,2,3:

python train.py --base training_configs/places_transformer.yaml -n my_transformer --gpus 1,2,3 

To evaluate the trained model, please follow configuration files in configs/ to modify the respective paths to each module checkpoints.

Contact

Haiwei Chen: [email protected] Any discussions or concerns are welcomed!

Citation If you find our project useful in your research, please consider citing:

@article{chen2024don,
  title={Don't Look into the Dark: Latent Codes for Pluralistic Image Inpainting},
  author={Chen, Haiwei and Zhao, Yajie},
  journal={arXiv preprint arXiv:2403.18186},
  year={2024}
}

License and Acknowledgement

The code and models in this repo are for research purposes only. Our code is bulit upon VQGAN.

About

PyTorch implementation for the paper Don't Look into the Dark: Latent Codes for Pluralistic Image Inpainting (CVPR2024).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published