We provide our PyTorch implementation for unpaired image-to-image translation based on patchwise contrastive learning and adversarial learning. No hand-crafted loss and inverse network is used. Compared to CycleGAN, model training is faster and less memory-intensive. In addition, our method can be extended to single image training, where each “domain” is only a single image.
Contrastive Learning for Unpaired Image-to-Image Translation
Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
UC Berkeley and Adobe Research
In ECCV 2020
import torch
cross_entropy_loss = torch.nn.CrossEntropyLoss()
# Input: f_q (BxCxS) and sampled features from H(G_enc(x))
# Input: f_k (BxCxS) are sampled features from H(G_enc(G(x))
# Input: tau is the temperature used in NCE loss.
# Output: PatchNCE loss
def PatchNCELoss(f_q, f_k, tau=0.07):
# batch size, channel size, and number of sample locations
B, C, S = f_q.shape
# calculate v * v+: BxSx1
l_pos = (f_k * f_q).sum(dim=1)[:, :, None]
# calculate v * v-: BxSxS
l_neg = torch.bmm(f_q.transpose(1, 2), f_k)
# The diagonal entries are not negatives. Remove them.
identity_matrix = torch.eye(S)[None, :, :]
l_neg.masked_fill_(identity_matrix, -float('inf'))
# calculate logits: (B)x(S)x(S+1)
logits = torch.cat((l_pos, l_neg), dim=2) / tau
# return NCE loss
predictions = logits.flatten(0, 1)
targets = torch.zeros(B * S, dtype=torch.long)
return cross_entropy_loss(predictions, targets)
- Linux or macOS
- Python 3
- CPU or NVIDIA GPU + CUDA CuDNN
- Clone this repo:
git clone https://github.com/taesungp/contrastive-unpaired-translation CUT
cd CUT
- Install PyTorch 1.4 and other dependencies (e.g., torchvision, func-timeout, gputil).
For pip users, please type the command pip install -r requirements.txt
.
For Conda users, we provide an installation script scripts/conda_deps.sh. Alternatively, you can create a new Conda environment using conda env create -f environment.yml
.
- Download the grumpify dataset (Fig 8 of the paper. Russian Blue -> Grumpy Cats)
bash ./datasets/download_cut_dataset.sh grumpifycat
The dataset is downloaded and unzipped at ./datasets/grumpifycat/
.
The other datasets can be downloaded using
bash ./datasets/download_cut_dataset.sh [dataset_name]
, a script provided by the CycleGAN repo.
- Train the model:
# Trains the CUT model
python train.py --dataroot ./datasets/grumpifycat --name grumpycat_CUT --CUT_mode CUT
# Trains the FastCUT model
python train.py --dataroot ./datasets/grumpifycat --name grumpycat_FastCUT --CUT_mode FastCUT
The checkpoints are stored at ./checkpoints/grumpycat_*/web
.
Please see experiments/grumpifycat_launcher.py
that generates the above command line arguments.
The launcher scripts are useful for configuring rather complicated command-line arguments of
training and testing.
Using the launcher, the command below generates the training command of CUT and FastCUT.
python -m experiments grumpifycat train 0
python -m experiments grumpifycat train 1
To test using the laucher,
python -m experiments grumpifycat test 0
python -m experiments grumpifycat test 1
Possible commands are run, run_test, launch, close, and so on. Please see experiments/main.py for all commands
The tutorial for applying pretrained models will be released soon.
The tutorial for the Single-Image Translation will be released soon.
If you use this code for your research, please cite our paper.
@inproceedings{park2020cut,
title={Contrastive Learning for Unpaired Image-to-Image Translation},
author={Taesung Park and Alexei A. Efros and Richard Zhang and Jun-Yan Zhu},
booktitle={European Conference on Computer Vision},
year={2020}
}
We thank Allan Jabri and Phillip Isola for helpful discussion and feedback. Our code is developed based on pytorch-CycleGAN-and-pix2pix. We also thank pytorch-fid for FID computation and drn for mIoU computation, and stylegan2-pytorch for the PyTorch implementation of StyleGAN2 used in single-image translation.