forked from bfs18/nsynth_wavenet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig_str.py
132 lines (103 loc) · 3.74 KB
/
config_str.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from wavenet import parallel_wavenet, wavenet, masked
import json
import subprocess
from argparse import Namespace
import time
from auxilaries import reader
def get_config_srt(hparams, model, tag=''):
prefix = 'ns_' # nsynth
if model == 'wavenet':
model_str = 'wn'
elif model == 'parallel_wavenet':
model_str = 'pwn'
else:
raise ValueError('unsupported model type {}'.format(model))
branch_names = subprocess.check_output(['git', 'branch'])
current_branch_name = [bn for bn in branch_names.decode('utf-8').split('\n')
if '*' in bn][0]
current_branch_name = current_branch_name.split()[1]
if getattr(hparams, 'use_mu_law', False):
mu_law_tag = 'MU'
else:
mu_law_tag = 'n_MU'
if getattr(hparams, 'use_weight_norm', False):
weight_norm_tag = 'WN'
if current_branch_name == 'data_dep_init':
weight_norm_tag += '_DDI'
if parallel_wavenet.MANUAL_FINAL_INIT and model == 'parallel_wavenet':
weight_norm_tag += '_mfinit'
else:
weight_norm_tag = 'n_WN'
loss_type = getattr(hparams, 'loss_type', '').upper()
cstr = '-'.join([prefix + model_str, mu_law_tag, weight_norm_tag])
if reader.USE_NEW_MEL_EXTRACTOR:
cstr += '-NM'
if getattr(hparams, 'use_resize_conv', False):
cstr += '-RS'
else:
cstr += '-TS'
if getattr(hparams, 'use_input_noise', False):
cstr += '-IN'
else:
cstr += '-n_IN'
if getattr(hparams, 'dropout_inputs', False):
cstr += '-DO'
else:
cstr += '-n_DO'
upsample_act = getattr(hparams, 'upsample_act', 'tanh')
cstr += ('-' + upsample_act)
if model == 'parallel_wavenet':
if parallel_wavenet.USE_LOG_SCALE:
cstr += '-LOGS'
else:
cstr += '-n_LOGS'
if parallel_wavenet.CLIP:
cstr += '-CLIP'
else:
cstr += '-n_CLIP'
if parallel_wavenet.SPEC_ENHANCE_FACTOR == 0:
cstr += '-NLABS' if parallel_wavenet.NORM_FEAT else '-LABS'
elif parallel_wavenet.SPEC_ENHANCE_FACTOR == 1:
cstr += '-NABS' if parallel_wavenet.NORM_FEAT else '-ABS'
elif parallel_wavenet.SPEC_ENHANCE_FACTOR == 2:
cstr += '-NPOW' if parallel_wavenet.NORM_FEAT else '-POW'
elif parallel_wavenet.SPEC_ENHANCE_FACTOR == 3:
cstr += '-NCOM' if parallel_wavenet.NORM_FEAT else '-COM'
else:
raise ValueError("SPEC_ENHANCE_FACTOR Value Error.")
if parallel_wavenet.USE_MEL:
cstr += '-MEL'
else:
cstr += '-n_MEL'
if parallel_wavenet.USE_L1_LOSS:
cstr += '-L1'
else:
cstr += '-L2'
if parallel_wavenet.USE_PRIORITY_FREQ:
cstr += '-PFS'
else:
cstr += '-n_PFS'
if model == 'wavenet' and getattr(hparams, 'add_noise', False):
cstr += '-NOISE'
if loss_type:
cstr += '-{}'.format(loss_type)
if tag:
cstr += '-{}'.format(tag)
return cstr
def get_time_str():
return time.strftime("%m_%d", time.localtime())
def get_config_time_str(hparams, model, tag=''):
cstr = get_config_srt(hparams, model, tag) + '-' + get_time_str()
return cstr
if __name__ == '__main__':
config1 = '../config_jsons/wavenet_mol.json'
with open(config1, 'rt') as F:
configs = json.load(F)
hparams = Namespace(**configs)
print(get_config_srt(hparams, 'wavenet'))
config1 = '../config_jsons/parallel_wavenet.json'
with open(config1, 'rt') as F:
configs = json.load(F)
hparams = Namespace(**configs)
print(get_config_srt(hparams, 'parallel_wavenet'))
print(get_config_time_str(hparams, 'parallel_wavenet'))