-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnlvr_fig6.py
58 lines (48 loc) · 1.48 KB
/
nlvr_fig6.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import argparse
from typing import Dict, Any
from evaluation.common import aggregate_without_voting, aggregate_with_voting, build_figure
from evaluation.nlvr import compute_stats
def process_results(results_file: str) -> Dict[str, Any]:
stats = compute_stats(results_file)
without_voting = aggregate_without_voting(stats)
with_voting = aggregate_with_voting(stats)
return dict(
without_voting=without_voting,
with_voting=with_voting,
)
def main():
parser = argparse.ArgumentParser(
description='Generate figure 6 bar plot for given the NLVR evaluation results',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
'in_context_2_results_file',
type=str,
)
parser.add_argument(
'in_context_4_results_file',
type=str,
)
parser.add_argument(
'in_context_8_results_file',
type=str,
)
parser.add_argument(
'in_context_12_results_file',
type=str,
)
parser.add_argument(
'figure_file',
type=str,
)
args = parser.parse_args()
results = {
'2': process_results(args.in_context_2_results_file),
'4': process_results(args.in_context_4_results_file),
'8': process_results(args.in_context_8_results_file),
'12': process_results(args.in_context_12_results_file),
}
fig = build_figure(results)
fig.savefig(args.figure_file)
if __name__ == '__main__':
main()