Skip to content

Commit

Permalink
fixed gram matrix distance to include linear weighting. refer to rich…
Browse files Browse the repository at this point in the history
  • Loading branch information
yohan-pg committed Mar 4, 2021
1 parent bd17fc5 commit d25baa7
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions metrics/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def __init__(self, in_channels, out_channels=1):
def forward(self, x):
return self.main(x)

def sqrt(self):
import copy

clone = copy.deepcopy(self)
clone.main[-1].weight.copy_(clone.main[-1].weight.sqrt())

return clone


class LPIPS(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -79,10 +87,13 @@ def forward(self, x, y, gram=False):
x_fmap = normalize(x_fmap)
y_fmap = normalize(y_fmap)
if gram:
x_fmap = x_fmap.reshape(*x_fmap.shape[:2], -1)
y_fmap = y_fmap.reshape(*y_fmap.shape[:2], -1)
sqrt_conv1x1 = conv1x1.sqrt()
x_fmap = x_fmap.reshape(*sqrt_conv1x1(x_fmap).shape[:2], -1)
y_fmap = y_fmap.reshape(*sqrt_conv1x1(y_fmap).shape[:2], -1)

x_gram = x_fmap.bmm(x_fmap.transpose(1, 2))
y_gram = y_fmap.bmm(y_fmap.transpose(1, 2))

lpips_value += torch.mean(((x_gram - y_gram)**2))
else:
lpips_value += torch.mean(conv1x1((x_fmap - y_fmap)**2))
Expand Down

0 comments on commit d25baa7

Please sign in to comment.