-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathFDL.py
90 lines (74 loc) · 3.04 KB
/
FDL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
import torch.nn.functional as F
from .models import VGG, ResNet, Inception, EffNet
class FDL_loss(torch.nn.Module):
def __init__(
self, patch_size=5, stride=1, num_proj=256, model="VGG", phase_weight=1.0
):
"""
patch_size, stride, num_proj: SWD slice parameters
model: feature extractor, support VGG, ResNet, Inception, EffNet
phase_weight: weight for phase branch
"""
super(FDL_loss, self).__init__()
if model == "ResNet":
self.model = ResNet()
elif model == "EffNet":
self.model = EffNet()
elif model == "Inception":
self.model = Inception()
elif model == "VGG":
self.model = VGG()
else:
assert "Invalid model type! Valid models: VGG, Inception, EffNet, ResNet"
self.phase_weight = phase_weight
self.stride = stride
for i in range(len(self.model.chns)):
rand = torch.randn(num_proj, self.model.chns[i], patch_size, patch_size)
rand = rand / rand.view(rand.shape[0], -1).norm(dim=1).unsqueeze(
1
).unsqueeze(2).unsqueeze(3)
self.register_buffer("rand_{}".format(i), rand)
# print all the parameters
def forward_once(self, x, y, idx):
"""
x, y: input image tensors with the shape of (N, C, H, W)
"""
rand = self.__getattr__("rand_{}".format(idx))
projx = F.conv2d(x, rand, stride=self.stride)
projx = projx.reshape(projx.shape[0], projx.shape[1], -1)
projy = F.conv2d(y, rand, stride=self.stride)
projy = projy.reshape(projy.shape[0], projy.shape[1], -1)
# sort the convolved input
projx, _ = torch.sort(projx, dim=-1)
projy, _ = torch.sort(projy, dim=-1)
# compute the mean of the sorted convolved input
s = torch.abs(projx - projy).mean([1, 2])
return s
def forward(self, x, y):
x = self.model(x)
y = self.model(y)
score = []
for i in range(len(x)):
# Transform to Fourier Space
fft_x = torch.fft.fftn(x[i], dim=(-2, -1))
fft_y = torch.fft.fftn(y[i], dim=(-2, -1))
# get the magnitude and phase of the extracted features
x_mag = torch.abs(fft_x)
x_phase = torch.angle(fft_x)
y_mag = torch.abs(fft_y)
y_phase = torch.angle(fft_y)
s_amplitude = self.forward_once(x_mag, y_mag, i)
s_phase = self.forward_once(x_phase, y_phase, i)
score.append(s_amplitude + s_phase * self.phase_weight)
score = sum(score) # sumup between different layers
score = score.mean() # mean within batch
return score # the bigger the score, the bigger the difference between the two images
# if __name__ == '__main__':
# print("FDL_loss")
# X = torch.randn((1, 3,128,128)).cuda()
# Y = torch.randn((1, 3,128,128)).cuda() * 2
# loss = FDL_loss().cuda()
# c = loss(X,Y)
# print('loss:', c)