-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
160 lines (132 loc) · 4.87 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import sys
from itertools import product
import argparse
sys.path.insert(0, 'src')
from util import *
from data import Data
from models import *
def parse() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run script using configuration defined in `config/`"
)
parser.add_argument(
"target", choices=['test', 'experiment', 'exp', 'hyperparameter', 'ht'],
type=str, default='experiment',
help="run target. Default experiment; if test is selected, ignore all other flags."
)
parser.add_argument("-d", "--data", type=str, help="data path", default='nyt/coarse')
parser.add_argument(
"-m", "--model", type=str, nargs='+', choices=['tfidf', 'w2v'],
help="models to run", action='extend', default=[]
)
parser.add_argument("-s", "--stem", action='store_true', help="only used in experiments")
parser.add_argument("-o", "--output", action='store_true', help="result filepath, only used in experiments")
parser.add_argument(
"-p", "--plot", action='store_true',
help="visualize document length distribution, only used in experiments"
)
return parser.parse_args()
def test() -> None:
"""
Run test target
"""
print('Running Test Data Target:')
d = Data(
data_dir='test/',
dataset='testdata',
stem=False,
special_tokens=True
)
Tfidf_Model.run(d)
config = load_config('config/test_w2v_config.json')
Word2Vec_Model.run(d, **config)
def experiment(dataset: str, stem, models: list, viz=False) -> None:
print('Running Experiment Target:')
d = Data(
data_dir='data/',
dataset=dataset,
stem=stem,
special_tokens=True
)
if viz:
d.plot_distribution()
# run TF-IDF Model
if 'tfidf' in models:
config = load_config('config/exp_tfidf_config.json')
tfidf_result = Tfidf_Model.run(d, **config)
if args.output:
write_result('results/tfidf_runs.json', **tfidf_result)
# run Word2Vec Model
if 'w2v' in models:
config = load_config('config/exp_w2v_config.json')
w2v_result, _ = Word2Vec_Model.run(d, **config)
if args.output:
write_result('results/w2v_runs.json', **w2v_result)
def tune(dataset: str, models: list) -> None:
print('Running Hyparameter Target:')
# run TF-IDF Model
if 'tfidf' in models:
config = load_config('config/ht_tfidf_config.json')
tfidf_result = dict()
keys, values = zip(*config.items())
for bundle in product(*values):
# load
one_config = dict(zip(keys, bundle))
d = Data(
data_dir='data/',
dataset=dataset,
stem=one_config['stem'],
special_tokens=True
)
result = Tfidf_Model.run(d, **one_config)
del d
one_tfidf_result = flatten_dict(result)
combine_dict(tfidf_result, one_tfidf_result)
write_result('results/tfidf_ht.json', **tfidf_result)
# run Word2Vec Model
if 'w2v' in models:
config = load_config('config/ht_w2v_config.json')
w2v_result = dict()
best_model, best_macro, best_config = None, 0, None
keys, values = zip(*config.items())
for bundle in product(*values):
# load
one_config = {'model_params': dict(zip(keys, bundle))}
d = Data(
data_dir='data/',
dataset=dataset,
stem=False if dataset == '20news/fine' else True,
special_tokens=True
)
one_w2v_result, model = Word2Vec_Model.run(d, **one_config)
del d
one_w2v_result = flatten_dict(one_w2v_result)
combine_dict(w2v_result, one_w2v_result)
# check best
if one_w2v_result['macro_f1'] > best_macro:
best_model, best_macro, best_config = (
model, one_w2v_result['macro_f1'], one_config['model_params']
)
best_str = f'best_{dataset.replace("/", "_")}'
best_model.save(f'models/new_ht/{best_str}.model')
with open(f'models/new_ht/{best_str}_config.json', 'w') as f:
json.dump(best_config, f)
write_result('results/w2v_ht.json', **w2v_result)
if __name__ == "__main__":
# parse command-line arguments
args = parse()
# test target
if args.target == 'test':
test()
else:
# load data
if len(args.model) == 0:
models = ['tfidf', 'w2v']
else:
models = list(set(args.model))
# experiment target
if args.target in ['experiment', 'exp']:
experiment(args.data, args.stem, models, args.plot)
# hyperparameter tuning target
elif args.target in ['hyperparameter', 'ht']:
tune(args.data, models)