forked from facebook/FAI-PEP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
classification_compare.py
executable file
·156 lines (137 loc) · 5.79 KB
/
classification_compare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python
##############################################################################
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
##############################################################################
# this library is to compare the output of the benchmark and the golden output
# for image classification tasks, if the golden is 1, expecting the benchmark
# is the closest to that.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import json
import numpy as np
import os
parser = argparse.ArgumentParser(description="Output compare")
parser.add_argument("--benchmark-output", required=True,
help="The output of the benchmark.")
parser.add_argument("--labels", required=True,
help="The golden output.")
parser.add_argument("--metric-keyword",
help="The keyword prefix each metric so that the harness can parse.")
parser.add_argument("--result-file",
help="Write the prediction result to a file.")
parser.add_argument("--top", type=int, default=1,
help="Integer indicating whether it is a top one or top five.")
parser.add_argument("--name", required=True,
help="Specify the type of the metric.")
class OutputCompare(object):
def __init__(self):
self.args = parser.parse_args()
assert os.path.isfile(self.args.benchmark_output), \
"Benchmark output file {} doesn't exist".format(
self.args.benchmark_output)
assert os.path.isfile(self.args.labels), \
"Labels file {} doesn't exist".format(
self.args.labels)
def getData(self, filename):
num_entries = 0
content_list = []
with open(filename, "r") as f:
line = f.readline()
dim_str = line
while (line != ""):
assert dim_str == line, \
"The dimensions do not match"
num_entries = num_entries + 1
dims_list = [int(dim.strip())
for dim in line.strip().split(',')]
line = f.readline().strip()
content_list.extend([float(entry.strip())
for entry in line.split(',')])
line = f.readline()
dims_list.insert(0, num_entries)
dims = np.asarray(dims_list)
content = np.asarray(content_list)
data = np.reshape(content, dims)
# reshape to two dimension array
benchmark_data = data.reshape((-1, data.shape[-1]))
return benchmark_data.tolist(), dims_list
def writeOneResult(self, values, data, metric, unit):
entry = {
"type": self.args.name,
"values": values,
"summary": {
"num_runs": len(values),
"p0": data,
"p10": data,
"p50": data,
"p90": data,
"p100": data,
"mean": data,
},
"unit": unit,
"metric": metric,
}
s = json.dumps(entry, sort_keys=True)
if self.args.metric_keyword:
s = self.args.metric_keyword + " " + s
print(s)
return entry
def writeResult(self, results):
top = "top{}".format(str(self.args.top))
values = [item["predict"] for item in results]
num_corrects = sum(values)
percent = num_corrects * 100. / len(values)
output = {}
res = self.writeOneResult(values, num_corrects,
"number_of_{}_corrects".format(top),
"number")
output[res["type"] + "_" + res["metric"]] = res
res = self.writeOneResult(values, percent,
"percent_of_{}_corrects".format(top),
"percent")
output[res["type"] + "_" + res["metric"]] = res
if self.args.result_file:
s = json.dumps(output, sort_keys=True, indent=2)
with open(self.args.result_file, "w") as f:
f.write(s)
def compare(self):
benchmark_data, dims_list = self.getData(self.args.benchmark_output)
with open(self.args.labels, "r") as f:
content = f.read()
golden_lines = [item.strip().split(',')
for item in content.strip().split('\n')]
golden_data = [{"index": int(item[0]),
"label": item[1],
"path": item[2]} for item in golden_lines]
if len(benchmark_data) != len(golden_data):
idx = dims_list.index(len(golden_data))
benchmark_data = np.reshape(benchmark_data,
(dims_list[idx], dims_list[idx + 1]))
assert len(benchmark_data) == len(golden_data), \
"Benchmark data has {} entries, ".format(len(benchmark_data)) + \
"but golden data has {} entries".format(len(golden_data))
def sort_key(elem):
return elem["value"]
for i in range(len(benchmark_data)):
benchmark_one_entry = benchmark_data[i]
golden_one_entry = golden_data[i]
benchmark_result = [{
"index": j,
"value": benchmark_one_entry[j],
} for j in range(len(benchmark_one_entry))]
benchmark_result.sort(reverse=True, key=sort_key)
golden_one_entry["predict"] = 1 \
if golden_one_entry["index"] in \
[item["index"] for item in benchmark_result[:self.args.top]] \
else 0
self.writeResult(golden_data)
if __name__ == "__main__":
app = OutputCompare()
app.compare()