diff --git a/README.md b/README.md index 599e396a..726135ee 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,10 @@ def PatchNCELoss(f_q, f_k, tau=0.07): - Python 3 - CPU or NVIDIA GPU + CUDA CuDNN +### Update log + +9/12/2020: Added single-image translation. + ### Getting started - Clone this repo: @@ -183,14 +187,40 @@ The tutorial for using pretrained models will be released soon. ### SinCUT Single Image Unpaired Training -The tutorial for the Single-Image Translation will be released soon. +To train SinCUT (single-image translation, shown in Fig 9, 13 and 14 of the paper), you need to + +1. set the `--model` option as `--model sincut`, which invokes the configuration and codes at `./models/sincut_model.py`, and +2. specify the dataset directory of one image in each domain, such as the example dataset included in this repo at `./datasets/single_image_monet_etretat/`. + +For example, to train a model for the [Etretat cliff (first image of Figure 13)](https://github.com/taesungp/contrastive-unpaired-translation/blob/master/imgs/singleimage.gif), please use the following command. + +```bash +python train.py --model sincut --name singleimage_monet_etretat --dataroot ./datasets/single_image_monet_etretat +``` + +or by using the experiment launcher script, +```bash +python -m experiments singleimage run 0 +``` +For single-image translation, we adopt network architectural components of [StyleGAN2](https://github.com/NVlabs/stylegan2), as well as the pixel identity preservation loss used in [DTN](https://arxiv.org/abs/1611.02200) and [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/cycle_gan_model.py#L160). In particular, we adopted the code of [rosinality](https://github.com/rosinality/stylegan2-pytorch), which exists at `models/stylegan_networks.py`. + +The training takes several hours. To generate the final image using the checkpoint, + +```bash +python test.py --model sincut --name singleimage_monet_etretat --dataroot ./datasets/single_image_monet_etretat +``` + +or simply + +```bash +python -m experiments singleimage run_test 0 +``` ### [Datasets](./docs/datasets.md) Download CUT/CycleGAN/pix2pix datasets and learn how to create your own datasets. - ### Citation If you use this code for your research, please cite our [paper](https://arxiv.org/pdf/2007.15651). ``` diff --git a/data/__init__.py b/data/__init__.py index 236e70d9..a7dd29b4 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -77,7 +77,7 @@ def __init__(self, opt): batch_size=opt.batch_size, shuffle=not opt.serial_batches, num_workers=int(opt.num_threads), - drop_last=True + drop_last=True if opt.isTrain else False, ) def set_epoch(self, epoch): diff --git a/data/singleimage_dataset.py b/data/singleimage_dataset.py new file mode 100644 index 00000000..0a9f1b55 --- /dev/null +++ b/data/singleimage_dataset.py @@ -0,0 +1,108 @@ +import numpy as np +import os.path +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image +import random +import util.util as util + + +class SingleImageDataset(BaseDataset): + """ + This dataset class can load unaligned/unpaired datasets. + + It requires two directories to host training images from domain A '/path/to/data/trainA' + and from domain B '/path/to/data/trainB' respectively. + You can train the model with the dataset flag '--dataroot /path/to/data'. + Similarly, you need to prepare two directories: + '/path/to/data/testA' and '/path/to/data/testB' during test time. + """ + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseDataset.__init__(self, opt) + + self.dir_A = os.path.join(opt.dataroot, 'trainA') # create a path '/path/to/data/trainA' + self.dir_B = os.path.join(opt.dataroot, 'trainB') # create a path '/path/to/data/trainB' + + if os.path.exists(self.dir_A) and os.path.exists(self.dir_B): + self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA' + self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB' + self.A_size = len(self.A_paths) # get the size of dataset A + self.B_size = len(self.B_paths) # get the size of dataset B + + assert len(self.A_paths) == 1 and len(self.B_paths) == 1,\ + "SingleImageDataset class should be used with one image in each domain" + A_img = Image.open(self.A_paths[0]).convert('RGB') + B_img = Image.open(self.B_paths[0]).convert('RGB') + print("Image sizes %s and %s" % (str(A_img.size), str(B_img.size))) + + self.A_img = A_img + self.B_img = B_img + + # In single-image translation, we augment the data loader by applying + # random scaling. Still, we design the data loader such that the + # amount of scaling is the same within a minibatch. To do this, + # we precompute the random scaling values, and repeat them by |batch_size|. + A_zoom = 1 / self.opt.random_scale_max + zoom_levels_A = np.random.uniform(A_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2)) + self.zoom_levels_A = np.reshape(np.tile(zoom_levels_A, (1, opt.batch_size, 1)), [-1, 2]) + + B_zoom = 1 / self.opt.random_scale_max + zoom_levels_B = np.random.uniform(B_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2)) + self.zoom_levels_B = np.reshape(np.tile(zoom_levels_B, (1, opt.batch_size, 1)), [-1, 2]) + + # While the crop locations are randomized, the negative samples should + # not come from the same location. To do this, we precompute the + # crop locations with no repetition. + self.patch_indices_A = list(range(len(self))) + random.shuffle(self.patch_indices_A) + self.patch_indices_B = list(range(len(self))) + random.shuffle(self.patch_indices_B) + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index (int) -- a random integer for data indexing + + Returns a dictionary that contains A, B, A_paths and B_paths + A (tensor) -- an image in the input domain + B (tensor) -- its corresponding image in the target domain + A_paths (str) -- image paths + B_paths (str) -- image paths + """ + A_path = self.A_paths[0] + B_path = self.B_paths[0] + A_img = self.A_img + B_img = self.B_img + + # apply image transformation + if self.opt.phase == "train": + param = {'scale_factor': self.zoom_levels_A[index], + 'patch_index': self.patch_indices_A[index], + 'flip': random.random() > 0.5} + + transform_A = get_transform(self.opt, params=param, method=Image.BILINEAR) + A = transform_A(A_img) + + param = {'scale_factor': self.zoom_levels_B[index], + 'patch_index': self.patch_indices_B[index], + 'flip': random.random() > 0.5} + transform_B = get_transform(self.opt, params=param, method=Image.BILINEAR) + B = transform_B(B_img) + else: + transform = get_transform(self.opt, method=Image.BILINEAR) + A = transform(A_img) + B = transform(B_img) + + return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path} + + def __len__(self): + """ Let's pretend the single image contains 100,000 crops for convenience. + """ + return 100000 diff --git a/datasets/single_image_monet_etretat/trainA/monet.jpg b/datasets/single_image_monet_etretat/trainA/monet.jpg new file mode 100644 index 00000000..738c1cd8 Binary files /dev/null and b/datasets/single_image_monet_etretat/trainA/monet.jpg differ diff --git a/datasets/single_image_monet_etretat/trainB/etretat-normandy-france.jpg b/datasets/single_image_monet_etretat/trainB/etretat-normandy-france.jpg new file mode 100644 index 00000000..41aabf6c Binary files /dev/null and b/datasets/single_image_monet_etretat/trainB/etretat-normandy-france.jpg differ diff --git a/experiments/singleimage_launcher.py b/experiments/singleimage_launcher.py new file mode 100644 index 00000000..5d286c91 --- /dev/null +++ b/experiments/singleimage_launcher.py @@ -0,0 +1,18 @@ +from .tmux_launcher import Options, TmuxLauncher + + +class Launcher(TmuxLauncher): + def common_options(self): + return [ + Options( + name="singleimage_monet_etretat", + dataroot="./datasets/single_image_monet_etretat", + model="sincut" + ) + ] + + def commands(self): + return ["python train.py " + str(opt) for opt in self.common_options()] + + def test_commands(self): + return ["python test.py " + str(opt) for opt in self.common_options()] diff --git a/models/cut_model.py b/models/cut_model.py index 46e57bd4..9d748057 100644 --- a/models/cut_model.py +++ b/models/cut_model.py @@ -23,9 +23,11 @@ def modify_commandline_options(parser, is_train=True): parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))') parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)') - parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))') parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers') + parser.add_argument('--nce_includes_all_negatives_from_minibatch', + type=util.str2bool, nargs='?', const=True, default=False, + help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.') parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map') parser.add_argument('--netF_nc', type=int, default=256) parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss') @@ -101,27 +103,31 @@ def data_dependent_initialize(self): self.real_B = self.real_B[:bs_per_gpu] self.forward() # compute fake images: G(A) if self.opt.isTrain: - self.backward_D() # calculate gradients for D - self.backward_G() # calculate graidents for G + self.compute_D_loss().backward() # calculate gradients for D + self.compute_G_loss().backward() # calculate graidents for G if self.opt.lambda_NCE > 0.0: self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2)) self.optimizers.append(self.optimizer_F) def optimize_parameters(self): # forward - self.forward() # compute fake images: G(A) + self.forward() + # update D - self.set_requires_grad(self.netD, True) # enable backprop for D - self.optimizer_D.zero_grad() # set D's gradients to zero - self.backward_D() # calculate gradients for D - self.optimizer_D.step() # update D's weights + self.set_requires_grad(self.netD, True) + self.optimizer_D.zero_grad() + self.loss_D = self.compute_D_loss() + self.loss_D.backward() + self.optimizer_D.step() + # update G - self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G - self.optimizer_G.zero_grad() # set G's gradients to zero + self.set_requires_grad(self.netD, False) + self.optimizer_G.zero_grad() if self.opt.netF == 'mlp_sample': self.optimizer_F.zero_grad() - self.backward_G() # calculate graidents for G - self.optimizer_G.step() # udpate G's weights + self.loss_G = self.compute_G_loss() + self.loss_G.backward() + self.optimizer_G.step() if self.opt.netF == 'mlp_sample': self.optimizer_F.step() @@ -138,7 +144,7 @@ def set_input(self, input): def forward(self): """Run forward pass; called by both functions and .""" - self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt else self.real_A + self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A if self.opt.flip_equivariance: self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5) if self.flipped_for_equivariance: @@ -149,25 +155,22 @@ def forward(self): if self.opt.nce_idt: self.idt_B = self.fake[self.real_A.size(0):] - def backward_D(self): - if self.opt.lambda_GAN > 0.0: - """Calculate GAN loss for the discriminator""" - fake = self.fake_B.detach() - # Fake; stop backprop to the generator by detaching fake_B - pred_fake = self.netD(fake) - self.loss_D_fake = self.criterionGAN(pred_fake, False).mean() - # Real - pred_real = self.netD(self.real_B) - loss_D_real_unweighted = self.criterionGAN(pred_real, True) - self.loss_D_real = loss_D_real_unweighted.mean() - - # combine loss and calculate gradients - self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 - self.loss_D.backward() - else: - self.loss_D_real, self.loss_D_fake, self.loss_D = 0.0, 0.0, 0.0 - - def backward_G(self): + def compute_D_loss(self): + """Calculate GAN loss for the discriminator""" + fake = self.fake_B.detach() + # Fake; stop backprop to the generator by detaching fake_B + pred_fake = self.netD(fake) + self.loss_D_fake = self.criterionGAN(pred_fake, False).mean() + # Real + self.pred_real = self.netD(self.real_B) + loss_D_real = self.criterionGAN(self.pred_real, True) + self.loss_D_real = loss_D_real.mean() + + # combine loss and calculate gradients + self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 + return self.loss_D + + def compute_G_loss(self): """Calculate GAN and NCE loss for the generator""" fake = self.fake_B # First, G(A) should fake the discriminator @@ -189,8 +192,7 @@ def backward_G(self): loss_NCE_both = self.loss_NCE self.loss_G = self.loss_G_GAN + loss_NCE_both - - self.loss_G.backward() + return self.loss_G def calculate_NCE_loss(self, src, tgt): n_layers = len(self.nce_layers) diff --git a/models/networks.py b/models/networks.py index 4c80bac1..88fb161c 100644 --- a/models/networks.py +++ b/models/networks.py @@ -257,9 +257,9 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in elif netG == 'unet_256': net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'stylegan2': - net = StyleGAN2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, opt=opt) + net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, opt=opt) elif netG == 'smallstylegan2': - net = StyleGAN2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=2, opt=opt) + net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, n_blocks=2, opt=opt) elif netG == 'resnet_cat': n_blocks = 8 net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu') @@ -323,12 +323,8 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal' net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias,) elif netD == 'pixel': # classify if each pixel is real or fake net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) - elif netD == "patch": - net = PatchDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias) - elif netD == "tilestylegan2": - net = TileStyleGAN2Discriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias, size=opt.D_patch_size, opt=opt) elif 'stylegan2' in netD: - net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias, opt=opt) + net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, no_antialias=no_antialias, opt=opt) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) return init_net(net, init_type, init_gain, gpu_ids, diff --git a/models/patchnce.py b/models/patchnce.py index 1c7bf76d..14b53d9c 100644 --- a/models/patchnce.py +++ b/models/patchnce.py @@ -19,10 +19,24 @@ def forward(self, feat_q, feat_k): l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1)) l_pos = l_pos.view(batchSize, 1) - # neg logit -- current batch + # neg logit + + # Should the negatives from the other samples of a minibatch be utilized? + # In CUT and FastCUT, we found that it's best to only include negatives + # from the same image. Therefore, we set + # --nce_includes_all_negatives_from_minibatch as False + # However, for single-image translation, the minibatch consists of + # crops from the "same" high-resolution image. + # Therefore, we will include the negatives from the entire minibatch. + if self.opt.nce_includes_all_negatives_from_minibatch: + # reshape features as if they are all negatives of minibatch of size 1. + batch_dim_for_bmm = 1 + else: + batch_dim_for_bmm = self.opt.batch_size + # reshape features to batch size - feat_q = feat_q.view(self.opt.batch_size, -1, dim) - feat_k = feat_k.view(self.opt.batch_size, -1, dim) + feat_q = feat_q.view(batch_dim_for_bmm, -1, dim) + feat_k = feat_k.view(batch_dim_for_bmm, -1, dim) npatches = feat_q.size(1) l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1)) diff --git a/models/sincut_model.py b/models/sincut_model.py new file mode 100644 index 00000000..7e54bcc9 --- /dev/null +++ b/models/sincut_model.py @@ -0,0 +1,79 @@ +import torch +from .cut_model import CUTModel + + +class SinCUTModel(CUTModel): + """ This class implements the single image translation model (Fig 9) of + Contrastive Learning for Unpaired Image-to-Image Translation + Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu + ECCV, 2020 + """ + + @staticmethod + def modify_commandline_options(parser, is_train=True): + parser = CUTModel.modify_commandline_options(parser, is_train) + parser.add_argument('--lambda_R1', type=float, default=1.0, + help='weight for the R1 gradient penalty') + parser.add_argument('--lambda_identity', type=float, default=1.0, + help='the "identity preservation loss"') + + parser.set_defaults(nce_includes_all_negatives_from_minibatch=True, + dataset_mode="singleimage", + netG="stylegan2", + stylegan2_G_num_downsampling=1, + netD="stylegan2", + gan_mode="nonsaturating", + num_patches=1, + nce_layers="0,2,4", + lambda_NCE=4.0, + ngf=10, + ndf=8, + lr=0.002, + beta1=0.0, + beta2=0.99, + load_size=1024, + crop_size=64, + preprocess="zoom_and_patch", + ) + + if is_train: + parser.set_defaults(preprocess="zoom_and_patch", + batch_size=16, + save_epoch_freq=1, + save_latest_freq=20000, + n_epochs=8, + n_epochs_decay=8, + + ) + else: + parser.set_defaults(preprocess="none", # load the whole image as it is + batch_size=1, + num_test=1, + ) + + return parser + + def __init__(self, opt): + super().__init__(opt) + if self.isTrain: + if opt.lambda_R1 > 0.0: + self.loss_names += ['D_R1'] + if opt.lambda_identity > 0.0: + self.loss_names += ['idt'] + + def compute_D_loss(self): + self.real_B.requires_grad_() + GAN_loss_D = super().compute_D_loss() + self.loss_D_R1 = self.R1_loss(self.pred_real, self.real_B) + self.loss_D = GAN_loss_D + self.loss_D_R1 + return self.loss_D + + def compute_G_loss(self): + CUT_loss_G = super().compute_G_loss() + self.loss_idt = torch.nn.functional.l1_loss(self.idt_B, self.real_B) * self.opt.lambda_identity + return CUT_loss_G + self.loss_idt + + def R1_loss(self, real_pred, real_img): + grad_real, = torch.autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True, retain_graph=True) + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty * (self.opt.lambda_R1 * 0.5) diff --git a/models/stylegan_networks.py b/models/stylegan_networks.py index 230dc11e..a3c625da 100644 --- a/models/stylegan_networks.py +++ b/models/stylegan_networks.py @@ -693,7 +693,7 @@ def forward(self, input): class StyleGAN2Discriminator(nn.Module): - def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False, size=None, opt=None): + def __init__(self, input_nc, ndf=64, n_layers=3, no_antialias=False, size=None, opt=None): super().__init__() self.opt = opt self.stddev_group = 16 @@ -795,7 +795,7 @@ def forward(self, input): class StyleGAN2Encoder(nn.Module): - def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None): + def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None): super().__init__() assert opt is not None self.opt = opt @@ -818,18 +818,16 @@ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_d convs = [nn.Identity(), ConvLayer(3, channels[cur_res], 1)] - num_downsampling = self.opt.G_n_downsampling + num_downsampling = self.opt.stylegan2_G_num_downsampling for i in range(num_downsampling): in_channel = channels[cur_res] out_channel = channels[cur_res // 2] - convs.append(ResBlock(in_channel, out_channel, blur_kernel, downsample=True, - skip_gain=opt.resnet_skip_gain)) + convs.append(ResBlock(in_channel, out_channel, blur_kernel, downsample=True)) cur_res = cur_res // 2 for i in range(n_blocks // 2): n_channel = channels[cur_res] - convs.append(ResBlock(n_channel, n_channel, downsample=False, - skip_gain=opt.resnet_skip_gain)) + convs.append(ResBlock(n_channel, n_channel, downsample=False)) self.convs = nn.Sequential(*convs) @@ -851,7 +849,7 @@ def forward(self, input, layers=[], get_features=False): class StyleGAN2Decoder(nn.Module): - def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None): + def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None): super().__init__() assert opt is not None self.opt = opt @@ -871,13 +869,13 @@ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_d 1024: int(round(16 * channel_multiplier)), } - num_downsampling = self.opt.G_n_downsampling + num_downsampling = self.opt.stylegan2_G_num_downsampling cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) // (2 ** num_downsampling) convs = [] for i in range(n_blocks // 2): n_channel = channels[cur_res] - convs.append(ResBlock(n_channel, n_channel, downsample=False, skip_gain=opt.resnet_skip_gain)) + convs.append(ResBlock(n_channel, n_channel, downsample=False)) for i in range(num_downsampling): in_channel = channels[cur_res] @@ -897,14 +895,11 @@ def forward(self, input): class StyleGAN2Generator(nn.Module): - def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None): + def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None): super().__init__() self.opt = opt - self.encoder = StyleGAN2Encoder(input_nc, output_nc, ngf, norm_layer, use_dropout, n_blocks, padding_type, no_antialias, opt) - self.decoder = StyleGAN2Decoder(input_nc, output_nc, ngf, norm_layer, use_dropout, n_blocks, padding_type, no_antialias, opt) - - if self.opt.G_learns_residual: - self.residual_gain = nn.Parameter(torch.from_numpy(np.array([1, -1], dtype=np.float32)).cuda()) + self.encoder = StyleGAN2Encoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt) + self.decoder = StyleGAN2Decoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt) def forward(self, input, layers=[], encode_only=False): feat, feats = self.encoder(input, layers, True) @@ -912,9 +907,6 @@ def forward(self, input, layers=[], encode_only=False): return feats else: fake = self.decoder(feat) - if self.opt.G_learns_residual: - ratio = F.softmax(self.residual_gain, dim=0) - feat = input * ratio[0] + fake * ratio[1] if len(layers) > 0: return fake, feats diff --git a/options/base_options.py b/options/base_options.py index 1b69cba1..e6d72e34 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -57,11 +57,18 @@ def initialize(self, parser): parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') + parser.add_argument('--random_scale_max', type=float, default=3.0, + help='(used for single image translation) Randomly scale the image by the specified factor as data augmentation.') # additional parameters parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + # parameters related to StyleGAN2-based networks + parser.add_argument('--stylegan2_G_num_downsampling', + default=1, type=int, + help='Number of downsampling layers used by StyleGAN2Generator') + self.initialized = True return parser