Skip to content

Commit

Permalink
Merge branch 'mandarin'
Browse files Browse the repository at this point in the history
Signed-off-by: begeekmyfriend <[email protected]>
  • Loading branch information
begeekmyfriend committed Dec 10, 2018
2 parents e7e1ee5 + 0fbbb8c commit c4632e9
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 24 deletions.
2 changes: 1 addition & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def run_eval(args):
synth.load(args.checkpoint)
base_path = get_output_base_path(args.checkpoint)
for i, text in enumerate(sentences):
path = '%s-%d.wav' % (base_path, i)
path = '%s-%03d.wav' % (base_path, i)
print('Synthesizing: %s' % path)
with open(path, 'wb') as f:
f.write(synth.synthesize(text))
Expand Down
16 changes: 9 additions & 7 deletions hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,25 @@
cleaners='english_cleaners',

# Audio:
num_mels=80,
num_mels=160,
num_freq=1025,
sample_rate=24000,
frame_length_ms=50,
frame_shift_ms=12.5,
preemphasis=0.97,
min_level_db=-100,
ref_level_db=20,
max_frame_num=1000,
max_abs_value = 4,

# Model:
outputs_per_step=5,
embed_depth=256,
prenet_depths=[256, 128],
embed_depth=512,
prenet_depths=[256, 256],
encoder_depth=256,
postnet_depth=256,
attention_depth=256,
decoder_depth=256,
postnet_depth=512,
attention_depth=128,
decoder_depth=1024,

# Training:
batch_size=32,
Expand All @@ -38,7 +40,7 @@
# Eval:
max_iters=300,
griffin_lim_iters=60,
power=1.5, # Power to raise magnitudes to prior to Griffin-Lim
power=1.2, # Power to raise magnitudes to prior to Griffin-Lim
)


Expand Down
40 changes: 38 additions & 2 deletions models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ def next_inputs(self, time, outputs, state, sample_ids, stop_token_preds, name=N


class TacoTrainingHelper(Helper):
def __init__(self, inputs, targets, output_dim, r):
def __init__(self, inputs, targets, output_dim, r, global_step):
# inputs is [N, T_in], targets is [N, T_out, D]
with tf.name_scope('TacoTrainingHelper'):
self._batch_size = tf.shape(inputs)[0]
self._output_dim = output_dim
self._reduction_factor = r
self._ratio = tf.convert_to_tensor(1.)
self.global_step = global_step

# Feed every r-th target frame as input
self._targets = targets[:, r-1::r, :]
Expand All @@ -80,6 +82,7 @@ def sample_ids_dtype(self):
return np.int32

def initialize(self, name=None):
self._ratio = _teacher_forcing_ratio_decay(1., self.global_step)
return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim))

def sample(self, time, outputs, state, name=None):
Expand All @@ -88,11 +91,44 @@ def sample(self, time, outputs, state, name=None):
def next_inputs(self, time, outputs, state, sample_ids, stop_token_preds, name=None):
with tf.name_scope(name or 'TacoTrainingHelper'):
finished = (time + 1 >= self._lengths)
next_inputs = self._targets[:, time, :] # Teacher forcing: feed the true frame

#Pick previous outputs randomly with respect to teacher forcing ratio
next_inputs = tf.cond(tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32), self._ratio),
lambda: self._targets[:, time, :], #Teacher-forcing: return true frame
lambda: outputs[:,-self._output_dim:])

# next_inputs = self._targets[:, time, :] # Teacher forcing: feed the true frame
return (finished, next_inputs, state)


def _go_frames(batch_size, output_dim):
'''Returns all-zero <GO> frames for a given batch size and output dimension'''
return tf.tile([[0.0]], [batch_size, output_dim])

def _teacher_forcing_ratio_decay(init_tfr, global_step):
#################################################################
# Narrow Cosine Decay:

# Phase 1: tfr = 1
# We only start learning rate decay after 10k steps

# Phase 2: tfr in ]0, 1[
# decay reach minimal value at step ~280k

# Phase 3: tfr = 0
# clip by minimal teacher forcing ratio value (step >~ 280k)
#################################################################
#Compute natural cosine decay
tfr = tf.train.cosine_decay(init_tfr,
global_step=global_step - 20000, #tfr = 1 at step 10k
decay_steps=200000, #tfr = 0 at step ~280k
alpha=0., #tfr = 0% of init_tfr as final value
name='tfr_cosine_decay')

#force teacher forcing ratio to take initial value when global step < start decay step.
narrow_tfr = tf.cond(
tf.less(global_step, tf.convert_to_tensor(20000)),
lambda: tf.convert_to_tensor(init_tfr),
lambda: tfr)

return narrow_tfr
15 changes: 9 additions & 6 deletions models/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def cbhg(inputs, input_lengths, is_training, scope, K, projections, depth):
with tf.variable_scope('conv_bank'):
# Convolution bank: concatenate on the last axis to stack channels from all convolutions
conv_outputs = tf.concat(
[conv1d(inputs, k, 128, tf.nn.relu, is_training, 'conv1d_%d' % k) for k in range(1, K+1)],
[conv1d(inputs, k, 128, tf.nn.relu, is_training, 0.5, 'conv1d_%d' % k) for k in range(1, K+1)],
axis=-1
)

Expand All @@ -52,8 +52,8 @@ def cbhg(inputs, input_lengths, is_training, scope, K, projections, depth):
padding='same')

# Two projection layers:
proj1_output = conv1d(maxpool_output, 3, projections[0], tf.nn.relu, is_training, 'proj_1')
proj2_output = conv1d(proj1_output, 3, projections[1], None, is_training, 'proj_2')
proj1_output = conv1d(maxpool_output, 3, projections[0], tf.nn.relu, is_training, 0.5, 'proj_1')
proj2_output = conv1d(proj1_output, 3, projections[1], lambda _:_, is_training, 0.5, 'proj_2')

# Residual connection:
highway_input = proj2_output + inputs
Expand Down Expand Up @@ -96,12 +96,15 @@ def highwaynet(inputs, scope, depth):
return H * T + inputs * (1.0 - T)


def conv1d(inputs, kernel_size, channels, activation, is_training, scope):
def conv1d(inputs, kernel_size, channels, activation, is_training, drop_rate, scope):
if not is_training: drop_rate = 0.0
with tf.variable_scope(scope):
conv1d_output = tf.layers.conv1d(
inputs,
filters=channels,
kernel_size=kernel_size,
activation=activation,
activation=None,
padding='same')
return tf.layers.batch_normalization(conv1d_output, training=is_training)
batched = tf.layers.batch_normalization(conv1d_output, training=is_training)
activated = activation(batched)
return tf.layers.dropout(activated, rate=drop_rate, training=is_training, name='dropout_{}'.format(scope))
4 changes: 2 additions & 2 deletions models/tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, hparams):
self._hparams = hparams


def initialize(self, inputs, input_lengths, mel_targets=None, linear_targets=None, stop_token_targets=None):
def initialize(self, inputs, input_lengths, mel_targets=None, linear_targets=None, stop_token_targets=None, global_step=None):
'''Initializes the model for inference.
Sets "mel_outputs", "linear_outputs", and "alignments" fields.
Expand Down Expand Up @@ -67,7 +67,7 @@ def initialize(self, inputs, input_lengths, mel_targets=None, linear_targets=Non
frame_projection, stop_projection)

if is_training:
helper = TacoTrainingHelper(inputs, mel_targets, hp.num_mels, hp.outputs_per_step)
helper = TacoTrainingHelper(inputs, mel_targets, hp.num_mels, hp.outputs_per_step, global_step)
else:
helper = TacoTestHelper(batch_size, hp.num_mels, hp.outputs_per_step)

Expand Down
1 change: 0 additions & 1 deletion synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def synthesize(self, text):
}
wav = self.session.run(self.wav_output, feed_dict=feed_dict)
wav = audio.inv_preemphasis(wav)
wav = wav[:audio.find_endpoint(wav)]
out = io.BytesIO()
audio.save_wav(wav, out)
return out.getvalue()
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def train(log_dir, args):
global_step = tf.Variable(0, name='global_step', trainable=False)
with tf.variable_scope('model') as scope:
model = create_model(args.model, hparams)
model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets, feeder.stop_token_targets)
model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets, feeder.stop_token_targets, global_step)
model.add_loss()
model.add_optimizer(global_step)
stats = add_stats(model)
Expand Down
25 changes: 21 additions & 4 deletions util/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,24 @@ def load_wav(path):


def save_wav(wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
# rescaling for unified measure for all clips
wav = wav / np.abs(wav).max() * 0.999
# factor 0.5 in case of overflow for int16
f1 = 0.5 * 32767 / max(0.01, np.max(np.abs(wav)))
# sublinear scaling as Y ~ X ^ k (k < 1)
f2 = np.sign(wav) * np.power(np.abs(wav), 0.667)
wav = f1 * f2
# bandpass for less noises
firwin = signal.firwin(hparams.num_freq, [75, 7600], pass_zero=False, fs=hparams.sample_rate)
wav = signal.convolve(wav, firwin)

wavfile.write(path, hparams.sample_rate, wav.astype(np.int16))


def trim_silence(wav):
return librosa.effects.trim(wav, top_db= 60, frame_length=512, hop_length=128)[0]


def preemphasis(x):
return signal.lfilter([1, -hparams.preemphasis], [1], x)

Expand Down Expand Up @@ -143,10 +157,13 @@ def _db_to_amp_tensorflow(x):
return tf.pow(tf.ones(tf.shape(x)) * 10.0, x * 0.05)

def _normalize(S):
return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1)
# symmetric mels
return 2 * hparams.max_abs_value * ((S - hparams.min_level_db) / -hparams.min_level_db) - hparams.max_abs_value

def _denormalize(S):
return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db
# symmetric mels
return ((S + hparams.max_abs_value) * -hparams.min_level_db) / (2 * hparams.max_abs_value) + hparams.min_level_db

def _denormalize_tensorflow(S):
return (tf.clip_by_value(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db
# symmetric mels
return ((S + hparams.max_abs_value) * -hparams.min_level_db) / (2 * hparams.max_abs_value) + hparams.min_level_db

0 comments on commit c4632e9

Please sign in to comment.