-
Notifications
You must be signed in to change notification settings - Fork 10
/
train.py
144 lines (121 loc) · 4.01 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import shutil
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import argparse
import os.path
import datetime
import yaml
from core.trainer import Trainer
def main(args):
assert os.path.exists(args.dataset),(f'The dataset dir [{args.dataset}] doesn\'t exist.')
assert os.path.exists(args.model_cfg),(f'The model config [{args.model_cfg}] doesn\'t exist.')
assert os.path.exists(args.data_cfg),(f'The dataset config [{args.data_cfg}] doesn\'t exist.')
print('-----------------')
print(f'dataset dir: {args.dataset}')
print(f'log dir: {args.log}')
print(f'model config: {args.model_cfg}')
print(f'dataset config: {args.data_cfg}')
print(f'checkpoint: {args.checkpoint}')
print('-----------------')
# load dataset config
try:
data_cfg = yaml.safe_load(open(args.data_cfg,'r'))
except Exception as e:
print(e)
print("Error opening data yaml file.")
quit()
# load model config
try:
model_cfg = yaml.safe_load(open(args.model_cfg,'r'))
except Exception as e:
print(e)
print("Error opening model yaml file.")
quit()
# create log dir for saving model
if args.log is None:
root,_ = os.path.split(os.path.abspath(__file__))
default= root + '/log/' + \
datetime.datetime.now().strftime("%Y-%-m-%d-%H:%M") + '/'
print(default)
print(f'The log dir [{args.log}] doesn\'t exist, Do you want to use'
f'[{default}] as default? [y/n]')
sig = input()
if sig in ['Y','y']:
args.log = default
if not os.path.exists(root+'/log/'): os.mkdir(root+'/log/')
if not os.path.exists(default): os.mkdir(default)
else:
print(f'Check the log dir.')
quit()
else:
print(f'Check the log dir.')
quit()
else:
if os.path.isdir(args.log):
if os.listdir(args.log):
print(f'Log dir [{args.log}] is not empty, Do you want to proceed? [y/n]')
sig = input()
if sig in ['Y', 'y']:
shutil.rmtree(args.log)
os.mkdir(args.log)
else:
print(f'Check the log dir.')
quit()
else:
print(f'Using the dir [{args.log}] to contains the log? [y/n]')
sig = input()
if sig in ['Y', 'y']:
os.mkdir(args.log)
else:
print(f'Check the log dir.')
quit()
# check if use pretrained model
if args.checkpoint is not None:
if os.path.isfile(args.checkpoint) and args.checkpoint.endswith('.ckpt'):
print(f'Using the pretrained model:[{args.checkpoint}]')
trainer = Trainer(args=args,model_cfg=model_cfg,data_cfg=data_cfg)
trainer.train()
if __name__=='__main__':
root,_ = os.path.split(os.path.abspath(__file__))
parser = argparse.ArgumentParser('Trainning Model.')
parser.add_argument(
'--dataset','-d',
type=str,required=True,
help='the root dir of datasets'
)
parser.add_argument(
'--checkpoint','-ckpt',
type=str,
default=None,
help='the path for loading checkpoint'
)
parser.add_argument(
'--log','-l',
type=str,
default = None,
help='the dir to save log'
)
parser.add_argument(
'--model_cfg','-m',
type=str,
default=root + '/config/model.yaml',
help='the config of model'
)
parser.add_argument(
'--data_cfg','-dc',
type=str,
default=root + '/config/semantic-kitti.yaml',
help='the config of model'
)
parser.add_argument(
'--freeze_layers',
action='store_true',
default=False
)
parser.add_argument(
'--device',
default='cpu',
help='device id (i.e. 0 or 0,1 or cuda)'
)
args,unparsed = parser.parse_known_args()
main(args)