-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
64 lines (51 loc) · 1.85 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from yacs.config import CfgNode as CN
_C = CN()
_C.SEED = 0
_C.SAVE_MODEL = False
_C.SAVE_STAT = False
_C.UNCERTAINTY = 'variance' # variance/entropy
_C.DATALOADER = CN()
_C.DATALOADER.BATCH_SIZE = 32
_C.DATALOADER.PIN_MEMORY = True
_C.DATALOADER.NUM_WORKERS = 4
_C.LOSS = CN()
_C.LOSS.FUNCTION = 'nll' # nll/ce/sos
_C.LOSS.ACTIVATION = 'exp' # relu/exp/softplus
_C.LOSS.LAMBDA_AU = 0.05
_C.UNCERTAINTY_SAMPLING = CN()
_C.UNCERTAINTY_SAMPLING.EPOCHS = [10, 12, 14, 16, 18]
_C.UNCERTAINTY_SAMPLING.ORDER = 'EU AU'
_C.UNCERTAINTY_SAMPLING.RATIO = [0.01, 0.01, 0.01, 0.01, 0.01]
_C.UNCERTAINTY_SAMPLING.KAPPA = 10
_C.CERTAINTY_SAMPLING = CN()
_C.CERTAINTY_SAMPLING.ORDER = 'EU'
_C.CERTAINTY_SAMPLING.RATIO = [0.01, 0.02, 0.03, 0.04, 0.05]
_C.DATASET = CN()
_C.DATASET.NAME = 'Office-Home' # Office-Home/Visda-2017
_C.PATHS = CN()
_C.TRAINER = CN()
if _C.DATASET.NAME == 'Office-Home':
_C.PATHS.DATA_DIR = 'C:/Users/kerry/OneDrive/Desktop/Projects/Datasets/OfficeHomeDataset_10072016'
_C.PATHS.OUTPUT_DIR = 'outputs'
_C.DATASET.NUM_CLASSES = 65
_C.DATASET.SOURCE_DOMAINS = ['Art', 'Clipart', 'Product', 'Real World']
_C.DATASET.TARGET_DOMAINS = ['Art', 'Clipart', 'Product', 'Real World']
_C.TRAINER.LR = 4e-3
_C.TRAINER.MAX_EPOCHS = 50
_C.TRAINER.EVAL_INTERVAL = 5
_C.LOSS.LAMBDA_EU = 1.0 if _C.UNCERTAINTY == 'entropy' else 50.0
elif _C.DATASET.NAME == 'Visda-2017':
_C.PATHS.DATA_DIR = ...
_C.PATHS.OUTPUT_DIR = ...
_C.DATASET.NUM_CLASSES = 12
_C.DATASET.SOURCE_DOMAINS = ['train']
_C.DATASET.TARGET_DOMAINS = ['validation']
_C.TRAINER.LR = 1e-3
_C.TRAINER.MAX_EPOCHS = 40
_C.TRAINER.EVAL_INTERVAL = 2
_C.LOSS.LAMBDA_EU = 1.0 if _C.UNCERTAINTY == 'entropy' else 10.0
else:
raise NotImplementedError(f'Dataset not implemented: {_C.DATASET.NAME}')
def get_cfg_defaults():
return _C.clone()
cfg = _C