forked from orpatashnik/StyleCLIP
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9645102
commit d5e3c44
Showing
32 changed files
with
1,403 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from models.facial_recognition.model_irse import Backbone | ||
|
||
|
||
class IDLoss(nn.Module): | ||
def __init__(self, opts): | ||
super(IDLoss, self).__init__() | ||
print('Loading ResNet ArcFace') | ||
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') | ||
self.facenet.load_state_dict(torch.load(opts.ir_se50_weights)) | ||
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | ||
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) | ||
self.facenet.eval() | ||
self.opts = opts | ||
|
||
def extract_feats(self, x): | ||
if x.shape[2] != 256: | ||
x = self.pool(x) | ||
x = x[:, :, 35:223, 32:220] # Crop interesting region | ||
x = self.face_pool(x) | ||
x_feats = self.facenet(x) | ||
return x_feats | ||
|
||
def forward(self, y_hat, y): | ||
n_samples = y.shape[0] | ||
y_feats = self.extract_feats(y) # Otherwise use the feature from there | ||
y_hat_feats = self.extract_feats(y_hat) | ||
y_feats = y_feats.detach() | ||
loss = 0 | ||
sim_improvement = 0 | ||
count = 0 | ||
for i in range(n_samples): | ||
diff_target = y_hat_feats[i].dot(y_feats[i]) | ||
loss += 1 - diff_target | ||
count += 1 | ||
|
||
return loss / count, sim_improvement / count |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class LatentsDataset(Dataset): | ||
|
||
def __init__(self, latents, opts): | ||
self.latents = latents | ||
self.opts = opts | ||
|
||
def __len__(self): | ||
return self.latents.shape[0] | ||
|
||
def __getitem__(self, index): | ||
|
||
return self.latents[index] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import torch | ||
from torch import nn | ||
from torch.nn import Module | ||
|
||
from models.stylegan2.model import EqualLinear, PixelNorm | ||
|
||
|
||
class Mapper(Module): | ||
|
||
def __init__(self, opts): | ||
super(Mapper, self).__init__() | ||
|
||
self.opts = opts | ||
layers = [PixelNorm()] | ||
|
||
for i in range(4): | ||
layers.append( | ||
EqualLinear( | ||
512, 512, lr_mul=0.01, activation='fused_lrelu' | ||
) | ||
) | ||
|
||
self.mapping = nn.Sequential(*layers) | ||
|
||
|
||
def forward(self, x): | ||
x = self.mapping(x) | ||
return x | ||
|
||
|
||
class SingleMapper(Module): | ||
|
||
def __init__(self, opts): | ||
super(SingleMapper, self).__init__() | ||
|
||
self.opts = opts | ||
|
||
self.mapping = Mapper(opts) | ||
|
||
def forward(self, x): | ||
out = self.mapping(x) | ||
return out | ||
|
||
|
||
class LevelsMapper(Module): | ||
|
||
def __init__(self, opts): | ||
super(LevelsMapper, self).__init__() | ||
|
||
self.opts = opts | ||
|
||
if not opts.no_coarse_mapper: | ||
self.course_mapping = Mapper(opts) | ||
if not opts.no_medium_mapper: | ||
self.medium_mapping = Mapper(opts) | ||
if not opts.no_fine_mapper: | ||
self.fine_mapping = Mapper(opts) | ||
|
||
def forward(self, x): | ||
x_coarse = x[:, :4, :] | ||
x_medium = x[:, 4:8, :] | ||
x_fine = x[:, 8:, :] | ||
|
||
if not self.opts.no_coarse_mapper: | ||
x_coarse = self.course_mapping(x_coarse) | ||
else: | ||
x_coarse = torch.zeros_like(x_coarse) | ||
if not self.opts.no_medium_mapper: | ||
x_medium = self.medium_mapping(x_medium) | ||
else: | ||
x_medium = torch.zeros_like(x_medium) | ||
if not self.opts.no_fine_mapper: | ||
x_fine = self.fine_mapping(x_fine) | ||
else: | ||
x_fine = torch.zeros_like(x_fine) | ||
|
||
|
||
out = torch.cat([x_coarse, x_medium, x_fine], dim=1) | ||
|
||
return out | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from argparse import ArgumentParser | ||
|
||
|
||
class TestOptions: | ||
|
||
def __init__(self): | ||
self.parser = ArgumentParser() | ||
self.initialize() | ||
|
||
def initialize(self): | ||
# arguments for inference script | ||
self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') | ||
self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint') | ||
self.parser.add_argument('--couple_outputs', action='store_true', help='Whether to also save inputs + outputs side-by-side') | ||
|
||
self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use') | ||
self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true") | ||
self.parser.add_argument('--no_medium_mapper', default=False, action="store_true") | ||
self.parser.add_argument('--no_fine_mapper', default=False, action="store_true") | ||
self.parser.add_argument('--stylegan_size', default=1024, type=int) | ||
|
||
|
||
self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference') | ||
self.parser.add_argument('--latents_test_path', default=None, type=str, help="The latents for the validation") | ||
self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers') | ||
|
||
self.parser.add_argument('--n_images', type=int, default=None, help='Number of images to output. If None, run on all data') | ||
|
||
def parse(self): | ||
opts = self.parser.parse_args() | ||
return opts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from argparse import ArgumentParser | ||
|
||
|
||
class TrainOptions: | ||
|
||
def __init__(self): | ||
self.parser = ArgumentParser() | ||
self.initialize() | ||
|
||
def initialize(self): | ||
self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') | ||
self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use') | ||
self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true") | ||
self.parser.add_argument('--no_medium_mapper', default=False, action="store_true") | ||
self.parser.add_argument('--no_fine_mapper', default=False, action="store_true") | ||
self.parser.add_argument('--latents_train_path', default="train_faces.pt", type=str, help="The latents for the training") | ||
self.parser.add_argument('--latents_test_path', default="test_faces.pt", type=str, help="The latents for the validation") | ||
self.parser.add_argument('--train_dataset_size', default=5000, type=int, help="Will be used only if no latents are given") | ||
self.parser.add_argument('--test_dataset_size', default=1000, type=int, help="Will be used only if no latents are given") | ||
|
||
self.parser.add_argument('--batch_size', default=2, type=int, help='Batch size for training') | ||
self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference') | ||
self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers') | ||
self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers') | ||
|
||
self.parser.add_argument('--learning_rate', default=0.5, type=float, help='Optimizer learning rate') | ||
self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') | ||
|
||
self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor') | ||
self.parser.add_argument('--clip_lambda', default=1.0, type=float, help='CLIP loss multiplier factor') | ||
self.parser.add_argument('--latent_l2_lambda', default=0.8, type=float, help='Latent L2 loss multiplier factor') | ||
|
||
self.parser.add_argument('--stylegan_weights', default='../pretrained_models/stylegan2-ffhq-config-f.pt', type=str, help='Path to StyleGAN model weights') | ||
self.parser.add_argument('--stylegan_size', default=1024, type=int) | ||
self.parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss") | ||
self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to StyleCLIPModel model checkpoint') | ||
|
||
self.parser.add_argument('--max_steps', default=50000, type=int, help='Maximum number of training steps') | ||
self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training') | ||
self.parser.add_argument('--board_interval', default=50, type=int, help='Interval for logging metrics to tensorboard') | ||
self.parser.add_argument('--val_interval', default=2000, type=int, help='Validation interval') | ||
self.parser.add_argument('--save_interval', default=2000, type=int, help='Model checkpoint interval') | ||
|
||
self.parser.add_argument('--description', required=True, type=str, help='Driving text prompt') | ||
|
||
|
||
def parse(self): | ||
opts = self.parser.parse_args() | ||
return opts |
Oops, something went wrong.