forked from XPixelGroup/BasicSR
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
905 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.nn.utils.spectral_norm as SpectralNorm | ||
|
||
from basicsr.models.archs.dfdnet_util import (AttentionBlock, Blur, | ||
MSDilationBlock, UpResBlock, | ||
adaptive_instance_normalization) | ||
from basicsr.models.archs.vgg_arch import VGGFeatureExtractor | ||
|
||
|
||
class SFTUpBlock(nn.Module): | ||
"""Spatial feature transform (SFT) with upsampling block.""" | ||
|
||
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1): | ||
super(SFTUpBlock, self).__init__() | ||
self.conv1 = nn.Sequential( | ||
Blur(in_channel), | ||
SpectralNorm( | ||
nn.Conv2d( | ||
in_channel, out_channel, kernel_size, padding=padding)), | ||
nn.LeakyReLU(0.04, True), | ||
# The official codes use two LeakyReLU here, so 0.04 for equivalent | ||
) | ||
self.convup = nn.Sequential( | ||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | ||
SpectralNorm( | ||
nn.Conv2d( | ||
out_channel, out_channel, kernel_size, padding=padding)), | ||
nn.LeakyReLU(0.2, True), | ||
) | ||
|
||
# for SFT scale and shift | ||
self.scale_block = nn.Sequential( | ||
SpectralNorm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), | ||
nn.LeakyReLU(0.2, True), | ||
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))) | ||
self.shift_block = nn.Sequential( | ||
SpectralNorm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), | ||
nn.LeakyReLU(0.2, True), | ||
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), | ||
nn.Sigmoid()) | ||
# The official codes use sigmoid for shift block, do not know why | ||
|
||
def forward(self, x, updated_feat): | ||
out = self.conv1(x) | ||
# SFT | ||
scale = self.scale_block(updated_feat) | ||
shift = self.shift_block(updated_feat) | ||
out = out * scale + shift | ||
# upsample | ||
out = self.convup(out) | ||
return out | ||
|
||
|
||
class VGGFaceFeatureExtractor(VGGFeatureExtractor): | ||
|
||
def preprocess(self, x): | ||
# norm to [0, 1] | ||
x = (x + 1) / 2 | ||
if self.use_input_norm: | ||
x = (x - self.mean) / self.std | ||
if x.shape[3] < 224: | ||
x = torch.nn.functional.interpolate( | ||
x, size=(224, 224), mode='bilinear', align_corners=False) | ||
return x | ||
|
||
def forward(self, x): | ||
x = self.preprocess(x) | ||
features = [] | ||
for key, layer in self.vgg_net._modules.items(): | ||
x = layer(x) | ||
if key in self.layer_name_list: | ||
features.append(x) | ||
return features | ||
|
||
|
||
class DFDNet(nn.Module): | ||
"""DFDNet: Deep Face Dictionary Network. | ||
It only processes faces with 512x512 size. | ||
""" | ||
|
||
def __init__(self, num_feat, dict_path): | ||
super().__init__() | ||
self.parts = ['left_eye', 'right_eye', 'nose', 'mouth'] | ||
# part_sizes: [80, 80, 50, 110] | ||
channel_sizes = [128, 256, 512, 512] | ||
self.feature_sizes = np.array([256, 128, 64, 32]) | ||
self.flag_dict_device = False | ||
|
||
# dict | ||
self.dict = torch.load(dict_path) | ||
|
||
# vgg face extractor | ||
self.vgg_extractor = VGGFaceFeatureExtractor( | ||
layer_name_list=['conv2_2', 'conv3_4', 'conv4_4', 'conv5_4'], | ||
vgg_type='vgg19', | ||
use_input_norm=True, | ||
requires_grad=False) | ||
|
||
# attention block for fusing dictionary features and input features | ||
self.attn_blocks = nn.ModuleDict() | ||
for idx, feat_size in enumerate(self.feature_sizes): | ||
for name in self.parts: | ||
self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock( | ||
channel_sizes[idx]) | ||
|
||
# multi scale dilation block | ||
self.multi_scale_dilation = MSDilationBlock( | ||
num_feat * 8, dilation=[4, 3, 2, 1]) | ||
|
||
# upsampling and reconstruction | ||
self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8) | ||
self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4) | ||
self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2) | ||
self.upsample3 = SFTUpBlock(num_feat * 2, num_feat) | ||
self.upsample4 = nn.Sequential( | ||
SpectralNorm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), | ||
nn.LeakyReLU(0.2, True), UpResBlock(num_feat), | ||
UpResBlock(num_feat), | ||
nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), | ||
nn.Tanh()) | ||
|
||
def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, | ||
f_size): | ||
"""swap the features from the dictionary.""" | ||
# get the original vgg features | ||
part_feat = vgg_feat[:, :, location[1]:location[3], | ||
location[0]:location[2]].clone() | ||
# resize original vgg features | ||
part_resize_feat = F.interpolate( | ||
part_feat, | ||
dict_feat.size()[2:4], | ||
mode='bilinear', | ||
align_corners=False) | ||
# use adaptive instance normalization to adjust color and illuminations | ||
dict_feat = adaptive_instance_normalization(dict_feat, | ||
part_resize_feat) | ||
# get similarity scores | ||
similarity_score = F.conv2d(part_resize_feat, dict_feat) | ||
similarity_score = F.softmax(similarity_score.view(-1), dim=0) | ||
# select the most similar features in the dict (after norm) | ||
select_idx = torch.argmax(similarity_score) | ||
swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], | ||
part_feat.size()[2:4]) | ||
# attention | ||
attn = self.attn_blocks[f'{part_name}_' + str(f_size)]( | ||
swap_feat - part_feat) | ||
attn_feat = attn * swap_feat | ||
# update features | ||
updated_feat[:, :, location[1]:location[3], | ||
location[0]:location[2]] = attn_feat + part_feat | ||
return updated_feat | ||
|
||
def put_dict_to_device(self, x): | ||
if self.flag_dict_device is False: | ||
for k, v in self.dict.items(): | ||
for kk, vv in v.items(): | ||
self.dict[k][kk] = vv.to(x) | ||
self.flag_dict_device = True | ||
|
||
def forward(self, x, part_locations): | ||
""" | ||
Now only support testing with batch size = 0. | ||
Args: | ||
x (Tensor): Input faces with shape (b, c, 512, 512). | ||
part_locations (list[Tensor]): Part locations. | ||
""" | ||
self.put_dict_to_device(x) | ||
# extract vggface features | ||
vgg_features = self.vgg_extractor(x) | ||
# update vggface features using the dictionary for each part | ||
updated_vgg_features = [] | ||
batch = 0 # only supports testing with batch size = 0 | ||
for i, f_size in enumerate(self.feature_sizes): | ||
dict_features = self.dict[f'{f_size}'] | ||
vgg_feat = vgg_features[i] | ||
updated_feat = vgg_feat.clone() | ||
|
||
# swap features from dictionary | ||
for part_idx, part_name in enumerate(self.parts): | ||
location = (part_locations[part_idx][batch] // | ||
(512 / f_size)).int() | ||
updated_feat = self.swap_feat(vgg_feat, updated_feat, | ||
dict_features[part_name], | ||
location, part_name, f_size) | ||
|
||
updated_vgg_features.append(updated_feat) | ||
|
||
vgg_feat_dilation = self.multi_scale_dilation(vgg_features[3]) | ||
# use updated vgg features to modulate the upsampled features with | ||
# SFT (Spatial Feature Transform) scaling and shifting manner. | ||
upsampled_feat = self.upsample0(vgg_feat_dilation, | ||
updated_vgg_features[3]) | ||
upsampled_feat = self.upsample1(upsampled_feat, | ||
updated_vgg_features[2]) | ||
upsampled_feat = self.upsample2(upsampled_feat, | ||
updated_vgg_features[1]) | ||
upsampled_feat = self.upsample3(upsampled_feat, | ||
updated_vgg_features[0]) | ||
out = self.upsample4(upsampled_feat) | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.nn.utils.spectral_norm as SpectralNorm | ||
from torch.autograd import Function | ||
|
||
|
||
class BlurFunctionBackward(Function): | ||
|
||
@staticmethod | ||
def forward(ctx, grad_output, kernel, kernel_flip): | ||
ctx.save_for_backward(kernel, kernel_flip) | ||
grad_input = F.conv2d( | ||
grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]) | ||
return grad_input | ||
|
||
@staticmethod | ||
def backward(ctx, gradgrad_output): | ||
kernel, kernel_flip = ctx.saved_tensors | ||
grad_input = F.conv2d( | ||
gradgrad_output, | ||
kernel, | ||
padding=1, | ||
groups=gradgrad_output.shape[1]) | ||
return grad_input, None, None | ||
|
||
|
||
class BlurFunction(Function): | ||
|
||
@staticmethod | ||
def forward(ctx, x, kernel, kernel_flip): | ||
ctx.save_for_backward(kernel, kernel_flip) | ||
output = F.conv2d(x, kernel, padding=1, groups=x.shape[1]) | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
kernel, kernel_flip = ctx.saved_tensors | ||
grad_input = BlurFunctionBackward.apply(grad_output, kernel, | ||
kernel_flip) | ||
return grad_input, None, None | ||
|
||
|
||
blur = BlurFunction.apply | ||
|
||
|
||
class Blur(nn.Module): | ||
|
||
def __init__(self, channel): | ||
super().__init__() | ||
kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], | ||
dtype=torch.float32) | ||
kernel = kernel.view(1, 1, 3, 3) | ||
kernel = kernel / kernel.sum() | ||
kernel_flip = torch.flip(kernel, [2, 3]) | ||
|
||
self.kernel = kernel.repeat(channel, 1, 1, 1) | ||
self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1) | ||
|
||
def forward(self, x): | ||
return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x)) | ||
|
||
|
||
def calc_mean_std(feat, eps=1e-5): | ||
"""Calculate mean and std for adaptive_instance_normalization. | ||
Args: | ||
feat (Tensor): 4D tensor. | ||
eps (float): A small value added to the variance to avoid | ||
divide-by-zero. Default: 1e-5. | ||
""" | ||
size = feat.size() | ||
assert len(size) == 4, 'The input feature should be 4D tensor.' | ||
n, c = size[:2] | ||
feat_var = feat.view(n, c, -1).var(dim=2) + eps | ||
feat_std = feat_var.sqrt().view(n, c, 1, 1) | ||
feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1) | ||
return feat_mean, feat_std | ||
|
||
|
||
def adaptive_instance_normalization(content_feat, style_feat): | ||
"""Adaptive instance normalization. | ||
Adjust the reference features to have the similar color and illuminations | ||
as those in the degradate features. | ||
Args: | ||
content_feat (Tensor): The reference feature. | ||
style_feat (Tensor): The degradate features. | ||
""" | ||
size = content_feat.size() | ||
style_mean, style_std = calc_mean_std(style_feat) | ||
content_mean, content_std = calc_mean_std(content_feat) | ||
normalized_feat = (content_feat - | ||
content_mean.expand(size)) / content_std.expand(size) | ||
return normalized_feat * style_std.expand(size) + style_mean.expand(size) | ||
|
||
|
||
def AttentionBlock(in_channel): | ||
return nn.Sequential( | ||
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), | ||
nn.LeakyReLU(0.2, True), | ||
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1))) | ||
|
||
|
||
def conv_block(in_channels, | ||
out_channels, | ||
kernel_size=3, | ||
stride=1, | ||
dilation=1, | ||
bias=True): | ||
"""Conv block used in MSDilationBlock.""" | ||
|
||
return nn.Sequential( | ||
SpectralNorm( | ||
nn.Conv2d( | ||
in_channels, | ||
out_channels, | ||
kernel_size=kernel_size, | ||
stride=stride, | ||
dilation=dilation, | ||
padding=((kernel_size - 1) // 2) * dilation, | ||
bias=bias)), | ||
nn.LeakyReLU(0.2), | ||
SpectralNorm( | ||
nn.Conv2d( | ||
out_channels, | ||
out_channels, | ||
kernel_size=kernel_size, | ||
stride=stride, | ||
dilation=dilation, | ||
padding=((kernel_size - 1) // 2) * dilation, | ||
bias=bias)), | ||
) | ||
|
||
|
||
class MSDilationBlock(nn.Module): | ||
"""Multi-scale dilation block.""" | ||
|
||
def __init__(self, | ||
in_channels, | ||
kernel_size=3, | ||
dilation=[1, 1, 1, 1], | ||
bias=True): | ||
super(MSDilationBlock, self).__init__() | ||
|
||
self.conv_blocks = nn.ModuleList() | ||
for i in range(4): | ||
self.conv_blocks.append( | ||
conv_block( | ||
in_channels, | ||
in_channels, | ||
kernel_size, | ||
dilation=dilation[i], | ||
bias=bias)) | ||
self.conv_fusion = SpectralNorm( | ||
nn.Conv2d( | ||
in_channels * 4, | ||
in_channels, | ||
kernel_size=kernel_size, | ||
stride=1, | ||
padding=(kernel_size - 1) // 2, | ||
bias=bias)) | ||
|
||
def forward(self, x): | ||
out = [] | ||
for i in range(4): | ||
out.append(self.conv_blocks[i](x)) | ||
out = torch.cat(out, 1) | ||
out = self.conv_fusion(out) + x | ||
return out | ||
|
||
|
||
class UpResBlock(nn.Module): | ||
|
||
def __init__(self, in_channel): | ||
super(UpResBlock, self).__init__() | ||
self.body = nn.Sequential( | ||
nn.Conv2d(in_channel, in_channel, 3, 1, 1), | ||
nn.LeakyReLU(0.2, True), | ||
nn.Conv2d(in_channel, in_channel, 3, 1, 1), | ||
) | ||
|
||
def forward(self, x): | ||
out = x + self.body(x) | ||
return out |
Oops, something went wrong.