-
Notifications
You must be signed in to change notification settings - Fork 97
/
Copy pathpipeline.py
112 lines (103 loc) · 4.12 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import tomli
import shutil
import os
import argparse
from train import train
from sample import sample
from eval_catboost import train_catboost
from eval_mlp import train_mlp
from eval_simple import train_simple
import pandas as pd
import matplotlib.pyplot as plt
import zero
import lib
import torch
def load_config(path) :
with open(path, 'rb') as f:
return tomli.load(f)
def save_file(parent_dir, config_path):
try:
dst = os.path.join(parent_dir)
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copyfile(os.path.abspath(config_path), dst)
except shutil.SameFileError:
pass
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', metavar='FILE')
parser.add_argument('--train', action='store_true', default=False)
parser.add_argument('--sample', action='store_true', default=False)
parser.add_argument('--eval', action='store_true', default=False)
parser.add_argument('--change_val', action='store_true', default=False)
args = parser.parse_args()
raw_config = lib.load_config(args.config)
if 'device' in raw_config:
device = torch.device(raw_config['device'])
else:
device = torch.device('cuda:1')
timer = zero.Timer()
timer.run()
save_file(os.path.join(raw_config['parent_dir'], 'config.toml'), args.config)
if args.train:
train(
**raw_config['train']['main'],
**raw_config['diffusion_params'],
parent_dir=raw_config['parent_dir'],
real_data_path=raw_config['real_data_path'],
model_type=raw_config['model_type'],
model_params=raw_config['model_params'],
T_dict=raw_config['train']['T'],
num_numerical_features=raw_config['num_numerical_features'],
device=device,
change_val=args.change_val
)
if args.sample:
sample(
num_samples=raw_config['sample']['num_samples'],
batch_size=raw_config['sample']['batch_size'],
disbalance=raw_config['sample'].get('disbalance', None),
**raw_config['diffusion_params'],
parent_dir=raw_config['parent_dir'],
real_data_path=raw_config['real_data_path'],
model_path=os.path.join(raw_config['parent_dir'], 'model.pt'),
model_type=raw_config['model_type'],
model_params=raw_config['model_params'],
T_dict=raw_config['train']['T'],
num_numerical_features=raw_config['num_numerical_features'],
device=device,
seed=raw_config['sample'].get('seed', 0),
change_val=args.change_val
)
save_file(os.path.join(raw_config['parent_dir'], 'info.json'), os.path.join(raw_config['real_data_path'], 'info.json'))
if args.eval:
if raw_config['eval']['type']['eval_model'] == 'catboost':
train_catboost(
parent_dir=raw_config['parent_dir'],
real_data_path=raw_config['real_data_path'],
eval_type=raw_config['eval']['type']['eval_type'],
T_dict=raw_config['eval']['T'],
seed=raw_config['seed'],
change_val=args.change_val
)
elif raw_config['eval']['type']['eval_model'] == 'mlp':
train_mlp(
parent_dir=raw_config['parent_dir'],
real_data_path=raw_config['real_data_path'],
eval_type=raw_config['eval']['type']['eval_type'],
T_dict=raw_config['eval']['T'],
seed=raw_config['seed'],
change_val=args.change_val,
device=device
)
elif raw_config['eval']['type']['eval_model'] == 'simple':
train_simple(
parent_dir=raw_config['parent_dir'],
real_data_path=raw_config['real_data_path'],
eval_type=raw_config['eval']['type']['eval_type'],
T_dict=raw_config['eval']['T'],
seed=raw_config['seed'],
change_val=args.change_val
)
print(f'Elapsed time: {str(timer)}')
if __name__ == '__main__':
main()