Skip to content

Commit

Permalink
fix ret per layer bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard Zhang authored and Richard Zhang committed Aug 25, 2021
1 parent 8db312a commit 1c4cf7c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 44 deletions.
29 changes: 0 additions & 29 deletions lpips/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,6 @@
from lpips.trainer import *
from lpips.lpips import *

# class PerceptualLoss(torch.nn.Module):
# def __init__(self, model='lpips', net='alex', spatial=False, use_gpu=False, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric)
# # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
# super(PerceptualLoss, self).__init__()
# print('Setting up Perceptual loss...')
# self.use_gpu = use_gpu
# self.spatial = spatial
# self.gpu_ids = gpu_ids
# self.model = dist_model.DistModel()
# self.model.initialize(model=model, net=net, use_gpu=use_gpu, spatial=self.spatial, gpu_ids=gpu_ids, version=version)
# print('...[%s] initialized'%self.model.name())
# print('...Done')

# def forward(self, pred, target, normalize=False):
# """
# Pred and target are Variables.
# If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
# If normalize is False, assumes the images are already between [-1,+1]

# Inputs pred and target are Nx3xHxW
# Output pytorch Variable N long
# """

# if normalize:
# target = 2 * target - 1
# pred = 2 * pred - 1

# return self.model.forward(target, pred)

def normalize_tensor(in_feat,eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
return in_feat/(norm_factor+eps)
Expand Down
50 changes: 36 additions & 14 deletions lpips/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,40 @@ def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and
class LPIPS(nn.Module):
def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True):
# lpips - [True] means with linear calibration on top of base network
# pretrained - [True] means load linear weights
""" Initializes a perceptual loss torch.nn.Module
Parameters (default listed first)
---------------------------------
lpips : bool
[True] use linear layers on top of base/trunk network
[False] means no linear layers; each layer is averaged together
pretrained : bool
This flag controls the linear layers, which are only in effect when lpips=True above
[True] means linear layers are calibrated with human perceptual judgments
[False] means linear layers are randomly initialized
pnet_rand : bool
[False] means trunk loaded with ImageNet classification weights
[True] means randomly initialized trunk
net : str
['alex','vgg','squeeze'] are the base/trunk networks available
version : str
['v0.1'] is the default and latest
['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1)
model_path : 'str'
[None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1
The following parameters should only be changed if training the network
eval_mode : bool
[True] is for test mode (default)
[False] is for training mode
pnet_tune
[False] tune the base/trunk network
[True] keep base/trunk frozen
use_dropout : bool
[True] to use dropout when training linear layers
[False] for no dropout when training linear layers
"""

super(LPIPS, self).__init__()
if(verbose):
Expand Down Expand Up @@ -102,19 +134,9 @@ def forward(self, in0, in1, retPerLayer=False, normalize=False):
else:
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]

val = res[0]
for l in range(1,self.L):
val = 0
for l in range(self.L):
val += res[l]

# a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
# b = torch.max(self.lins[kk](feats0[kk]**2))
# for kk in range(self.L):
# a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
# b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
# a = a/self.L
# from IPython import embed
# embed()
# return 10*torch.log10(b/a)

if(retPerLayer):
return (val, res)
Expand Down
5 changes: 4 additions & 1 deletion lpips_2imgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,8 @@
img1 = img1.cuda()

# Compute distance
dist01 = loss_fn.forward(img0,img1)
dist01 = loss_fn.forward(img0, img1)
print('Distance: %.3f'%dist01)

dist01 = loss_fn.forward(img0, img1, retPerLayer=True)
print(dist01)

0 comments on commit 1c4cf7c

Please sign in to comment.