Skip to content

Commit

Permalink
fix spatial
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard Zhang authored and Richard Zhang committed Oct 4, 2020
1 parent af38ce5 commit 9c26c20
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions lpips/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ def spatial_average(in_tens, keepdim=True):

def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W
in_H, in_W = in_tens.shape[2], in_tens.shape[3]
scale_factor_H, scale_factor_W = 1.*out_HW[0]/in_H, 1.*out_HW[1]/in_W

return nn.Upsample(scale_factor=(scale_factor_H, scale_factor_W), mode='bilinear', align_corners=False)(in_tens)
return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)

# Learned perceptual metric
class LPIPS(nn.Module):
Expand Down Expand Up @@ -95,9 +93,9 @@ def forward(self, in0, in1, retPerLayer=False, normalize=False):

if(self.lpips):
if(self.spatial):
res = [upsample(self.lins[kk].model(diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
else:
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
else:
if(self.spatial):
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
Expand All @@ -107,6 +105,16 @@ def forward(self, in0, in1, retPerLayer=False, normalize=False):
val = res[0]
for l in range(1,self.L):
val += res[l]

# a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
# b = torch.max(self.lins[kk](feats0[kk]**2))
# for kk in range(self.L):
# a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
# b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
# a = a/self.L
# from IPython import embed
# embed()
# return 10*torch.log10(b/a)

if(retPerLayer):
return (val, res)
Expand All @@ -133,6 +141,8 @@ def __init__(self, chn_in, chn_out=1, use_dropout=False):
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
self.model = nn.Sequential(*layers)

def forward(self, x):
return self.model(x)

class Dist2LogitLayer(nn.Module):
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
Expand Down

0 comments on commit 9c26c20

Please sign in to comment.