Skip to content

Commit

Permalink
recall at k origin
Browse files Browse the repository at this point in the history
  • Loading branch information
bnu-wangxun committed May 23, 2018
1 parent e2fffec commit 852c2bf
Show file tree
Hide file tree
Showing 23 changed files with 2,332 additions and 194 deletions.
79 changes: 61 additions & 18 deletions DataSet/In_shop_clothes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import absolute_import, print_function
"""
In-shop-clothes data-set for Pytorch
"""
Expand All @@ -6,15 +7,15 @@
from PIL import Image

import os
from DataSet import transforms
from torchvision import transforms
from collections import defaultdict


def default_loader(path):
return Image.open(path).convert('RGB')


class InShopClothes(data.Dataset):
class MyData(data.Dataset):
def __init__(self, root=None, label_txt=None,
transform=None, loader=default_loader):

Expand Down Expand Up @@ -44,6 +45,8 @@ def __init__(self, root=None, label_txt=None,
labels = []

for img_anon in images_anon:
img_anon = img_anon.replace(' ', '\t')

[img, label] = (img_anon.split('\t'))[:2]
images.append(img)
labels.append(int(label))
Expand Down Expand Up @@ -75,23 +78,63 @@ def __len__(self):
return len(self.images)


class InShopClothes:
def __init__(self, root=None, transform=None):
# Data loading code

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

if transform is None:
transform = [transforms.Compose([
# transforms.CovertBGR(),
transforms.Resize(256),
transforms.RandomResizedCrop(scale=(0.16, 1), size=224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]),
transforms.Compose([
# transforms.CovertBGR(),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])]

if root is None:
root = '/opt/intern/users/xunwang/DataSet/In_shop_clothes_retrieval'

train_txt = os.path.join(root, 'train.txt')
gallery_txt = os.path.join(root, 'gallery.txt')
query_txt = os.path.join(root, 'query.txt')

self.train = MyData(root, label_txt=train_txt, transform=transform[0])
self.gallery = MyData(root, label_txt=gallery_txt, transform=transform[1])
self.query = MyData(root, label_txt=query_txt, transform=transform[1])


def testIn_Shop_Clothes():
dataloader = InShopClothes(root="/Users/wangxun/DataSet/In_shop_clothes_retrieval/",
label_txt="/Users/wangxun/DataSet/In_shop_clothes_retrieval/train.txt")

# print('dataloader.getName', dataloader.getName())
print(dataloader.Index[3])

img_loader = torch.utils.data.DataLoader(
dataloader,
batch_size=4, shuffle=True, num_workers=2)

for index, batch in enumerate(img_loader):
# print(img)
print(batch)
if index == 1:
break
# print('label', label)
# dataloader = MyData(root="/Users/wangxun/DataSet/In_shop_clothes_retrieval/",
# label_txt="/Users/wangxun/DataSet/In_shop_clothes_retrieval/train.txt")
#
# # print('dataloader.getName', dataloader.getName())
# print(dataloader.Index[3])
#
# img_loader = torch.utils.data.DataLoader(
# dataloader,
# batch_size=4, shuffle=True, num_workers=2)
#
# for index, batch in enumerate(img_loader):
# # print(img)
# print(batch)
# if index == 1:
# break
# # print('label', label)
data = InShopClothes()
print(len(data.gallery))
print(len(data.query))
print(len(data.train))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion DataSet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
'cub': CUB200,
'car': Car196,
'product': Products,
'shop': In_shop_clothes,
'shop': InShopClothes,
}


Expand Down
2 changes: 1 addition & 1 deletion evaluations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

from .cnn import extract_cnn_feature
from .extract_featrure import extract_features, pairwise_distance, pairwise_similarity
from .recall_at_k import Recall_at_ks, Recall_at_ks_products
from .recall_at_k import Recall_at_ks, Recall_at_ks_products, Recall_at_ks_shop
from .NMI import NMI
# from utils import to_torch
15 changes: 9 additions & 6 deletions evaluations/cnn.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from collections import OrderedDict

from torch.autograd import Variable
# from torch.autograd import Variable
from utils import to_torch
import torch


def extract_cnn_feature(model, inputs, modules=None):
model.eval()
inputs = to_torch(inputs)
inputs = Variable(inputs, volatile=True).cuda()
if modules is None:
outputs = model(inputs)
outputs = outputs.data.cpu()
return outputs
with torch.no_grad():
inputs = inputs.cuda()
if modules is None:
outputs = model(inputs)
outputs = outputs.data.cpu()
return outputs

# Register forward hook for each module
outputs = OrderedDict()
handles = []
Expand Down
45 changes: 29 additions & 16 deletions evaluations/extract_featrure.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,32 @@ def pairwise_distance(features, metric=None):
return dist


def pairwise_similarity(features):
n = len(features)
x = torch.cat(features)
x = x.view(n, -1)
# print(4*'\n', x.size())
similarity = torch.mm(x, x.t()) - 1e5 * torch.eye(n)
return similarity

#
# features = torch.round(2*torch.rand(4, 2))
# print(features)
# distmat = pairwise_similarity(features)
# distmat = to_numpy(distmat)
# indices = np.argsort(distmat, axis=1)
# print(distmat)
# print(indices)
def pairwise_similarity(x, y=None):

if y is None:

n = len(x)
x = torch.cat(x)
x = x.view(n, -1)
x = normalize(x)
# print(4*'\n', x.size())
similarity = torch.mm(x, x.t()) - 1e5 * torch.eye(n)
return similarity

else:

m = len(y)
y = torch.cat(y)
y = y.view(m, -1)
y = normalize(y)

n = len(x)
x = torch.cat(x)
x = x.view(n, -1)
x = normalize(x)

similarity = torch.mm(x, y.t())
return similarity



112 changes: 41 additions & 71 deletions evaluations/recall_at_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,115 +2,85 @@
from __future__ import absolute_import
import heapq
import numpy as np
import random
from utils import to_numpy
import time
import random


def Recall_at_ks(sim_mat, query_ids=None, gallery_ids=None):
def Recall_at_ks(sim_mat, k_s=None, query_ids=None, gallery_ids=None):
start_time = time.time()
print(start_time)
"""
:param sim_mat:
:param query_ids
:param gallery_ids
for the Deep Metric problem, following the evaluation table of Proxy NCA loss
only compute the [R@1, R@2, R@4, R@8]
fast computation via heapq
Compute [R@1, R@2, R@4, R@8]
"""
if k_s is None:
k_s = [1, 2, 4, 8]

sim_mat = to_numpy(sim_mat)
m, n = sim_mat.shape
num_max = int(1e4)
# Fill up default values
gallery_ids = np.asarray(gallery_ids)
if query_ids is None:
query_ids = np.arange(m)
if gallery_ids is None:
gallery_ids = np.arange(n)
# Ensure numpy array
query_ids = gallery_ids
else:
query_ids = np.asarray(query_ids)

num_max = int(1e4)

if m > num_max:
samples = list(range(m))
random.shuffle(samples)
samples = samples[:num_max]
sim_mat = sim_mat[samples, :]
query_ids = [query_ids[k] for k in samples]
m = num_max
else:
query_ids = np.asarray(query_ids)

# Sort and find correct matches
# indice = np.argsort(sim_mat, axis=1)
num_valid = np.zeros(4)
num_valid = np.zeros(len(k_s))
for i in range(m):
if i % 1000 == 0:
print(i)
x = sim_mat[i]
indice = heapq.nlargest(8, range(len(x)), x.take)
indice = heapq.nlargest(k_s[-1], range(len(x)), x.take)

if query_ids[i] == gallery_ids[indice[0]]:
num_valid += 1
elif query_ids[i] == gallery_ids[indice[1]]:
num_valid[1:] += 1
elif query_ids[i] in gallery_ids[indice[1:4]]:
num_valid[2:] += 1
elif query_ids[i] in gallery_ids[indice[4:]]:
num_valid[3:] += 1
return num_valid/float(m)
continue

for k in range(len(k_s) - 1):
if query_ids[i] in gallery_ids[indice[k_s[k]: k_s[k + 1]]]:
num_valid[(k + 1):] += 1
break
print(time.time())
t = time.time() - start_time
print(t)
return num_valid / float(m)


def Recall_at_ks_products(sim_mat, query_ids=None, gallery_ids=None):
"""
:param sim_mat:
:param query_ids
:param gallery_ids
for the Deep Metric problem, following the evaluation table of Proxy NCA loss
only compute the [R@1, R@10, R@100]
Compute [R@1, R@10, R@100] for stanford on-line Product
"""
return Recall_at_ks(sim_mat, query_ids=query_ids, gallery_ids=gallery_ids, k_s=[1, 10, 100])

fast computation via heapq

def Recall_at_ks_shop(sim_mat, query_ids=None, gallery_ids=None):
"""
sim_mat = to_numpy(sim_mat)
m, n = sim_mat.shape
num_max = int(1e4)
# Fill up default values
gallery_ids = np.asarray(gallery_ids)
if query_ids is None:
query_ids = np.arange(m)
if gallery_ids is None:
gallery_ids = np.arange(n)
# Ensure numpy array
if m > num_max:
samples = list(range(m))
random.shuffle(samples)
samples = samples[:num_max]
sim_mat = sim_mat[samples, :]
query_ids = [query_ids[k] for k in samples]
m = num_max
else:
query_ids = np.asarray(query_ids)

# Sort and find correct matches
# indice = np.argsort(sim_mat, axis=1)
num_valid = np.zeros(4)
for i in range(m):
x = sim_mat[i]
indice = heapq.nlargest(1000, range(len(x)), x.take)
if query_ids[i] == gallery_ids[indice[0]]:
num_valid += 1
elif query_ids[i] in gallery_ids[indice[1:10]]:
num_valid[1:] += 1
elif query_ids[i] in gallery_ids[indice[10:100]]:
num_valid[2:] += 1
elif query_ids[i] in gallery_ids[indice[100:]]:
num_valid[3] += 1
return num_valid/float(m)
Compute [R@1, R@10, R@20, ..., R@50] for In-shop-clothes
"""
return Recall_at_ks(sim_mat, query_ids=query_ids,
gallery_ids=gallery_ids, k_s=[1, 10, 20, 30, 40, 50])


def main():
def test():
import torch
sim_mat = torch.rand(int(7e4), int(7*400))
sim_mat = to_numpy(sim_mat)
query_ids = int(1e4)*list(range(7))
gallery_ids = int(1e3)*list(range(7))
print(Recall_at_ks(sim_mat, query_ids, gallery_ids))
print(Recall_at_ks_shop(sim_mat, query_ids, gallery_ids))

if __name__ == '__main__':
main()
test()
22 changes: 12 additions & 10 deletions losses/BinDevianceLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,22 @@ def forward(self, inputs, targets):
# print(40*'#')
for i, pos_pair in enumerate(pos_sim):
# print(i)
pos_pair = torch.sort(pos_pair)[0]
neg_pair = neg_sim[i]
neg_pair = torch.masked_select(neg_pair, neg_pair > pos_pair[0] - 0.05)
pos_pair = torch.masked_select(pos_pair, pos_pair < base)
# pos_pair = pos_pair[1:]
if len(neg_pair) < 2:
pos_pair_ = torch.sort(pos_pair)[0]
neg_pair_ = torch.sort(neg_sim[i])[0]
neg_pair = torch.masked_select(neg_pair_, neg_pair_ > pos_pair_[0] - 0.05)
pos_pair = torch.masked_select(pos_pair_, pos_pair_ < base)

# for train stability
if len(neg_pair) < 1:
c += 1
# print(len(pos_pair))
# print(len(neg_pair))
neg_pair = neg_pair_[-1]

if len(pos_pair) < 1:
pos_pair = pos_pair_[0]

pos_loss = torch.mean(torch.log(1 + torch.exp(-2*(pos_pair - self.margin))))
neg_loss = (float(2)/self.alpha) * torch.mean(torch.log(1 + torch.exp(self.alpha*(neg_pair - self.margin))))
loss_ = pos_loss + neg_loss
# print(pos_loss)
# print(neg_loss)
loss = loss + loss_

loss = loss/n
Expand Down
Loading

0 comments on commit 852c2bf

Please sign in to comment.