Skip to content

Commit

Permalink
match the model size
Browse files Browse the repository at this point in the history
  • Loading branch information
keonlee9420 committed Oct 8, 2021
1 parent 81be46d commit 2ad88da
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 17 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ PyTorch Implementation of [PortaSpeech: Portable and High-Quality Generative Tex
## Model Size
| Module | Normal | Small | Normal (paper) | Small (paper) |
| :----- | :-----: | :-----: | :-----: | :-----: |
| *Total* | 34.3M | 9.6M | 21.8M | 6.7M
| *LinguisticEncoder* | 14M | 3.5M | - | -
| *Total* | 24M | 7.6M | 21.8M | 6.7M
| *LinguisticEncoder* | 3.7M | 1.4M | - | -
| *VariationalGenerator* | 11M | 2.8M | - | -
| *FlowPostNet* | 9.3M | 3.4M | - | -

Expand Down
19 changes: 8 additions & 11 deletions model/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_s
window_size=window_size, p_dropout=p_dropout, block_length=block_length))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(
hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
hidden_channels, hidden_channels, kernel_size, p_dropout=p_dropout))
self.norm_layers_2.append(LayerNorm(hidden_channels))

def forward(self, x, x_mask):
Expand Down Expand Up @@ -511,29 +511,25 @@ def forward(self, x):


class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
def __init__(self, in_channels, out_channels, kernel_size, p_dropout=0., activation=None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation

self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.conv_2 = nn.Conv1d(
filter_channels, out_channels, kernel_size, padding=kernel_size//2)
self.conv = nn.Conv1d(
in_channels, out_channels, kernel_size, padding=kernel_size//2)
self.drop = nn.Dropout(p_dropout)

def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = self.conv(x * x_mask)
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
return x * x_mask


Expand All @@ -553,7 +549,7 @@ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0):

self.attention = ScaledDotProductAttention(
temperature=np.power(d_k, 0.5))
self.layer_norm = nn.LayerNorm(d_model)
# self.layer_norm = nn.LayerNorm(d_model)

self.fc = LinearNorm(n_head * d_v, d_model)

Expand Down Expand Up @@ -594,7 +590,8 @@ def forward(self, q, k, v, mask_1=None, mask_2=None, mapping_mask=None, indivisu
) # b x lq x (n*dv)

output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
output = output + residual
# output = self.layer_norm(output)

if indivisual_attn:
attn = attn.view(n_head, sz_b, len_q, len_k)
Expand Down
22 changes: 18 additions & 4 deletions model/variational_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,13 @@ def __init__(self, preprocess_config, model_config):
self.cond_layer_e = torch.nn.utils.weight_norm(
torch.nn.Conv1d(
d_model,
2*encoder_decoder_hidden*encoder_layer,
2*encoder_decoder_hidden*encoder_layer, # d_model
kernel_size=conv_kernel_size,
stride=conv_stride_size,
padding=self.padding_size,
dilation=dilation), name='weight')
# self.cond_layer_e_prj = LinearNorm(
# d_model, 2*encoder_decoder_hidden*encoder_layer)
self.enc_wn = NonCausalWaveNet(
encoder_decoder_hidden,
conv_kernel_size,
Expand Down Expand Up @@ -92,11 +94,13 @@ def __init__(self, preprocess_config, model_config):
self.cond_layer_d = torch.nn.utils.weight_norm(
torch.nn.Conv1d(
d_model,
2*encoder_decoder_hidden*decoder_layer,
2*encoder_decoder_hidden*decoder_layer, # d_model
kernel_size=conv_kernel_size,
stride=conv_stride_size,
padding=self.padding_size,
dilation=dilation), name='weight')
# self.cond_layer_d_prj = LinearNorm(
# d_model, 2*encoder_decoder_hidden*decoder_layer)
self.dec_wn = NonCausalWaveNet(
encoder_decoder_hidden,
conv_kernel_size,
Expand Down Expand Up @@ -168,6 +172,11 @@ def forward(self, mel, mel_len, mel_mask, h_text):
mel = self.pad_input(mel)

# Prepare Conditioner
# h_text_f = self.cond_layer_f(h_text.transpose(1, 2)) # [B, H, L']
# h_text_e = self.cond_layer_e_prj(self.cond_layer_e(
# h_text.transpose(1, 2)).transpose(1, 2)).transpose(1, 2) # [B, H, L']
# h_text_d = self.cond_layer_d_prj(self.cond_layer_d(
# h_text.transpose(1, 2)).transpose(1, 2)).transpose(1, 2) # [B, H, L']
h_text = h_text.transpose(1, 2)
h_text_f = self.cond_layer_f(h_text) # [B, H, L']
h_text_e = self.cond_layer_e(h_text) # [B, H, L']
Expand Down Expand Up @@ -197,7 +206,8 @@ def forward(self, mel, mel_len, mel_mask, h_text):
x = self.dec_wn(x, g=h_text_d) * mel_mask_conv.transpose(1, 2)
x = x.contiguous().transpose(1, 2)
mel_res = self.dec_conv(x)
mel_res = self.trim_output(mel_res, mel_mask.shape[1]) * mel_mask.unsqueeze(-1)
mel_res = self.trim_output(
mel_res, mel_mask.shape[1]) * mel_mask.unsqueeze(-1)
residual = self.residual_layer(mel_res) * mel_mask.unsqueeze(-1)

return mel_res, residual, (z_p, logs_q.transpose(1, 2), mel_mask_conv.transpose(1, 2))
Expand All @@ -212,6 +222,9 @@ def inference(self, mel_len, mel_mask, h_text):
h_text = self.pad_input(h_text)

# Prepare Conditioner
# h_text_f = self.cond_layer_f(h_text.transpose(1, 2)) # [B, H, L']
# h_text_d = self.cond_layer_d_prj(self.cond_layer_d(
# h_text.transpose(1, 2)).transpose(1, 2)).transpose(1, 2) # [B, H, L']
h_text = h_text.transpose(1, 2)
h_text_f = self.cond_layer_f(h_text) # [B, H, L']
h_text_d = self.cond_layer_d(h_text) # [B, H, L']
Expand All @@ -230,7 +243,8 @@ def inference(self, mel_len, mel_mask, h_text):
x = self.dec_wn(x, g=h_text_d) * mel_mask_conv.transpose(1, 2)
x = x.contiguous().transpose(1, 2)
mel_res = self.dec_conv(x)
mel_res = self.trim_output(mel_res, mel_mask.shape[1]) * mel_mask.unsqueeze(-1)
mel_res = self.trim_output(
mel_res, mel_mask.shape[1]) * mel_mask.unsqueeze(-1)
residual = self.residual_layer(mel_res) * mel_mask.unsqueeze(-1)

return mel_res, residual, None
Expand Down

0 comments on commit 2ad88da

Please sign in to comment.