Skip to content

Commit

Permalink
apply helpers to improve word-to-phoneme alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
keonlee9420 committed Feb 14, 2022
1 parent 0dd9170 commit 814cdda
Show file tree
Hide file tree
Showing 13 changed files with 377 additions and 56 deletions.
20 changes: 16 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ for some preparations.

For the forced alignment, [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/) (MFA) is used to obtain the alignments between the utterances and the phoneme sequences.
Pre-extracted alignments for the datasets are provided [here](https://drive.google.com/drive/folders/1fizpyOiQ1lG2UDaMlXnT3Ll4_j6Xwg7K?usp=sharing).
You have to unzip the files in `preprocessed_data/DATASET/TextGrid/`. Alternately, you can [run the aligner by yourself](https://montreal-forced-aligner.readthedocs.io/en/latest/aligning.html).
You have to unzip the files in `preprocessed_data/DATASET/TextGrid/`. Alternately, you can [run the aligner by yourself](https://montreal-forced-aligner.readthedocs.io/en/latest/user_guide/workflows/index.html).

After that, run the preprocessing script by
```
Expand All @@ -116,15 +116,25 @@ tensorboard --logdir output/log
to serve TensorBoard on your localhost.
The loss curves, synthesized mel-spectrograms, and audios are shown.

![](./img/tensorboard_loss.png)
<!-- ![](./img/tensorboard_loss.png)
![](./img/tensorboard_spec.png)
![](./img/tensorboard_audio.png)
![](./img/tensorboard_audio.png) -->

# Notes

- For vocoder, **HiFi-GAN** and **MelGAN** are supported.
- Speed ​​up the convergence of word-to-phoneme alignment in **LinguisticEncoder** by dividing long words into subwords and sorting the dataset by mel-spectrogram frame length.
- No ReLU activation and LayerNorm in **VariationalGenerator** to avoid mashed output.
- Speed ​​up the convergence of word-to-phoneme alignment in **LinguisticEncoder** by dividing long words into subwords and sorting the dataset by mel-spectrogram frame length.
- There are two kinds of helper loss to improve word-to-phoneme alignment: "ctc" and "dga". You can toggle them as follows:
```yaml
# In the train.yaml
aligner:
helper_type: "ctc" # ["ctc", "dga", "none"]
```
- "ctc": [Connectionist Temporal Classification (CTC)](https://dl.acm.org/doi/pdf/10.1145/1143844.1143891) Loss with forward-sum algorithm
- "dga": [Diagonal Guided Attention (DGA)](https://arxiv.org/abs/1710.08969) Loss
- The default setting is "ctc". If you set "none", no helper loss will be applied during training.
- Will be extended to a **multi-speaker TTS**.
<!-- - Two options for embedding for the **multi-speaker TTS** setting: training speaker embedder from scratch or using a pre-trained [philipperemy's DeepSpeaker](https://github.com/philipperemy/deep-speaker) model (as [STYLER](https://github.com/keonlee9420/STYLER) did). You can toggle it by setting the config (between `'none'` and `'DeepSpeaker'`).
- DeepSpeaker on VCTK dataset shows clear identification among speakers. The following figure shows the T-SNE plot of extracted speaker embedding.
Expand All @@ -141,3 +151,5 @@ Please cite this repository by the "[Cite this repository](https://github.blog/2
- [jaywalnut310's VITS](https://github.com/jaywalnut310/vits)
- [jaywalnut310's Glow-TTS](https://github.com/jaywalnut310/glow-tts)
- [keonlee9420's VAENAR-TTS](https://github.com/keonlee9420/VAENAR-TTS)
- [keonlee9420's Comprehensive-Transformer-TTS](https://github.com/keonlee9420/Comprehensive-Transformer-TTS) (CTC Loss)
- [keonlee9420's Comprehensive-Tacotron2](https://github.com/keonlee9420/Comprehensive-Tacotron2) (DGA Loss)
2 changes: 2 additions & 0 deletions config/LJSpeech/preprocess.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ preprocessing:
n_mel_channels: 80
mel_fmin: 0
mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder
aligner:
beta_binomial_scaling_factor: 1.
20 changes: 14 additions & 6 deletions config/LJSpeech/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ dist_config:
dist_url: "tcp://localhost:80000"
world_size: 1
path:
ckpt_path: "./output/ckpt/LJSpeech_fixing"
log_path: "./output/log/LJSpeech_fixing"
result_path: "./output/result/LJSpeech_fixing"
ckpt_path: "./output/ckpt/LJSpeech"
log_path: "./output/log/LJSpeech"
result_path: "./output/result/LJSpeech"
optimizer:
batch_size: 64
betas: [0.9, 0.98]
Expand All @@ -15,11 +15,19 @@ optimizer:
grad_clip_thresh: 1.0
grad_acc_step: 1
warm_up_step: 4000
anneal_steps: [175000, 250000, 300000]
anneal_steps: [200000]
anneal_rate: 0.3
step:
total_step: 500000
total_step: 200000
log_step: 100
synth_step: 1000
val_step: 1000
save_step: 25000
save_step: 20000
ctc_step: 1000
aligner:
helper_type: "dga" # ["dga", "ctc", "none"]
ctc_weight_start: 1.0
ctc_weight_end: 1.0
guided_sigma: 0.4
guided_lambda: 1.0
guided_weight: 1.0
12 changes: 11 additions & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.data import Dataset

from text import text_to_sequence
from utils.tools import pad_1D, pad_2D
from utils.tools import pad_1D, pad_2D, pad_3D


class Dataset(Dataset):
Expand Down Expand Up @@ -60,6 +60,12 @@ def __getitem__(self, idx):
"spker_embed",
"{}-spker_embed.npy".format(speaker),
)) if self.load_spker_embed else None
attn_prior_path = os.path.join(
self.preprocessed_path,
"attn_prior",
"{}-attn_prior-{}.npy".format(speaker, basename),
)
attn_prior = np.load(attn_prior_path)

sample = {
"id": basename,
Expand All @@ -70,6 +76,7 @@ def __getitem__(self, idx):
"duration": duration,
"word_boundary": phones_per_word,
"spker_embed": spker_embed,
"attn_prior": attn_prior,
}

return sample
Expand Down Expand Up @@ -100,6 +107,7 @@ def reprocess(self, data, idxs):
word_boundaries = [data[idx]["word_boundary"] for idx in idxs]
spker_embeds = np.concatenate(np.array([data[idx]["spker_embed"] for idx in idxs]), axis=0) \
if self.load_spker_embed else None
attn_priors = [data[idx]["attn_prior"] for idx in idxs]

text_w_lens = np.array([word_boundary.shape[0]
for word_boundary in word_boundaries])
Expand All @@ -111,6 +119,7 @@ def reprocess(self, data, idxs):
mels = pad_2D(mels)
durations = pad_1D(durations)
word_boundaries = pad_1D(word_boundaries)
attn_priors = pad_3D(attn_priors, len(idxs), max(text_lens), max(mel_lens))

return (
ids,
Expand All @@ -123,6 +132,7 @@ def reprocess(self, data, idxs):
text_w_lens,
max(text_w_lens),
spker_embeds,
attn_priors,
mels,
mel_lens,
max(mel_lens),
Expand Down
4 changes: 2 additions & 2 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def evaluate(device, model, step, configs, logger=None, vocoder=None, len_losses
output = model(*(batch[2:]))

# Cal Loss
losses = Loss(batch, output)
losses = Loss(batch, output, step)

for i in range(len(losses)):
loss_sums[i] += losses[i].item() * len(batch[0])

loss_means = [loss_sum / len(dataset) for loss_sum in loss_sums]

message = "Validation Step {}, Total Loss: {:.4f}, Mel Loss: {:.4f}, KL Loss: {:.4f}, PN Loss: {:.4f}, Duration Loss: {:.4f}".format(
message = "Validation Step {}, Total Loss: {:.4f}, Mel Loss: {:.4f}, KL Loss: {:.4f}, PN Loss: {:.4f}, Duration Loss: {:.4f}, Helper Loss: {:.4f}".format(
*([step] + [l for l in loss_means])
)

Expand Down
8 changes: 6 additions & 2 deletions model/PortaSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
class PortaSpeech(nn.Module):
""" PortaSpeech """

def __init__(self, preprocess_config, model_config):
def __init__(self, preprocess_config, model_config, train_config):
super(PortaSpeech, self).__init__()
self.model_config = model_config

self.linguistic_encoder = LinguisticEncoder(model_config)
self.linguistic_encoder = LinguisticEncoder(model_config, train_config)
self.variational_generator = VariationalGenerator(
preprocess_config, model_config)
self.postnet = FlowPostNet(preprocess_config, model_config)
Expand Down Expand Up @@ -53,6 +53,7 @@ def forward(
src_w_lens,
max_src_w_len,
spker_embeds=None,
attn_priors=None,
mels=None,
mel_lens=None,
max_mel_len=None,
Expand All @@ -74,6 +75,7 @@ def forward(
mel_lens,
mel_masks,
alignments,
alignment_logprobs,
) = self.linguistic_encoder(
texts,
src_lens,
Expand All @@ -83,6 +85,7 @@ def forward(
src_w_masks,
mel_masks,
max_mel_len,
attn_priors,
d_targets,
d_control,
)
Expand Down Expand Up @@ -129,4 +132,5 @@ def forward(
dist_info,
src_w_masks,
residual,
alignment_logprobs,
)
25 changes: 17 additions & 8 deletions model/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0):

self.dropout = nn.Dropout(dropout)

def forward(self, q, k, v, key_mask=None, query_mask=None, mapping_mask=None, indivisual_attn=False):
def forward(self, q, k, v, key_mask=None, query_mask=None, mapping_mask=None, indivisual_attn=False, attn_prior=None):

d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

Expand All @@ -581,8 +581,10 @@ def forward(self, q, k, v, key_mask=None, query_mask=None, mapping_mask=None, in
query_mask = query_mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
if mapping_mask is not None:
mapping_mask = mapping_mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(
q, k, v, key_mask=key_mask, query_mask=query_mask, mapping_mask=mapping_mask)
if attn_prior is not None:
attn_prior = attn_prior.repeat(n_head, 1, 1)
output, attns, attn_logprob = self.attention(
q, k, v, key_mask=key_mask, query_mask=query_mask, mapping_mask=mapping_mask, attn_prior=attn_prior)

output = output.view(n_head, sz_b, len_q, d_v)
output = (
Expand All @@ -594,30 +596,37 @@ def forward(self, q, k, v, key_mask=None, query_mask=None, mapping_mask=None, in
# output = self.layer_norm(output)

if indivisual_attn:
attn = attn.view(n_head, sz_b, len_q, len_k)
attns = tuple([attn.view(n_head, sz_b, len_q, len_k) for attn in attns])
attn_logprob = attn_logprob.view(n_head, sz_b, 1, len_q, len_k)

return output, attn
return output, attns, attn_logprob


class ScaledDotProductAttention(nn.Module):
def __init__(self, temperature):
super(ScaledDotProductAttention, self).__init__()
self.temperature = temperature
self.softmax = nn.Softmax(dim=2)
self.log_softmax = torch.nn.LogSoftmax(dim=2)

def forward(self, q, k, v, key_mask=None, query_mask=None, mapping_mask=None):
def forward(self, q, k, v, key_mask=None, query_mask=None, mapping_mask=None, attn_prior=None):

attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature

if key_mask is not None:
attn = attn.masked_fill(key_mask==0., -np.inf)
attn = attn.masked_fill(key_mask == 0., -np.inf)
if attn_prior is not None:
attn = self.log_softmax(attn) + torch.log(attn_prior.transpose(1, 2) + 1e-8)
attn_logprob = attn.unsqueeze(1).clone()

attn = self.softmax(attn)

if query_mask is not None:
attn = attn * query_mask
attn_raw = attn.clone()
if mapping_mask is not None:
attn = attn * mapping_mask
output = torch.bmm(attn, v)

return output, attn
return output, (attn, attn_raw), attn_logprob
54 changes: 33 additions & 21 deletions model/linguistic_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,26 @@ def get_posi_angle_vec(position):
class LinguisticEncoder(nn.Module):
""" Linguistic Encoder """

def __init__(self, config):
def __init__(self, model_config, train_config):
super(LinguisticEncoder, self).__init__()

n_position = config["max_seq_len"] + 1
n_position = model_config["max_seq_len"] + 1
n_src_vocab = len(symbols) + 1
d_word_vec = config["transformer"]["encoder_hidden"]
n_layers = config["transformer"]["encoder_layer"]
n_head = config["transformer"]["encoder_head"]
d_word_vec = model_config["transformer"]["encoder_hidden"]
n_layers = model_config["transformer"]["encoder_layer"]
n_head = model_config["transformer"]["encoder_head"]
d_k = d_v = (
config["transformer"]["encoder_hidden"]
// config["transformer"]["encoder_head"]
model_config["transformer"]["encoder_hidden"]
// model_config["transformer"]["encoder_head"]
)
d_model = config["transformer"]["encoder_hidden"]
d_inner = config["transformer"]["conv_filter_size"]
kernel_size = config["transformer"]["conv_kernel_size"]
# dropout = config["transformer"]["encoder_dropout"]
window_size = config["transformer"]["encoder_window_size"]

self.max_seq_len = config["max_seq_len"]
d_model = model_config["transformer"]["encoder_hidden"]
d_inner = model_config["transformer"]["conv_filter_size"]
kernel_size = model_config["transformer"]["conv_kernel_size"]
# dropout = model_config["transformer"]["encoder_dropout"]
window_size = model_config["transformer"]["encoder_window_size"]
self.helper_type = train_config["aligner"]["helper_type"]

self.max_seq_len = model_config["max_seq_len"]
self.d_model = d_model
self.n_head = n_head

Expand Down Expand Up @@ -97,7 +98,7 @@ def __init__(self, config):
window_size=window_size,
)
self.length_regulator = LengthRegulator()
self.duration_predictor = VariancePredictor(config)
self.duration_predictor = VariancePredictor(model_config)

self.w2p_attn = WordToPhonemeAttention(
n_head, d_model, d_k, d_v # , dropout=dropout
Expand Down Expand Up @@ -154,7 +155,7 @@ def get_rel_coef(self, dur, dur_len, mask):
idx_b += list(range(d_i))
idx.append(torch.tensor(idx_b).to(device))
# assert L[-1].shape == idx[-1].shape
return torch.div(pad(idx).to(device), pad(L).masked_fill(mask==0., 1.).to(device))
return torch.div(pad(idx).to(device), pad(L).masked_fill(mask == 0., 1.).to(device))

def forward(
self,
Expand All @@ -166,6 +167,7 @@ def forward(
src_w_mask,
mel_mask=None,
max_len=None,
attn_prior=None,
duration_target=None,
duration_control=1.0,
):
Expand All @@ -183,7 +185,8 @@ def forward(
1, 2), src_w_mask.unsqueeze(1)).transpose(1, 2)

# Phoneme-level Duration Prediction
log_duration_p_prediction = self.duration_predictor(enc_p_out, src_p_mask)
log_duration_p_prediction = self.duration_predictor(
enc_p_out, src_p_mask)

# Word-level Pooling (in log scale)
log_duration_w_prediction = word_level_pooling(
Expand Down Expand Up @@ -211,8 +214,9 @@ def forward(
src_mask_ = src_p_mask.unsqueeze(1).expand(-1, mel_mask.shape[1], -1)
# [batch, mel_len, seq_len]
mel_mask_ = mel_mask.unsqueeze(-1).expand(-1, -1, src_p_mask.shape[1])
# [batch, mel_len, seq_len]
mapping_mask = self.get_mapping_mask(
x, enc_p_out, duration_w_rounded, word_boundary, src_w_len) # [batch, mel_len, seq_len]
x, enc_p_out, duration_w_rounded, word_boundary, src_w_len)

q = self.add_position_enc(x, position_enc=self.q_position_enc, coef=self.get_rel_coef(
duration_w_rounded, src_w_len, mel_mask))
Expand All @@ -223,8 +227,15 @@ def forward(
# q = self.add_position_enc(x)
# k = self.add_position_enc(enc_p_out)
# v = self.add_position_enc(enc_p_out)
x, alignment = self.w2p_attn(
q, k, v, key_mask=src_mask_, query_mask=mel_mask_, mapping_mask=mapping_mask, indivisual_attn=True
x, attns, attn_logprob = self.w2p_attn(
q=q,
k=k,
v=v,
key_mask=src_mask_,
query_mask=mel_mask_,
mapping_mask=mapping_mask,
indivisual_attn=True,
attn_prior=attn_prior if self.helper_type == "ctc" else None,
)

return (
Expand All @@ -233,7 +244,8 @@ def forward(
duration_w_rounded,
mel_len,
mel_mask,
alignment,
attns,
attn_logprob,
)


Expand Down
Loading

0 comments on commit 814cdda

Please sign in to comment.