Skip to content

Commit

Permalink
add mapper code and notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
orpatashnik committed Apr 6, 2021
1 parent 9645102 commit d5e3c44
Show file tree
Hide file tree
Showing 32 changed files with 1,403 additions and 24 deletions.
55 changes: 50 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery

Optimization notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/optimization_playground.ipynb)
Global directions notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/StyleCLIP_global.ipynb)
Optimization notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/optimization_playground.ipynb)
Global directions notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/StyleCLIP_global.ipynb)
Mapper notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/mapper_playground.ipynb)

<p align="center">
<a href="https://www.youtube.com/watch?v=5icI0NgALnQ"><img src='https://github.com/orpatashnik/StyleCLIP/blob/main/img/StyleCLIP_gif.gif' width=600 ></a>
Expand Down Expand Up @@ -56,6 +57,8 @@ Currently, the repository contains the code for the optimization and for the glo
The work is still in progress -- stay tuned!

## Updates
**6/4/2021** Add mapper training and inference (including a jupyter notebook) code

**6/4/2021** Add support for custom StyleGAN2 and StyleGAN2-ada models, and also custom images

**2/4/2021** Add the global directions code (a local GUI and a colab notebook)
Expand Down Expand Up @@ -90,7 +93,7 @@ In addition to the requirements mentioned before, a pretrained StyleGAN2 generat
### Usage

Given a textual description, one can both edit a given image, or generate a random image that best fits to the description.
Both operations can be done through the `main.py` script, or the `optimization_playground.ipynb` notebook ([![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/optimization_playground.ipynb)).
Both operations can be done through the `main.py` script, or the `optimization_playground.ipynb` notebook ([![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/optimization_playground.ipynb)).

#### Editing
To edit an image set `--mode=edit`. Editing can be done on both provided latent vector, and on a random latent vector from StyleGAN's latent space.
Expand All @@ -99,11 +102,51 @@ It is recommended to adjust the `--l2_lambda` according to the desired edit.
#### Generating Free-style Images
To generate a free-style image set `--mode=free_generation`.

## Editing via Latent Mapper
Here, we provide the code for the latent mapper. The mapper is trained to learn *residuals* from a given latent vector, according to the driving text.
The code for the mapper is in `mapper/`.

### Setup
As in the optimization, the code relies on [Rosinality](https://github.com/rosinality/stylegan2-pytorch/) pytorch implementation of StyleGAN2.
In addition the the StyleGAN weights, it is neccessary to have weights for the facial recognition network used in the ID loss.
The weights can be downloaded from [here](https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing).

The mapper is trained on latent vectors. It is recommended to train on *inverted real images*.
To this end, we provide the CelebA-HQ that was inverted by e4e:
[train set](https://drive.google.com/file/d/1gof8kYc_gDLUT4wQlmUdAtPnQIlCO26q/view?usp=sharing), [test set](https://drive.google.com/file/d/1j7RIfmrCoisxx3t-r-KC02Qc8barBecr/view?usp=sharing).

### Usage

#### Training
- The main training script is placed in `mapper/scripts/train.py`.
- Training arguments can be found at `mapper/options/train_options.py`.
- Intermediate training results are saved to opts.exp_dir. This includes checkpoints, train outputs, and test outputs.
Additionally, if you have tensorboard installed, you can visualize tensorboard logs in opts.exp_dir/logs.
Note that
- To resume a training, please provide `--checkpoint_path`.
- `--description` is where you provide the driving text.
- If you perform an edit that is not supposed to change "colors" in the image, it is recommended to use the flag `--no_fine_mapper`.

Example for training a mapper for the moahwk hairstyle:
```bash
cd mapper
python train.py --exp_dir ../results/mohawk_hairstyle --no_fine_mapper --description "mohawk hairstyle"
```
All configurations for the examples shown in the paper are provided there.

#### Inference
- The main inferece script is placed in `mapper/scripts/inference.py`.
- Inference arguments can be found at `mapper/options/test_options.py`.
- Adding the flag `--couple_outputs` will save image containing the input and output images side-by-side.

Pretrained models for variuos edits are provided. Please refer to `utils.py` for the complete links list.

We also provide a notebook for performing inference with the mapper Mapper notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/mapper_playground.ipynb)

## Editing via Global Direction

Here we provide GUI for editing images with the global directions.
We provide both a jupyter notebook [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/StyleCLIP_global.ipynb),
We provide both a jupyter notebook [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/orpatashnik/StyleCLIP/blob/main/notebooks/StyleCLIP_global.ipynb),
and the GUI used in the [video](https://www.youtube.com/watch?v=5icI0NgALnQ).
For both, the linear direction are computed in **real time**.
The code is located at `global/`.
Expand Down Expand Up @@ -218,7 +261,9 @@ The driving text that was used for each edit appears below or above each image.

The global directions we find for editing are direction in the _S Space_, which was introduced and analyzed in [StyleSpace](https://arxiv.org/abs/2011.12799) (Wu et al).

To edit real images, we inverted them to the StyleGAN's latent space using [e4e](https://arxiv.org/abs/2102.02766) (Tov et al.).
To edit real images, we inverted them to the StyleGAN's latent space using [e4e](https://arxiv.org/abs/2102.02766) (Tov et al.).

The code strcuture of the mapper is heavily based on [pSp](https://github.com/eladrich/pixel2style2pixel).

## Citation

Expand Down
Empty file added criteria/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions clip_loss.py → criteria/clip_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

class CLIPLoss(torch.nn.Module):

def __init__(self):
def __init__(self, opts):
super(CLIPLoss, self).__init__()
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
self.upsample = torch.nn.Upsample(scale_factor=7)
self.avg_pool = torch.nn.AvgPool2d(kernel_size=32)
self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)

def forward(self, image, text):
image = self.avg_pool(self.upsample(image))
Expand Down
39 changes: 39 additions & 0 deletions criteria/id_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from torch import nn

from models.facial_recognition.model_irse import Backbone


class IDLoss(nn.Module):
def __init__(self, opts):
super(IDLoss, self).__init__()
print('Loading ResNet ArcFace')
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
self.facenet.load_state_dict(torch.load(opts.ir_se50_weights))
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
self.opts = opts

def extract_feats(self, x):
if x.shape[2] != 256:
x = self.pool(x)
x = x[:, :, 35:223, 32:220] # Crop interesting region
x = self.face_pool(x)
x_feats = self.facenet(x)
return x_feats

def forward(self, y_hat, y):
n_samples = y.shape[0]
y_feats = self.extract_feats(y) # Otherwise use the feature from there
y_hat_feats = self.extract_feats(y_hat)
y_feats = y_feats.detach()
loss = 0
sim_improvement = 0
count = 0
for i in range(n_samples):
diff_target = y_hat_feats[i].dot(y_feats[i])
loss += 1 - diff_target
count += 1

return loss / count, sim_improvement / count
Empty file added mapper/__init__.py
Empty file.
Empty file added mapper/datasets/__init__.py
Empty file.
15 changes: 15 additions & 0 deletions mapper/datasets/latents_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from torch.utils.data import Dataset


class LatentsDataset(Dataset):

def __init__(self, latents, opts):
self.latents = latents
self.opts = opts

def __len__(self):
return self.latents.shape[0]

def __getitem__(self, index):

return self.latents[index]
81 changes: 81 additions & 0 deletions mapper/latent_mappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
from torch import nn
from torch.nn import Module

from models.stylegan2.model import EqualLinear, PixelNorm


class Mapper(Module):

def __init__(self, opts):
super(Mapper, self).__init__()

self.opts = opts
layers = [PixelNorm()]

for i in range(4):
layers.append(
EqualLinear(
512, 512, lr_mul=0.01, activation='fused_lrelu'
)
)

self.mapping = nn.Sequential(*layers)


def forward(self, x):
x = self.mapping(x)
return x


class SingleMapper(Module):

def __init__(self, opts):
super(SingleMapper, self).__init__()

self.opts = opts

self.mapping = Mapper(opts)

def forward(self, x):
out = self.mapping(x)
return out


class LevelsMapper(Module):

def __init__(self, opts):
super(LevelsMapper, self).__init__()

self.opts = opts

if not opts.no_coarse_mapper:
self.course_mapping = Mapper(opts)
if not opts.no_medium_mapper:
self.medium_mapping = Mapper(opts)
if not opts.no_fine_mapper:
self.fine_mapping = Mapper(opts)

def forward(self, x):
x_coarse = x[:, :4, :]
x_medium = x[:, 4:8, :]
x_fine = x[:, 8:, :]

if not self.opts.no_coarse_mapper:
x_coarse = self.course_mapping(x_coarse)
else:
x_coarse = torch.zeros_like(x_coarse)
if not self.opts.no_medium_mapper:
x_medium = self.medium_mapping(x_medium)
else:
x_medium = torch.zeros_like(x_medium)
if not self.opts.no_fine_mapper:
x_fine = self.fine_mapping(x_fine)
else:
x_fine = torch.zeros_like(x_fine)


out = torch.cat([x_coarse, x_medium, x_fine], dim=1)

return out

Empty file added mapper/options/__init__.py
Empty file.
31 changes: 31 additions & 0 deletions mapper/options/test_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from argparse import ArgumentParser


class TestOptions:

def __init__(self):
self.parser = ArgumentParser()
self.initialize()

def initialize(self):
# arguments for inference script
self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint')
self.parser.add_argument('--couple_outputs', action='store_true', help='Whether to also save inputs + outputs side-by-side')

self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use')
self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true")
self.parser.add_argument('--no_medium_mapper', default=False, action="store_true")
self.parser.add_argument('--no_fine_mapper', default=False, action="store_true")
self.parser.add_argument('--stylegan_size', default=1024, type=int)


self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
self.parser.add_argument('--latents_test_path', default=None, type=str, help="The latents for the validation")
self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers')

self.parser.add_argument('--n_images', type=int, default=None, help='Number of images to output. If None, run on all data')

def parse(self):
opts = self.parser.parse_args()
return opts
49 changes: 49 additions & 0 deletions mapper/options/train_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from argparse import ArgumentParser


class TrainOptions:

def __init__(self):
self.parser = ArgumentParser()
self.initialize()

def initialize(self):
self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use')
self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true")
self.parser.add_argument('--no_medium_mapper', default=False, action="store_true")
self.parser.add_argument('--no_fine_mapper', default=False, action="store_true")
self.parser.add_argument('--latents_train_path', default="train_faces.pt", type=str, help="The latents for the training")
self.parser.add_argument('--latents_test_path', default="test_faces.pt", type=str, help="The latents for the validation")
self.parser.add_argument('--train_dataset_size', default=5000, type=int, help="Will be used only if no latents are given")
self.parser.add_argument('--test_dataset_size', default=1000, type=int, help="Will be used only if no latents are given")

self.parser.add_argument('--batch_size', default=2, type=int, help='Batch size for training')
self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference')
self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')
self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers')

self.parser.add_argument('--learning_rate', default=0.5, type=float, help='Optimizer learning rate')
self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')

self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor')
self.parser.add_argument('--clip_lambda', default=1.0, type=float, help='CLIP loss multiplier factor')
self.parser.add_argument('--latent_l2_lambda', default=0.8, type=float, help='Latent L2 loss multiplier factor')

self.parser.add_argument('--stylegan_weights', default='../pretrained_models/stylegan2-ffhq-config-f.pt', type=str, help='Path to StyleGAN model weights')
self.parser.add_argument('--stylegan_size', default=1024, type=int)
self.parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss")
self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to StyleCLIPModel model checkpoint')

self.parser.add_argument('--max_steps', default=50000, type=int, help='Maximum number of training steps')
self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training')
self.parser.add_argument('--board_interval', default=50, type=int, help='Interval for logging metrics to tensorboard')
self.parser.add_argument('--val_interval', default=2000, type=int, help='Validation interval')
self.parser.add_argument('--save_interval', default=2000, type=int, help='Model checkpoint interval')

self.parser.add_argument('--description', required=True, type=str, help='Driving text prompt')


def parse(self):
opts = self.parser.parse_args()
return opts
Loading

0 comments on commit d5e3c44

Please sign in to comment.