Skip to content

Commit

Permalink
clean up spatial, aggregate pnet and pnetlin, remove some unused funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
richzhang committed Jul 27, 2019
1 parent 7180b0f commit 7b34113
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 194 deletions.
31 changes: 5 additions & 26 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np
from skimage.measure import compare_ssim
import torch
from torch.autograd import Variable

import torch
from models import dist_model

class PerceptualLoss(torch.nn.Module):
Expand All @@ -17,6 +17,7 @@ def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False,
print('Setting up Perceptual loss...')
self.use_gpu = use_gpu
self.spatial = spatial
self.gpu_ids = gpu_ids
self.model = dist_model.DistModel()
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
print('...[%s] initialized'%self.model.name())
Expand All @@ -37,33 +38,15 @@ def forward(self, pred, target, normalize=False):
pred = 2 * pred - 1

if(self.use_gpu):
target = target.cuda()
pred = pred.cuda()

if(not self.spatial):
return self.model.forward_pair(target, pred)
else:
return self.model.forward(target, pred)


def cos_sim_blob(in0,in1):
in0_norm = normalize_blob(in0)
in1_norm = normalize_blob(in1)
(N,C,X,Y) = in0_norm.shape
target = Variable(target.to(device=self.gpu_ids[0]), requires_grad=True)
pred = Variable(pred.to(device=self.gpu_ids[0]), requires_grad=True)

return np.mean(np.mean(np.sum(in0_norm*in1_norm,axis=1),axis=1),axis=1)
return self.model.forward(target, pred)

def normalize_tensor(in_feat,eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
return in_feat/(norm_factor+eps)

def cos_sim(in0,in1):
in0_norm = normalize_tensor(in0)
in1_norm = normalize_tensor(in1)
N = in0.size()[0]

return torch.mean(torch.mean(torch.sum(in0_norm*in1_norm,dim=1,keepdim=True),dim=2,keepdim=True),dim=3,keepdim=True).view(N)

def l2(p0, p1, range=255.):
return .5*np.mean((p0 / range - p1 / range)**2)

Expand All @@ -80,10 +63,6 @@ def rgb2lab(in_img,mean_cent=False):
img_lab[:,:,0] = img_lab[:,:,0]-50
return img_lab

def normalize_blob(in_feat,eps=1e-10):
norm_factor = np.sqrt(np.sum(in_feat**2,axis=1,keepdims=True))
return in_feat/(norm_factor+eps)

def tensor2np(tensor_obj):
# change dimension of a tensor object into a numpy array
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
Expand Down
1 change: 0 additions & 1 deletion models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def save_network(self, network, path, network_label, epoch_label):

# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label):
# embed()
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
print('Loading network from %s'%save_path)
Expand Down
78 changes: 12 additions & 66 deletions models/dist_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from __future__ import absolute_import

import sys
# sys.path.append('..')
# sys.path.append('.')
import numpy as np
import torch
from torch import nn
Expand All @@ -28,7 +26,7 @@ def name(self):
return self.model_name

def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
use_gpu=True, printNet=False, spatial=False, spatial_shape=None, spatial_order=1, spatial_factor=None,
use_gpu=True, printNet=False, spatial=False,
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
'''
INPUTS
Expand Down Expand Up @@ -57,14 +55,12 @@ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=Fa
self.net = net
self.is_train = is_train
self.spatial = spatial
self.spatial_shape = spatial_shape
self.spatial_order = spatial_order
self.spatial_factor = spatial_factor
self.gpu_ids = gpu_ids

self.model_name = '%s [%s]'%(model,net)
if(self.model == 'net-lin'): # pretrained net + linear layer
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
use_dropout=True, spatial=spatial, version=version)
use_dropout=True, spatial=spatial, version=version, lpips=True)
kw = {}
if not use_gpu:
kw['map_location'] = 'cpu'
Expand All @@ -77,9 +73,7 @@ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=Fa
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)

elif(self.model=='net'): # pretrained network
assert not self.spatial, 'spatial argument not supported yet for uncalibrated networks'
self.net = networks.PNet(pnet_type=net)
self.is_fake_net = True
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
elif(self.model in ['L2','l2']):
self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
self.model_name = 'L2'
Expand Down Expand Up @@ -112,61 +106,15 @@ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=Fa
networks.print_network(self.net)
print('-----------------------------------------------')

def forward_pair(self, in1, in2, retPerLayer=False):
if(retPerLayer):
return self.net.forward(in1, in2, retPerLayer=True)
else:
return self.net.forward(in1, in2)

def forward(self, in0, in1, retNumpy=True):
def forward(self, in0, in1, retPerLayer=False):
''' Function computes the distance between image patches in0 and in1
INPUTS
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
retNumpy - [False] to return as torch.Tensor, [True] to return as numpy array
OUTPUT
computed distances between in0 and in1
'''

self.input_ref = in0
self.input_p0 = in1

if(self.use_gpu):
self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])

self.var_ref = Variable(self.input_ref,requires_grad=True)
self.var_p0 = Variable(self.input_p0,requires_grad=True)

self.d0 = self.forward_pair(self.var_ref, self.var_p0)
self.loss_total = self.d0

def convert_output(d0):
if(retNumpy):
ans = d0.cpu().data.numpy()
if not self.spatial:
ans = ans.flatten()
else:
assert(ans.shape[0] == 1 and len(ans.shape) == 4)
return ans[0,...].transpose([1, 2, 0]) # Reshape to usual numpy image format: (height, width, channels)
return ans
else:
return d0

if self.spatial:
L = [convert_output(x) for x in self.d0]
spatial_shape = self.spatial_shape
if spatial_shape is None:
if(self.spatial_factor is None):
spatial_shape = (in0.size()[2],in0.size()[3])
else:
spatial_shape = (max([x.shape[0] for x in L])*self.spatial_factor, max([x.shape[1] for x in L])*self.spatial_factor)

L = [skimage.transform.resize(x, spatial_shape, order=self.spatial_order, mode='edge') for x in L]

L = np.mean(np.concatenate(L, 2) * len(L), 2)
return L
else:
return convert_output(self.d0)
return self.net.forward(in0, in1, retPerLayer=retPerLayer)

# ***** TRAINING FUNCTIONS *****
def optimize_parameters(self):
Expand Down Expand Up @@ -201,8 +149,8 @@ def forward_train(self): # run forward pass
# print(self.net.module.scaling_layer.shift)
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())

self.d0 = self.forward_pair(self.var_ref, self.var_p0)
self.d1 = self.forward_pair(self.var_ref, self.var_p1)
self.d0 = self.forward(self.var_ref, self.var_p0)
self.d1 = self.forward(self.var_ref, self.var_p1)
self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)

self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
Expand Down Expand Up @@ -261,8 +209,6 @@ def update_learning_rate(self,nepoch_decay):
print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
self.old_lr = lr



def score_2afc_dataset(data_loader, func, name=''):
''' Function computes Two Alternative Forced Choice (2AFC) score using
distance function 'func' in dataset 'data_loader'
Expand All @@ -287,8 +233,8 @@ def score_2afc_dataset(data_loader, func, name=''):
gts = []

for data in tqdm(data_loader.load_data(), desc=name):
d0s+=func(data['ref'],data['p0']).tolist()
d1s+=func(data['ref'],data['p1']).tolist()
d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
gts+=data['judge'].cpu().numpy().flatten().tolist()

d0s = np.array(d0s)
Expand All @@ -303,7 +249,7 @@ def score_jnd_dataset(data_loader, func, name=''):
INPUTS
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
func - callable distance function - calling d=func(in0,in1) should take 2
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
OUTPUTS
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
[1] - dictionary with following elements
Expand All @@ -317,7 +263,7 @@ def score_jnd_dataset(data_loader, func, name=''):
gts = []

for data in tqdm(data_loader.load_data(), desc=name):
ds+=func(data['p0'],data['p1']).tolist()
ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
gts+=data['same'].cpu().numpy().flatten().tolist()

sames = np.array(gts)
Expand Down
Loading

0 comments on commit 7b34113

Please sign in to comment.