Skip to content

Commit

Permalink
add stylespace mode to optimization and mapper; add id loss to optimi…
Browse files Browse the repository at this point in the history
…zation
  • Loading branch information
orpatashnik committed Aug 14, 2021
1 parent 0d17712 commit ea80a46
Show file tree
Hide file tree
Showing 11 changed files with 248 additions and 62 deletions.
1 change: 1 addition & 0 deletions criteria/id_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, opts):
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
self.facenet.cuda()
self.opts = opts

def extract_feats(self, x):
Expand Down
22 changes: 22 additions & 0 deletions mapper/datasets/latents_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from torch.utils.data import Dataset


Expand All @@ -13,3 +14,24 @@ def __len__(self):
def __getitem__(self, index):

return self.latents[index]

class StyleSpaceLatentsDataset(Dataset):

def __init__(self, latents, opts):
padded_latents = []
for latent in latents:
latent = latent.cpu()
if latent.shape[2] == 512:
padded_latents.append(latent)
else:
padding = torch.zeros((latent.shape[0], 1, 512 - latent.shape[2], 1, 1))
padded_latent = torch.cat([latent, padding], dim=2)
padded_latents.append(padded_latent)
self.latents = torch.cat(padded_latents, dim=2)
self.opts = opts

def __len__(self):
return len(self.latents)

def __getitem__(self, index):
return self.latents[index]
51 changes: 49 additions & 2 deletions mapper/latent_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

from models.stylegan2.model import EqualLinear, PixelNorm

STYLESPACE_DIMENSIONS = [512 for _ in range(15)] + [256, 256, 256] + [128, 128, 128] + [64, 64, 64] + [32, 32]


class Mapper(Module):

def __init__(self, opts):
def __init__(self, opts, latent_dim=512):
super(Mapper, self).__init__()

self.opts = opts
Expand All @@ -16,7 +18,7 @@ def __init__(self, opts):
for i in range(4):
layers.append(
EqualLinear(
512, 512, lr_mul=0.01, activation='fused_lrelu'
latent_dim, latent_dim, lr_mul=0.01, activation='fused_lrelu'
)
)

Expand Down Expand Up @@ -79,3 +81,48 @@ def forward(self, x):

return out

class FullStyleSpaceMapper(Module):

def __init__(self, opts):
super(FullStyleSpaceMapper, self).__init__()

self.opts = opts

for c, c_dim in enumerate(STYLESPACE_DIMENSIONS):
setattr(self, f"mapper_{c}", Mapper(opts, latent_dim=c_dim))

def forward(self, x):
out = []
for c, x_c in enumerate(x):
curr_mapper = getattr(self, f"mapper_{c}")
x_c_res = curr_mapper(x_c.view(x_c.shape[0], -1)).view(x_c.shape)
out.append(x_c_res)

return out


class WithoutToRGBStyleSpaceMapper(Module):

def __init__(self, opts):
super(WithoutToRGBStyleSpaceMapper, self).__init__()

self.opts = opts

indices_without_torgb = list(range(1, len(STYLESPACE_DIMENSIONS), 3))
self.STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in indices_without_torgb]

for c in self.STYLESPACE_INDICES_WITHOUT_TORGB:
setattr(self, f"mapper_{c}", Mapper(opts, latent_dim=STYLESPACE_DIMENSIONS[c]))

def forward(self, x):
out = []
for c in range(len(STYLESPACE_DIMENSIONS)):
x_c = x[c]
if c in self.STYLESPACE_INDICES_WITHOUT_TORGB:
curr_mapper = getattr(self, f"mapper_{c}")
x_c_res = curr_mapper(x_c.view(x_c.shape[0], -1)).view(x_c.shape)
else:
x_c_res = torch.zeros_like(x_c)
out.append(x_c_res)

return out
1 change: 1 addition & 0 deletions mapper/options/test_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def initialize(self):
self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
self.parser.add_argument('--latents_test_path', default=None, type=str, help="The latents for the validation")
self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers')
self.parser.add_argument('--work_in_stylespace', default=False, action='store_true')

self.parser.add_argument('--n_images', type=int, default=None, help='Number of images to output. If None, run on all data')

Expand Down
1 change: 1 addition & 0 deletions mapper/options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def initialize(self):
self.parser.add_argument('--latents_test_path', default="test_faces.pt", type=str, help="The latents for the validation")
self.parser.add_argument('--train_dataset_size', default=5000, type=int, help="Will be used only if no latents are given")
self.parser.add_argument('--test_dataset_size', default=1000, type=int, help="Will be used only if no latents are given")
self.parser.add_argument('--work_in_stylespace', default=False, action='store_true', help="trains a mapper in S instead of W+")

self.parser.add_argument('--batch_size', default=2, type=int, help='Batch size for training')
self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference')
Expand Down
35 changes: 26 additions & 9 deletions mapper/scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@

from tqdm import tqdm

from mapper.training.train_utils import convert_s_tensor_to_list

sys.path.append(".")
sys.path.append("..")

from mapper.datasets.latents_dataset import LatentsDataset
from mapper.datasets.latents_dataset import LatentsDataset, StyleSpaceLatentsDataset

from mapper.options.test_options import TestOptions
from mapper.styleclip_mapper import StyleCLIPMapper
Expand All @@ -34,8 +36,10 @@ def run(test_opts):
net.cuda()

test_latents = torch.load(opts.latents_test_path)
dataset = LatentsDataset(latents=test_latents.cpu(),
opts=opts)
if opts.work_in_stylespace:
dataset = StyleSpaceLatentsDataset(latents=[l.cpu() for l in test_latents], opts=opts)
else:
dataset = LatentsDataset(latents=test_latents.cpu(), opts=opts)
dataloader = DataLoader(dataset,
batch_size=opts.test_batch_size,
shuffle=False,
Expand All @@ -51,9 +55,15 @@ def run(test_opts):
if global_i >= opts.n_images:
break
with torch.no_grad():
input_cuda = input_batch.cuda().float()
if opts.work_in_stylespace:
input_cuda = convert_s_tensor_to_list(input_batch)
input_cuda = [c.cuda() for c in input_cuda]
else:
input_cuda = input_batch
input_cuda = input_cuda.cuda()

tic = time.time()
result_batch = run_on_batch(input_cuda, net, test_opts.couple_outputs)
result_batch = run_on_batch(input_cuda, net, opts.couple_outputs, opts.work_in_stylespace)
toc = time.time()
global_time.append(toc - tic)

Expand All @@ -76,14 +86,21 @@ def run(test_opts):
f.write(result_str)


def run_on_batch(inputs, net, couple_outputs=False):
def run_on_batch(inputs, net, couple_outputs=False, stylespace=False):
w = inputs
with torch.no_grad():
w_hat = w + 0.1 * net.mapper(w)
x_hat, w_hat = net.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1)
if stylespace:
delta = net.mapper(w)
w_hat = [c + 0.1 * delta_c for (c, delta_c) in zip(w, delta)]
x_hat, _, w_hat = net.decoder([w_hat], input_is_latent=True, return_latents=True,
randomize_noise=False, truncation=1, input_is_stylespace=True)
else:
w_hat = w + 0.1 * net.mapper(w)
x_hat, w_hat, _ = net.decoder([w_hat], input_is_latent=True, return_latents=True,
randomize_noise=False, truncation=1)
result_batch = (x_hat, w_hat)
if couple_outputs:
x, _ = net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1)
x, _ = net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1, input_is_stylespace=stylespace)
result_batch = (x_hat, w_hat, x)
return result_batch

Expand Down
4 changes: 3 additions & 1 deletion mapper/styleclip_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def __init__(self, opts):
self.load_weights()

def set_mapper(self):
if self.opts.mapper_type == 'SingleMapper':
if self.opts.work_in_stylespace:
mapper = latent_mappers.WithoutToRGBStyleSpaceMapper(self.opts)
elif self.opts.mapper_type == 'SingleMapper':
mapper = latent_mappers.SingleMapper(self.opts)
elif self.opts.mapper_type == 'LevelsMapper':
mapper = latent_mappers.LevelsMapper(self.opts)
Expand Down
62 changes: 46 additions & 16 deletions mapper/training/coach.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

import criteria.clip_loss as clip_loss
from criteria import id_loss
from mapper.datasets.latents_dataset import LatentsDataset
from mapper.datasets.latents_dataset import LatentsDataset, StyleSpaceLatentsDataset
from mapper.styleclip_mapper import StyleCLIPMapper
from mapper.training.ranger import Ranger
from mapper.training import train_utils
from mapper.training.train_utils import convert_s_tensor_to_list


class Coach:
Expand Down Expand Up @@ -71,12 +72,21 @@ def train(self):
while self.global_step < self.opts.max_steps:
for batch_idx, batch in enumerate(self.train_dataloader):
self.optimizer.zero_grad()
w = batch
w = w.to(self.device)
if self.opts.work_in_stylespace:
w = convert_s_tensor_to_list(batch)
w = [c.to(self.device) for c in w]
else:
w = batch
w = w.to(self.device)
with torch.no_grad():
x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1)
w_hat = w + 0.1 * self.net.mapper(w)
x_hat, w_hat = self.net.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1)
x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1, input_is_stylespace=self.opts.work_in_stylespace)
if self.opts.work_in_stylespace:
delta = self.net.mapper(w)
w_hat = [c + 0.1 * delta_c for (c, delta_c) in zip(w, delta)]
x_hat, _, w_hat = self.net.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1, input_is_stylespace=True)
else:
w_hat = w + 0.1 * self.net.mapper(w)
x_hat, w_hat, _ = self.net.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1)
loss, loss_dict = self.calc_loss(w, x, w_hat, x_hat)
loss.backward()
self.optimizer.step()
Expand Down Expand Up @@ -116,13 +126,22 @@ def validate(self):
if batch_idx > 200:
break

w = batch
if self.opts.work_in_stylespace:
w = convert_s_tensor_to_list(batch)
w = [c.to(self.device) for c in w]
else:
w = batch
w = w.to(self.device)

with torch.no_grad():
w = w.to(self.device).float()
x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=True, truncation=1)
w_hat = w + 0.1 * self.net.mapper(w)
x_hat, _ = self.net.decoder([w_hat], input_is_latent=True, randomize_noise=True, truncation=1)
x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1, input_is_stylespace=self.opts.work_in_stylespace)
if self.opts.work_in_stylespace:
delta = self.net.mapper(w)
w_hat = [c + 0.1 * delta_c for (c, delta_c) in zip(w, delta)]
x_hat, _, w_hat = self.net.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1, input_is_stylespace=True)
else:
w_hat = w + 0.1 * self.net.mapper(w)
x_hat, w_hat, _ = self.net.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1)
loss, cur_loss_dict = self.calc_loss(w, x, w_hat, x_hat)
agg_loss_dict.append(cur_loss_dict)

Expand Down Expand Up @@ -185,10 +204,16 @@ def configure_datasets(self):
test_latents.append(test_latents_b)
test_latents = torch.cat(test_latents)

train_dataset_celeba = LatentsDataset(latents=train_latents.cpu(),
opts=self.opts)
test_dataset_celeba = LatentsDataset(latents=test_latents.cpu(),
opts=self.opts)
if self.opts.work_in_stylespace:
train_dataset_celeba = StyleSpaceLatentsDataset(latents=[l.cpu() for l in train_latents],
opts=self.opts)
test_dataset_celeba = StyleSpaceLatentsDataset(latents=[l.cpu() for l in test_latents],
opts=self.opts)
else:
train_dataset_celeba = LatentsDataset(latents=train_latents.cpu(),
opts=self.opts)
test_dataset_celeba = LatentsDataset(latents=test_latents.cpu(),
opts=self.opts)
train_dataset = train_dataset_celeba
test_dataset = test_dataset_celeba
print("Number of training samples: {}".format(len(train_dataset)))
Expand All @@ -208,7 +233,12 @@ def calc_loss(self, w, x, w_hat, x_hat):
loss_dict['loss_clip'] = float(loss_clip)
loss += loss_clip * self.opts.clip_lambda
if self.opts.latent_l2_lambda > 0:
loss_l2_latent = self.latent_l2_loss(w_hat, w)
if self.opts.work_in_stylespace:
loss_l2_latent = 0
for c_hat, c in zip(w_hat, w):
loss_l2_latent += self.latent_l2_loss(c_hat, c)
else:
loss_l2_latent = self.latent_l2_loss(w_hat, w)
loss_dict['loss_l2_latent'] = float(loss_l2_latent)
loss += loss_l2_latent * self.opts.latent_l2_lambda
loss_dict['loss'] = float(loss)
Expand Down
8 changes: 8 additions & 0 deletions mapper/training/train_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
STYLESPACE_DIMENSIONS = [512 for _ in range(15)] + [256, 256, 256] + [128, 128, 128] + [64, 64, 64] + [32, 32]

def aggregate_loss_dict(agg_loss_dict):
mean_vals = {}
Expand All @@ -11,3 +12,10 @@ def aggregate_loss_dict(agg_loss_dict):
print('{} has no value'.format(key))
mean_vals[key] = 0
return mean_vals


def convert_s_tensor_to_list(batch):
s_list = []
for i in range(len(STYLESPACE_DIMENSIONS)):
s_list.append(batch[:, :, 512 * i: 512 * i + STYLESPACE_DIMENSIONS[i]])
return s_list
Loading

0 comments on commit ea80a46

Please sign in to comment.