-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconfig.py
68 lines (51 loc) · 1.74 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from pprint import pprint
import numpy as np
class Config:
# lr
lr = 1e-4
# warmup setting
warmup_action = True # False: not using warmup True: using warmup
warmup_lr = 1e-5
warmup_epoch = 2
# setting
image_size = 448 # input image resolution
keep_ratio = False
classes = ('ship',) # if only one class, should add ','
data_path = r'' # absolute data root path
output_path = r'' # absolute model output path
inshore_data_path = r'' # absolute Inshore data path
offshore_data_path = r'' # absolute Offshore data path
# training setting
Evaluate_val_start = 1
Evaluate_train_start = 1
save_interval = 2 # save weight file
val_interval = 2 # check the val result on validation set
# label assignment (Max IoU Assigner)
pos_iou_thr = 0.5
neg_iou_thr = 0.4
min_pos_iou = 0
low_quality_match = True # Low quality match
# anchor setting
base_size = 4
ratios = np.array([0.5, 1., 2.])
# ratios = np.array([1])
scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])
# scales = np.array([2 ** 0])
# reg loss func
# support method = ['giou', 'smooth_l1', 'l1_loss']
loss_func = 'smooth_l1'
# nms setting
nms_thr = 0.5 # for [email protected] NMS threshold [nms_thr = 0.5]
score_thr = 0.05 # follow mmdet RetinaNet settings [score_thr = 0.05]
def _state_dict(self):
return {k: getattr(self, k) for k, _ in Config.__dict__.items() \
if not k.startswith('_')}
def _print_cfg(self):
pprint(self._state_dict())
cfg = Config()
if __name__ == '__main__':
cfg._print_cfg()
# Check
print(f'the number of the classes: {len(cfg.classes)}')
for cat in cfg.classes:
print(cat)