forked from Hanqer/deep-hough-transform
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_sel.py
76 lines (65 loc) · 2.79 KB
/
test_sel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
import numpy as np
import os
from hungarian_matching import caculate_tp_fp_fn
gt_path = './data/training/JTLEE_resize_100_100'
#gt_path = '/home/hanqi/work/semantic/data/crawl/JTLEE_resize_100_25'
parser = argparse.ArgumentParser(description='PyTorch Semantic-Line Training')
parser.add_argument('--pred', type=str, required=True)
parser.add_argument('--gt', type=str, required=True)
parser.add_argument('--align', default=False, action='store_true')
arg = parser.parse_args()
#
pred_path = arg.pred
gt_path = arg.gt
filenames = sorted(os.listdir(pred_path))
total_tp = np.zeros(99)
total_fp = np.zeros(99)
total_fn = np.zeros(99)
total_tp_align = np.zeros(99)
total_fp_align = np.zeros(99)
total_fn_align = np.zeros(99)
for filename in filenames:
if 'npy' not in filename:
continue
if 'align' in filename:
continue
pred = np.load(os.path.join(pred_path, filename))
if arg.align:
pred_align = np.load(os.path.join(pred_path, filename.split('.')[0]+'_align.npy'))
gt_txt = open(os.path.join(gt_path, filename.split('.')[0] + '.txt'))
gt_coords = gt_txt.readlines()
gt = [[int(float(l.rstrip().split(', ')[1])), int(float(l.rstrip().split(', ')[0])), int(float(l.rstrip().split(', ')[3])), int(float(l.rstrip().split(', ')[2]))] for l in gt_coords]
for i in range(1, 100):
tp, fp, fn = caculate_tp_fp_fn(pred.tolist(), gt, thresh=i*0.01)
total_tp[i-1] += tp
total_fp[i-1] += fp
total_fn[i-1] += fn
if arg.align:
tp, fp, fn = caculate_tp_fp_fn(pred_align.tolist(), gt.tolist(), thresh=i*0.01)
total_tp_align[i-1] += tp
total_fp_align[i-1] += fp
total_fn_align[i-1] += fn
# print('done')
total_recall = total_tp / (total_tp + total_fn)
total_precision = total_tp / (total_tp + total_fp)
f = 2 * total_recall * total_precision / (total_recall + total_precision + 1e-6)
if arg.align:
total_recall_align = total_tp_align / (total_tp_align + total_fn_align)
total_precision_align = total_tp_align / (total_tp_align + total_fp_align)
f_align = 2 * total_recall_align * total_precision_align / (total_recall_align + total_precision_align + 1e-6)
print('Mean P:', total_precision.mean())
print('Mean R:', total_recall.mean())
print('Mean F:', f.mean())
print('[email protected]:', f[94])
np.savetxt('precision.csv', total_precision)
np.savetxt('recall.csv', total_recall)
np.savetxt('fscore.csv', f)
if arg.align:
print('Mean P_align:', total_precision_align.mean())
print('Mean R_align:', total_recall_align.mean())
print('Mean F_align:', f_align.mean())
print('[email protected]:', f_align[94])
np.savetxt('total_precision_refine.csv', total_precision_align)
np.savetxt('total_recall_refine.csv', total_recall_align)
np.savetxt('total_f_refine.csv', f_align)