forked from orpatashnik/StyleCLIP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlatent_mappers.py
128 lines (88 loc) · 3.44 KB
/
latent_mappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from torch import nn
from torch.nn import Module
from models.stylegan2.model import EqualLinear, PixelNorm
STYLESPACE_DIMENSIONS = [512 for _ in range(15)] + [256, 256, 256] + [128, 128, 128] + [64, 64, 64] + [32, 32]
class Mapper(Module):
def __init__(self, opts, latent_dim=512):
super(Mapper, self).__init__()
self.opts = opts
layers = [PixelNorm()]
for i in range(4):
layers.append(
EqualLinear(
latent_dim, latent_dim, lr_mul=0.01, activation='fused_lrelu'
)
)
self.mapping = nn.Sequential(*layers)
def forward(self, x):
x = self.mapping(x)
return x
class SingleMapper(Module):
def __init__(self, opts):
super(SingleMapper, self).__init__()
self.opts = opts
self.mapping = Mapper(opts)
def forward(self, x):
out = self.mapping(x)
return out
class LevelsMapper(Module):
def __init__(self, opts):
super(LevelsMapper, self).__init__()
self.opts = opts
if not opts.no_coarse_mapper:
self.course_mapping = Mapper(opts)
if not opts.no_medium_mapper:
self.medium_mapping = Mapper(opts)
if not opts.no_fine_mapper:
self.fine_mapping = Mapper(opts)
def forward(self, x):
x_coarse = x[:, :4, :]
x_medium = x[:, 4:8, :]
x_fine = x[:, 8:, :]
if not self.opts.no_coarse_mapper:
x_coarse = self.course_mapping(x_coarse)
else:
x_coarse = torch.zeros_like(x_coarse)
if not self.opts.no_medium_mapper:
x_medium = self.medium_mapping(x_medium)
else:
x_medium = torch.zeros_like(x_medium)
if not self.opts.no_fine_mapper:
x_fine = self.fine_mapping(x_fine)
else:
x_fine = torch.zeros_like(x_fine)
out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
return out
class FullStyleSpaceMapper(Module):
def __init__(self, opts):
super(FullStyleSpaceMapper, self).__init__()
self.opts = opts
for c, c_dim in enumerate(STYLESPACE_DIMENSIONS):
setattr(self, f"mapper_{c}", Mapper(opts, latent_dim=c_dim))
def forward(self, x):
out = []
for c, x_c in enumerate(x):
curr_mapper = getattr(self, f"mapper_{c}")
x_c_res = curr_mapper(x_c.view(x_c.shape[0], -1)).view(x_c.shape)
out.append(x_c_res)
return out
class WithoutToRGBStyleSpaceMapper(Module):
def __init__(self, opts):
super(WithoutToRGBStyleSpaceMapper, self).__init__()
self.opts = opts
indices_without_torgb = list(range(1, len(STYLESPACE_DIMENSIONS), 3))
self.STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in indices_without_torgb]
for c in self.STYLESPACE_INDICES_WITHOUT_TORGB:
setattr(self, f"mapper_{c}", Mapper(opts, latent_dim=STYLESPACE_DIMENSIONS[c]))
def forward(self, x):
out = []
for c in range(len(STYLESPACE_DIMENSIONS)):
x_c = x[c]
if c in self.STYLESPACE_INDICES_WITHOUT_TORGB:
curr_mapper = getattr(self, f"mapper_{c}")
x_c_res = curr_mapper(x_c.view(x_c.shape[0], -1)).view(x_c.shape)
else:
x_c_res = torch.zeros_like(x_c)
out.append(x_c_res)
return out