Skip to content

Commit

Permalink
Added training code for the single-image translation.
Browse files Browse the repository at this point in the history
  • Loading branch information
taesungp committed Sep 13, 2020
1 parent 24e03ce commit a138080
Show file tree
Hide file tree
Showing 12 changed files with 312 additions and 66 deletions.
34 changes: 32 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def PatchNCELoss(f_q, f_k, tau=0.07):
- Python 3
- CPU or NVIDIA GPU + CUDA CuDNN

### Update log

9/12/2020: Added single-image translation.

### Getting started

- Clone this repo:
Expand Down Expand Up @@ -183,14 +187,40 @@ The tutorial for using pretrained models will be released soon.

### SinCUT Single Image Unpaired Training

The tutorial for the Single-Image Translation will be released soon.
To train SinCUT (single-image translation, shown in Fig 9, 13 and 14 of the paper), you need to

1. set the `--model` option as `--model sincut`, which invokes the configuration and codes at `./models/sincut_model.py`, and
2. specify the dataset directory of one image in each domain, such as the example dataset included in this repo at `./datasets/single_image_monet_etretat/`.

For example, to train a model for the [Etretat cliff (first image of Figure 13)](https://github.com/taesungp/contrastive-unpaired-translation/blob/master/imgs/singleimage.gif), please use the following command.

```bash
python train.py --model sincut --name singleimage_monet_etretat --dataroot ./datasets/single_image_monet_etretat
```

or by using the experiment launcher script,
```bash
python -m experiments singleimage run 0
```

For single-image translation, we adopt network architectural components of [StyleGAN2](https://github.com/NVlabs/stylegan2), as well as the pixel identity preservation loss used in [DTN](https://arxiv.org/abs/1611.02200) and [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/cycle_gan_model.py#L160). In particular, we adopted the code of [rosinality](https://github.com/rosinality/stylegan2-pytorch), which exists at `models/stylegan_networks.py`.

The training takes several hours. To generate the final image using the checkpoint,

```bash
python test.py --model sincut --name singleimage_monet_etretat --dataroot ./datasets/single_image_monet_etretat
```

or simply

```bash
python -m experiments singleimage run_test 0
```

### [Datasets](./docs/datasets.md)
Download CUT/CycleGAN/pix2pix datasets and learn how to create your own datasets.



### Citation
If you use this code for your research, please cite our [paper](https://arxiv.org/pdf/2007.15651).
```
Expand Down
2 changes: 1 addition & 1 deletion data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, opt):
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.num_threads),
drop_last=True
drop_last=True if opt.isTrain else False,
)

def set_epoch(self, epoch):
Expand Down
108 changes: 108 additions & 0 deletions data/singleimage_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import numpy as np
import os.path
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import random
import util.util as util


class SingleImageDataset(BaseDataset):
"""
This dataset class can load unaligned/unpaired datasets.
It requires two directories to host training images from domain A '/path/to/data/trainA'
and from domain B '/path/to/data/trainB' respectively.
You can train the model with the dataset flag '--dataroot /path/to/data'.
Similarly, you need to prepare two directories:
'/path/to/data/testA' and '/path/to/data/testB' during test time.
"""

def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)

self.dir_A = os.path.join(opt.dataroot, 'trainA') # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(opt.dataroot, 'trainB') # create a path '/path/to/data/trainB'

if os.path.exists(self.dir_A) and os.path.exists(self.dir_B):
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B

assert len(self.A_paths) == 1 and len(self.B_paths) == 1,\
"SingleImageDataset class should be used with one image in each domain"
A_img = Image.open(self.A_paths[0]).convert('RGB')
B_img = Image.open(self.B_paths[0]).convert('RGB')
print("Image sizes %s and %s" % (str(A_img.size), str(B_img.size)))

self.A_img = A_img
self.B_img = B_img

# In single-image translation, we augment the data loader by applying
# random scaling. Still, we design the data loader such that the
# amount of scaling is the same within a minibatch. To do this,
# we precompute the random scaling values, and repeat them by |batch_size|.
A_zoom = 1 / self.opt.random_scale_max
zoom_levels_A = np.random.uniform(A_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2))
self.zoom_levels_A = np.reshape(np.tile(zoom_levels_A, (1, opt.batch_size, 1)), [-1, 2])

B_zoom = 1 / self.opt.random_scale_max
zoom_levels_B = np.random.uniform(B_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2))
self.zoom_levels_B = np.reshape(np.tile(zoom_levels_B, (1, opt.batch_size, 1)), [-1, 2])

# While the crop locations are randomized, the negative samples should
# not come from the same location. To do this, we precompute the
# crop locations with no repetition.
self.patch_indices_A = list(range(len(self)))
random.shuffle(self.patch_indices_A)
self.patch_indices_B = list(range(len(self)))
random.shuffle(self.patch_indices_B)

def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
A_path = self.A_paths[0]
B_path = self.B_paths[0]
A_img = self.A_img
B_img = self.B_img

# apply image transformation
if self.opt.phase == "train":
param = {'scale_factor': self.zoom_levels_A[index],
'patch_index': self.patch_indices_A[index],
'flip': random.random() > 0.5}

transform_A = get_transform(self.opt, params=param, method=Image.BILINEAR)
A = transform_A(A_img)

param = {'scale_factor': self.zoom_levels_B[index],
'patch_index': self.patch_indices_B[index],
'flip': random.random() > 0.5}
transform_B = get_transform(self.opt, params=param, method=Image.BILINEAR)
B = transform_B(B_img)
else:
transform = get_transform(self.opt, method=Image.BILINEAR)
A = transform(A_img)
B = transform(B_img)

return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}

def __len__(self):
""" Let's pretend the single image contains 100,000 crops for convenience.
"""
return 100000
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 18 additions & 0 deletions experiments/singleimage_launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .tmux_launcher import Options, TmuxLauncher


class Launcher(TmuxLauncher):
def common_options(self):
return [
Options(
name="singleimage_monet_etretat",
dataroot="./datasets/single_image_monet_etretat",
model="sincut"
)
]

def commands(self):
return ["python train.py " + str(opt) for opt in self.common_options()]

def test_commands(self):
return ["python test.py " + str(opt) for opt in self.common_options()]
70 changes: 36 additions & 34 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ def modify_commandline_options(parser, is_train=True):

parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))')
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')

parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
type=util.str2bool, nargs='?', const=True, default=False,
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
parser.add_argument('--netF_nc', type=int, default=256)
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
Expand Down Expand Up @@ -101,27 +103,31 @@ def data_dependent_initialize(self):
self.real_B = self.real_B[:bs_per_gpu]
self.forward() # compute fake images: G(A)
if self.opt.isTrain:
self.backward_D() # calculate gradients for D
self.backward_G() # calculate graidents for G
self.compute_D_loss().backward() # calculate gradients for D
self.compute_G_loss().backward() # calculate graidents for G
if self.opt.lambda_NCE > 0.0:
self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
self.optimizers.append(self.optimizer_F)

def optimize_parameters(self):
# forward
self.forward() # compute fake images: G(A)
self.forward()

# update D
self.set_requires_grad(self.netD, True) # enable backprop for D
self.optimizer_D.zero_grad() # set D's gradients to zero
self.backward_D() # calculate gradients for D
self.optimizer_D.step() # update D's weights
self.set_requires_grad(self.netD, True)
self.optimizer_D.zero_grad()
self.loss_D = self.compute_D_loss()
self.loss_D.backward()
self.optimizer_D.step()

# update G
self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
self.optimizer_G.zero_grad() # set G's gradients to zero
self.set_requires_grad(self.netD, False)
self.optimizer_G.zero_grad()
if self.opt.netF == 'mlp_sample':
self.optimizer_F.zero_grad()
self.backward_G() # calculate graidents for G
self.optimizer_G.step() # udpate G's weights
self.loss_G = self.compute_G_loss()
self.loss_G.backward()
self.optimizer_G.step()
if self.opt.netF == 'mlp_sample':
self.optimizer_F.step()

Expand All @@ -138,7 +144,7 @@ def set_input(self, input):

def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt else self.real_A
self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
if self.opt.flip_equivariance:
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
if self.flipped_for_equivariance:
Expand All @@ -149,25 +155,22 @@ def forward(self):
if self.opt.nce_idt:
self.idt_B = self.fake[self.real_A.size(0):]

def backward_D(self):
if self.opt.lambda_GAN > 0.0:
"""Calculate GAN loss for the discriminator"""
fake = self.fake_B.detach()
# Fake; stop backprop to the generator by detaching fake_B
pred_fake = self.netD(fake)
self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
# Real
pred_real = self.netD(self.real_B)
loss_D_real_unweighted = self.criterionGAN(pred_real, True)
self.loss_D_real = loss_D_real_unweighted.mean()

# combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
else:
self.loss_D_real, self.loss_D_fake, self.loss_D = 0.0, 0.0, 0.0

def backward_G(self):
def compute_D_loss(self):
"""Calculate GAN loss for the discriminator"""
fake = self.fake_B.detach()
# Fake; stop backprop to the generator by detaching fake_B
pred_fake = self.netD(fake)
self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
# Real
self.pred_real = self.netD(self.real_B)
loss_D_real = self.criterionGAN(self.pred_real, True)
self.loss_D_real = loss_D_real.mean()

# combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
return self.loss_D

def compute_G_loss(self):
"""Calculate GAN and NCE loss for the generator"""
fake = self.fake_B
# First, G(A) should fake the discriminator
Expand All @@ -189,8 +192,7 @@ def backward_G(self):
loss_NCE_both = self.loss_NCE

self.loss_G = self.loss_G_GAN + loss_NCE_both

self.loss_G.backward()
return self.loss_G

def calculate_NCE_loss(self, src, tgt):
n_layers = len(self.nce_layers)
Expand Down
10 changes: 3 additions & 7 deletions models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
elif netG == 'unet_256':
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
elif netG == 'stylegan2':
net = StyleGAN2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, opt=opt)
net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, opt=opt)
elif netG == 'smallstylegan2':
net = StyleGAN2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=2, opt=opt)
net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, n_blocks=2, opt=opt)
elif netG == 'resnet_cat':
n_blocks = 8
net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu')
Expand Down Expand Up @@ -323,12 +323,8 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias,)
elif netD == 'pixel': # classify if each pixel is real or fake
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
elif netD == "patch":
net = PatchDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias)
elif netD == "tilestylegan2":
net = TileStyleGAN2Discriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias, size=opt.D_patch_size, opt=opt)
elif 'stylegan2' in netD:
net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias, opt=opt)
net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, no_antialias=no_antialias, opt=opt)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
return init_net(net, init_type, init_gain, gpu_ids,
Expand Down
20 changes: 17 additions & 3 deletions models/patchnce.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,24 @@ def forward(self, feat_q, feat_k):
l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
l_pos = l_pos.view(batchSize, 1)

# neg logit -- current batch
# neg logit

# Should the negatives from the other samples of a minibatch be utilized?
# In CUT and FastCUT, we found that it's best to only include negatives
# from the same image. Therefore, we set
# --nce_includes_all_negatives_from_minibatch as False
# However, for single-image translation, the minibatch consists of
# crops from the "same" high-resolution image.
# Therefore, we will include the negatives from the entire minibatch.
if self.opt.nce_includes_all_negatives_from_minibatch:
# reshape features as if they are all negatives of minibatch of size 1.
batch_dim_for_bmm = 1
else:
batch_dim_for_bmm = self.opt.batch_size

# reshape features to batch size
feat_q = feat_q.view(self.opt.batch_size, -1, dim)
feat_k = feat_k.view(self.opt.batch_size, -1, dim)
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
npatches = feat_q.size(1)
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))

Expand Down
Loading

0 comments on commit a138080

Please sign in to comment.