-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathevaluate_gpu.py
122 lines (106 loc) · 4.14 KB
/
evaluate_gpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import scipy.io
import torch
import numpy as np
#import time
import os
#######################################################################
# Evaluate
def evaluate(qf,ql,qc,gf,gl,gc):
query = qf.view(-1,1) #将Tensor拉成一列
# print(query.shape)
score = torch.mm(gf,query) #矩阵相成:余弦距离
score = score.squeeze(1).cpu()
score = score.numpy()
# predict index
index = np.argsort(score) #from small to large,排序
index = index[::-1] #将图像按他们的相似度排序
# index = index[0:2000]
# good index
query_index = np.argwhere(gl==ql)
camera_index = np.argwhere(gc==qc)
good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
# 注意到有两种图像我们不把他们考虑为true - matches
#
# 一种是Junk_index1
# 错误检测的图像,主要是包含一些人的部件。
# 一种是Junk_index2
# 相同的人在同一摄像头下,按照reid的定义,我们不需要检索这一类图像。
junk_index1 = np.argwhere(gl==-1)
junk_index2 = np.intersect1d(query_index, camera_index)
junk_index = np.append(junk_index2, junk_index1) #.flatten())
CMC_tmp = compute_mAP(index, good_index, junk_index)
return CMC_tmp
def compute_mAP(index, good_index, junk_index):
ap = 0
cmc = torch.IntTensor(len(index)).zero_()
if good_index.size==0: # if empty
cmc[0] = -1
return ap,cmc
# remove junk_index
mask = np.in1d(index, junk_index, invert=True)
index = index[mask]
# find good_index index
ngood = len(good_index)
mask = np.in1d(index, good_index)
rows_good = np.argwhere(mask==True)
rows_good = rows_good.flatten()
cmc[rows_good[0]:] = 1
for i in range(ngood):
d_recall = 1.0/ngood
precision = (i+1)*1.0/(rows_good[i]+1)
if rows_good[i]!=0:
old_precision = i*1.0/rows_good[i]
else:
old_precision=1.0
ap = ap + d_recall*(old_precision + precision)/2
return ap, cmc
######################################################################
result = scipy.io.loadmat('pytorch_result.mat')
query_feature = torch.FloatTensor(result['query_f'])
query_cam = result['query_cam'][0]
query_label = result['query_label'][0]
gallery_feature = torch.FloatTensor(result['gallery_f'])
gallery_cam = result['gallery_cam'][0]
gallery_label = result['gallery_label'][0]
multi = os.path.isfile('multi_query.mat')
if multi:
m_result = scipy.io.loadmat('multi_query.mat')
mquery_feature = torch.FloatTensor(m_result['mquery_f'])
mquery_cam = m_result['mquery_cam'][0]
mquery_label = m_result['mquery_label'][0]
mquery_feature = mquery_feature.cuda()
query_feature = query_feature.cuda()
gallery_feature = gallery_feature.cuda()
#len(gallery_label)=19732,Market1501 的gallery一共有19732张图像
CMC = torch.IntTensor(len(gallery_label)).zero_()
print(CMC.shape)
ap = 0.0
#print(query_label)
for i in range(len(query_label)):
ap_tmp, CMC_tmp = evaluate(query_feature[i],query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam)
if CMC_tmp[0]==-1:
continue
CMC = CMC + CMC_tmp
ap += ap_tmp
#print(i, CMC_tmp[0])
CMC = CMC.float()
CMC = CMC/len(query_label) #average CMC
print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label)))
# # multiple-query
# CMC = torch.IntTensor(len(gallery_label)).zero_()
# ap = 0.0
# if multi:
# for i in range(len(query_label)):
# mquery_index1 = np.argwhere(mquery_label==query_label[i])
# mquery_index2 = np.argwhere(mquery_cam==query_cam[i])
# mquery_index = np.intersect1d(mquery_index1, mquery_index2)
# mq = torch.mean(mquery_feature[mquery_index,:], dim=0)
# ap_tmp, CMC_tmp = evaluate(mq,query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam)
# if CMC_tmp[0]==-1:
# continue
# CMC = CMC + CMC_tmp
# ap += ap_tmp
# #print(i, CMC_tmp[0])
# CMC = CMC.float()
# CMC = CMC/len(query_label) #average CMC
# print('multi Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label)))