forked from jennyzhang0215/STAR-GCN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
106 lines (90 loc) · 3.81 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
import os
import yaml
import logging
from collections import OrderedDict, namedtuple
from mxgraph.helpers.ordered_easydict import OrderedEasyDict as edict
def _merge_two_config(user_cfg, default_cfg):
""" Merge user's config into default config dictionary, clobbering the
options in b whenever they are also specified in a.
Need to ensure the type of two val under same key are the same
Do recursive merge when encounter hierarchical dictionary
"""
if type(user_cfg) is not edict:
return
for key, val in user_cfg.items():
# Since user_cfg is a sub-file of default_cfg
if key not in default_cfg:
raise KeyError('{} is not a valid config key'.format(key))
if (type(default_cfg[key]) is not type(val) and
default_cfg[key] is not None):
if isinstance(default_cfg[key], np.ndarray):
val = np.array(val, dtype=default_cfg[key].dtype)
elif isinstance(default_cfg[key], (int, float)) and isinstance(val, (int, float)):
pass
else:
raise ValueError(
'Type mismatch ({} vs. {}) '
'for config key: {}'.format(type(default_cfg[key]),
type(val), key))
# Recursive merge config
if type(val) is edict:
try:
_merge_two_config(user_cfg[key], default_cfg[key])
except:
print('Error under config key: {}'.format(key))
raise
else:
default_cfg[key] = val
def cfg_from_file(file_name, target):
""" Load a config file and merge it into the default options.
"""
import yaml
with open(file_name, 'r') as f:
print('Loading YAML config file from %s' %f)
yaml_cfg = edict(yaml.load(f))
_merge_two_config(yaml_cfg, target)
def ordered_dump(data, stream=None, Dumper=yaml.SafeDumper, **kwds):
class OrderedDumper(Dumper):
pass
def _dict_representer(dumper, data):
return dumper.represent_mapping(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
data.items(), flow_style=False)
def _ndarray_representer(dumper, data):
return dumper.represent_list(data.tolist())
OrderedDumper.add_representer(OrderedDict, _dict_representer)
OrderedDumper.add_representer(edict, _dict_representer)
OrderedDumper.add_representer(np.ndarray, _ndarray_representer)
return yaml.dump(data, stream, OrderedDumper, **kwds)
def save_cfg_file(file_path, source):
source = source.copy()
masked_keys = ['DATASET_PATH', 'ROOT_DIR']
for key in masked_keys:
if key in source:
del source[key]
delattr(source, key)
with open(file_path, 'w') as f:
logging.info("Save YAML config file to %s" %file_path)
ordered_dump(source, f, yaml.SafeDumper, default_flow_style=None)
def save_cfg_dir(dir_path, source):
cfg_count = 0
file_path = os.path.join(dir_path, 'cfg%d.yml' %cfg_count)
while os.path.exists(file_path):
cfg_count += 1
file_path = os.path.join(dir_path, 'cfg%d.yml' % cfg_count)
save_cfg_file(file_path, source)
return cfg_count
def load_latest_cfg(dir_path, target):
import re
cfg_count = None
source_cfg_path = None
for fname in os.listdir(dir_path):
ret = re.search(r'cfg(\d+)\.yml', fname)
if ret != None:
if cfg_count is None or (int(re.group(1)) > cfg_count):
cfg_count = int(re.group(1))
source_cfg_path = os.path.join(dir_path, ret.group(0))
cfg_from_file(file_name=source_cfg_path, target=target)
# save_f_name = os.path.join("..", "experiments", "baselines", "our_implementation", "cfg_template","ml_100k.yml")
# save_cfg_file(save_f_name)