Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monitoring metric #284

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
36 changes: 36 additions & 0 deletions layers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,40 @@
import torch
import numpy as np
from librosa.filters import mel as librosa_mel_fn
from audio_processing import dynamic_range_compression
from audio_processing import dynamic_range_decompression
from stft import STFT

def dct(x, norm=None):
"""
Discrete Cosine Transform, Type II (a.k.a. the DCT)
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last dimension
"""
x_shape = x.shape
N = x_shape[-1]
x = x.contiguous().view(-1, N)

v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

Vc = torch.rfft(v, 1, onesided=False)

k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)

V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

if norm == 'ortho':
V[:, 0] /= np.sqrt(N) * 2
V[:, 1:] /= np.sqrt(N / 2) * 2

V = 2 * V.view(*x_shape)

return V

class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
Expand Down Expand Up @@ -60,6 +91,11 @@ def spectral_de_normalize(self, magnitudes):
output = dynamic_range_decompression(magnitudes)
return output

def cepstrum_from_mel(self, mel):
#magnitudes = self.spectral_de_normalize(mel)
mcc = dct(mel,'ortho')
return mcc

def mel_spectrogram(self, y):
"""Computes mel-spectrograms from a batch of waves
PARAMS
Expand Down
20 changes: 14 additions & 6 deletions logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,25 @@ class Tacotron2Logger(SummaryWriter):
def __init__(self, logdir):
super(Tacotron2Logger, self).__init__(logdir)

def log_training(self, reduced_loss, grad_norm, learning_rate, duration,
def log_training(self, reduced_loss, grad_norm, learning_rate, duration, diagonality, avg_prob, avg_MCD, avg_f0,
iteration):
self.add_scalar("training.loss", reduced_loss, iteration)
self.add_scalar("grad.norm", grad_norm, iteration)
self.add_scalar("learning.rate", learning_rate, iteration)
self.add_scalar("duration", duration, iteration)
self.add_scalar("training.loss", reduced_loss, iteration)
self.add_scalar("grad.norm", grad_norm, iteration)
self.add_scalar("learning.rate", learning_rate, iteration)
self.add_scalar("duration", duration, iteration)
self.add_scalar("training.attention_alignment_diagonality", diagonality, iteration)
self.add_scalar("training.average_max_attention_weight", avg_prob, iteration)
self.add_scalar("training.log_MCD", avg_MCD, iteration)
self.add_scalar("training.f0(100hz)", avg_f0, iteration)

def log_validation(self, reduced_loss, model, y, y_pred, iteration):
def log_validation(self, reduced_loss, model, y, y_pred, diagonality, avg_prob, avg_MCD, avg_f0, iteration):
self.add_scalar("validation.loss", reduced_loss, iteration)
_, mel_outputs, gate_outputs, alignments = y_pred
mel_targets, gate_targets = y
self.add_scalar("validation.attention_alignment_diagonality", diagonality, iteration)
self.add_scalar("validation.average_max_attention_weight", avg_prob, iteration)
self.add_scalar("validation.log_MCD", avg_MCD, iteration)
self.add_scalar("validation.f0(100hz)", avg_f0, iteration)

# plot distribution of parameters
for tag, value in model.named_parameters():
Expand Down
115 changes: 115 additions & 0 deletions metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
from torch.autograd import Variable
import numpy as np
from utils import load_wav_to_torch
from hparams import create_hparams
from layers import TacotronSTFT

def alignment_metric(alignments, input_lengths, output_lengths):
# alignments [batch size, x, y]
# input_lengths [batch size] for len_x
# output_lengths [batch size] for len_y

batch_size = alignments.size(0)
optimums = torch.sqrt(input_lengths.double()**2 + output_lengths.double()**2)

diagonalitys = torch.zeros(batch_size)
val_sum = torch.zeros(1)
for i in range(batch_size):
dist = torch.zeros(1)
for j in range(output_lengths[i]):
value, cur_idx = torch.max(alignments[i][:][j], 0)
val_sum += value
if j==0:
prev_idx = cur_idx
continue
else:
dist += (1 + (cur_idx - prev_idx).pow(2)).float().pow(0.5)
prev_idx = cur_idx
diagonalitys[i] = Variable(dist/optimums[i])
avg_prob = Variable(val_sum / torch.sum(output_lengths).float())
diagonality = torch.mean(diagonalitys)
return diagonality, avg_prob

def evaluation_metrics(stft, source_mels, target_mels):
batch_size = source_mels.size(0)
MCDs = torch.zeros(batch_size)
f0s = None
for i in range(batch_size):
src_mel = source_mels[i].unsqueeze(0)
src_mel = torch.clamp(src_mel, min=-4.0, max=4.0)
dst_mel = target_mels[i].unsqueeze(0)
dst_mel = torch.clamp(dst_mel, min=-4.0, max=4.0)
MCDs[i] = MCD_from_mels(stft, src_mel, dst_mel)
f0 = sqDiffF0_from_mels(stft, src_mel, dst_mel)
f0s = f0 if f0s is None else torch.cat((f0s, f0), 0)

avg_MCD = torch.mean(MCDs)
avg_f0 = torch.mean(f0s)

return avg_MCD, avg_f0

def melCepDist(srcMCC, dstMCC):
# https://dsp.stackexchange.com/questions/56391/mel-cepstral-distortion
diff = dstMCC - srcMCC
return torch.sum((torch.sqrt( 2 * (diff**2) ) ))* (10.0/np.log(10)) * 1/diff.size(1)

def f0(MCC):
#print(MCC.shape, MCC.max(), MCC.min())
_, f0 = MCC.max(0)
return f0

def MCD_from_mels(stft, srcMel, dstMel):
srcMCC = stft.cepstrum_from_mel(srcMel)[0,:25,:]
#print('srcMCC: ', srcMCC.max(), srcMCC.min())
dstMCC = stft.cepstrum_from_mel(dstMel)[0,:25,:]
#print('dstMCC: ', dstMCC.max(), dstMCC.min())
MCD = melCepDist(srcMCC,dstMCC)
log_MCD = torch.log10(torch.clamp(MCD,min=1e-5))
return log_MCD

def sqDiffF0_from_mels(stft, srcMel, dstMel):
srcMCC = stft.cepstrum_from_mel(srcMel).squeeze(0)
dstMCC = stft.cepstrum_from_mel(dstMel).squeeze(0)
srcF0 = f0(srcMCC)
dstF0 = f0(dstMCC)
diff = (dstF0 - srcF0).double()
return torch.sqrt(diff**2)

def test_MCD_and_f0():
hparams = create_hparams()
stft = TacotronSTFT(
hparams.filter_length, hparams.hop_length, hparams.win_length,
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
hparams.mel_fmax)
audio_path = 'kakao/1/1_0001.wav'
mel_path = 'kakao/1/1_0001.mel.npy'
srcMel = torch.from_numpy(np.load(mel_path)).unsqueeze(0)
srcMel = torch.clamp(srcMel, -4.0, 4.0)
# print(srcMel.shape, srcMel.max(), srcMel.min())
audio, sr = load_wav_to_torch(audio_path)
# print(audio.shape, audio.max(), audio.min())
audio_norm = audio / hparams.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)

# print(audio_norm.shape, audio_norm.max(), audio_norm.min())
dstMel = stft.mel_spectrogram(audio_norm)
# print(dstMel.shape, dstMel.max(), dstMel.min())
# mcc = stft.cepstrum_from_audio(audio_norm)
# print('mcc', mcc.shape, mcc.max(), mcc.min())

log_MCD = MCD_from_mels(stft, srcMel, dstMel)
print(log_MCD.data, 'log')

sqrtDiffF0 = sqDiffF0_from_mels(stft, srcMel, dstMel)
print(sqrtDiffF0)
meanSqrtDiffF0 = torch.mean(sqrtDiffF0)
print(meanSqrtDiffF0.data, '100hz')

#alignment_metric()
if __name__ == "__main__":
test_MCD_and_f0()


#np.save('mel.npy' ,mel)
87 changes: 87 additions & 0 deletions preprocess_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
This code was developed with reference to https://github.com/Rayhane-mamah/Tacotron-2.
"""
from scipy.io.wavfile import write
import librosa
import numpy as np
import argparse

sr = 22050
max_wav_value=32768.0
trim_fft_size = 1024
trim_hop_size = 256

# These are control parameters for trimming and skipping
trim_top_db = 23
skip_len = 14848

def preprocess_audio(file_list, silence_audio_size, pre_emphasis=False):
for F in file_list:
f = open(F, encoding='utf-8')
R = f.readlines()
f.close()
print('='*5+F+'='*5)

for i, r in enumerate(R):
wav_file = r.split('|')[0]
data, sampling_rate = librosa.core.load(wav_file, sr)
data = data / np.abs(data).max() *0.999
data_= librosa.effects.trim(data, top_db= trim_top_db, frame_length=trim_fft_size, hop_length=trim_hop_size)[0]
if (pre_emphasis):
data_ = np.append(data_[0], data_[1:] - 0.97 * data_[:-1])
data_ = data_ / np.abs(data_).max() * 0.999
data_ = data_ * max_wav_value
data_ = np.append(data_, [0.]*silence_audio_size)
data_ = data_.astype(dtype=np.int16)
write(wav_file, sr, data_)
#print(len(data),len(data_))
if(i%100 == 0):
print (i)

def remove_short_audios(file_name):
f = open(file_name,'r',encoding='utf-8')
R = f.readlines()
f.close()

L = []
for i, r in enumerate(R):
wav_file = r.split('|')[0]
data, sampling_rate = librosa.core.load(wav_file, sr)
if(len(data) >= skip_len):
L.append(r)
if (i % 100 == 0):
print(i)
tmp = file_name.split('.')
tmp.insert(1,'_skipped.')
skipped_file_name = "".join(tmp)
f = open(skipped_file_name,'w',encoding='utf-8')
f.writelines(L)
f.close()

if __name__ == "__main__":
"""
usage
python preprocess_dataset.py -f=metadata.csv -s=5 -t -p -r
python preprocess_dataset.py -f=metadata.csv
"""
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--file_list', type=str,
help='Metadata file list to preprocess')
parser.add_argument('-s', '--silence_padding', type=int, default=0,
help='Adding silence padding at the end of each audio, silence audio size is hop_length * silence padding')
parser.add_argument('-p', '--pre_emphasis', action='store_true',
help="Doing pre_emphasis")
parser.add_argument('-t', '--trimming', action='store_true',
help="Doing trimming audios")
parser.add_argument('-r', '--remove_short_audios',action='store_true',
help="Removing short audios in metadata file")
args = parser.parse_args()
file_list = args.file_list.split(',')
silence_audio_size = trim_hop_size * args.silence_padding


preprocess_audio(file_list, silence_audio_size, args.pre_emphasis)

if(args.remove_short_audios):
for f in file_list:
remove_short_audios(f)
43 changes: 37 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from loss_function import Tacotron2Loss
from logger import Tacotron2Logger
from hparams import create_hparams

from metric import alignment_metric, evaluation_metrics
import layers

def reduce_tensor(tensor, n_gpus):
rt = tensor.clone()
Expand Down Expand Up @@ -119,7 +120,7 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):


def validate(model, criterion, valset, iteration, batch_size, n_gpus,
collate_fn, logger, distributed_run, rank):
collate_fn, logger, distributed_run, rank, stft):
"""Handles all the validation scoring and printing"""
model.eval()
with torch.no_grad():
Expand All @@ -129,21 +130,38 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus,
pin_memory=False, collate_fn=collate_fn)

val_loss = 0.0
diagonality = torch.zeros(1)
avg_prob = torch.zeros(1)
avg_MCD = torch.zeros(1)
avg_f0 = torch.zeros(1)
for i, batch in enumerate(val_loader):
x, y = model.parse_batch(batch)
y_pred = model(x)
_, input_lengths, mel_padded, _, output_lengths = x
_, mel_outputs_postnet, _, alignments = y_pred
loss = criterion(y_pred, y)
if distributed_run:
reduced_val_loss = reduce_tensor(loss.data, n_gpus).item()
else:
reduced_val_loss = loss.item()
val_loss += reduced_val_loss
rate, prob = alignment_metric(alignments, input_lengths, output_lengths)
MCD, f0 = evaluation_metrics(stft, mel_padded, mel_outputs_postnet)
diagonality += rate
avg_prob += prob
avg_MCD += MCD
avg_f0 += f0
diagonality=diagonality / (i + 1)
avg_prob = avg_prob / (i + 1)
val_loss = val_loss / (i + 1)
avg_MCD = avg_MCD / (i + 1)
avg_f0 = avg_f0 / (i + 1)

model.train()
if rank == 0:
print("Validation loss {}: {:9f} ".format(iteration, val_loss))
logger.log_validation(val_loss, model, y, y_pred, iteration)
logger.log_validation(val_loss, model, y, y_pred, diagonality, avg_prob, avg_MCD, avg_f0, iteration)



def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
Expand All @@ -159,6 +177,11 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
rank (int): rank of current gpu
hparams (object): comma separated list of "name=value" pairs.
"""
stft = layers.TacotronSTFT(
hparams.filter_length, hparams.hop_length, hparams.win_length,
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
hparams.mel_fmax)

if hparams.distributed_run:
init_distributed(hparams, n_gpus, rank, group_name)

Expand Down Expand Up @@ -214,11 +237,15 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
x, y = model.parse_batch(batch)
y_pred = model(x)

_, input_lengths, mel_padded, _, output_lengths = x
_, mel_outputs_postnet, _, alignments = y_pred

loss = criterion(y_pred, y)
if hparams.distributed_run:
reduced_loss = reduce_tensor(loss.data, n_gpus).item()
else:
reduced_loss = loss.item()

if hparams.fp16_run:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
Expand All @@ -239,13 +266,17 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
duration = time.perf_counter() - start
print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
iteration, reduced_loss, grad_norm, duration))
logger.log_training(
reduced_loss, grad_norm, learning_rate, duration, iteration)
if (i % (hparams.iters_per_checkpoint / 10) == 0):
with torch.no_grad():
diagonality, avg_prob = alignment_metric(alignments, input_lengths, output_lengths)
avg_MCD, avg_f0 = evaluation_metrics(stft, mel_padded, mel_outputs_postnet)
logger.log_training(
reduced_loss, grad_norm, learning_rate, duration, diagonality, avg_prob, avg_MCD, avg_f0, iteration)

if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0):
validate(model, criterion, valset, iteration,
hparams.batch_size, n_gpus, collate_fn, logger,
hparams.distributed_run, rank)
hparams.distributed_run, rank, stft)
if rank == 0:
checkpoint_path = os.path.join(
output_directory, "checkpoint_{}".format(iteration))
Expand Down