-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] Update source for Conformer model
- Loading branch information
Showing
3 changed files
with
300 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,292 @@ | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
|
||
__all__ = ["Conformer"] | ||
|
||
|
||
def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor: | ||
batch_size = lengths.shape[0] | ||
max_length = int(torch.max(lengths).item()) | ||
padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand( | ||
batch_size, max_length | ||
) >= lengths.unsqueeze(1) | ||
return padding_mask | ||
|
||
|
||
class _ConvolutionModule(torch.nn.Module): | ||
r"""Conformer convolution module. | ||
Args: | ||
input_dim (int): input dimension. | ||
num_channels (int): number of depthwise convolution layer input channels. | ||
depthwise_kernel_size (int): kernel size of depthwise convolution layer. | ||
dropout (float, optional): dropout probability. (Default: 0.0) | ||
bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``) | ||
use_group_norm (bool, optional): use GroupNorm rather than BatchNorm. (Default: ``False``) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_dim: int, | ||
num_channels: int, | ||
depthwise_kernel_size: int, | ||
dropout: float = 0.0, | ||
bias: bool = False, | ||
use_group_norm: bool = False, | ||
) -> None: | ||
super().__init__() | ||
if (depthwise_kernel_size - 1) % 2 != 0: | ||
raise ValueError("depthwise_kernel_size must be odd to achieve 'SAME' padding.") | ||
self.layer_norm = torch.nn.LayerNorm(input_dim) | ||
self.sequential = torch.nn.Sequential( | ||
torch.nn.Conv1d( | ||
input_dim, | ||
2 * num_channels, | ||
1, | ||
stride=1, | ||
padding=0, | ||
bias=bias, | ||
), | ||
torch.nn.GLU(dim=1), | ||
torch.nn.Conv1d( | ||
num_channels, | ||
num_channels, | ||
depthwise_kernel_size, | ||
stride=1, | ||
padding=(depthwise_kernel_size - 1) // 2, | ||
groups=num_channels, | ||
bias=bias, | ||
), | ||
torch.nn.GroupNorm(num_groups=1, num_channels=num_channels) | ||
if use_group_norm | ||
else torch.nn.BatchNorm1d(num_channels), | ||
torch.nn.SiLU(), | ||
torch.nn.Conv1d( | ||
num_channels, | ||
input_dim, | ||
kernel_size=1, | ||
stride=1, | ||
padding=0, | ||
bias=bias, | ||
), | ||
torch.nn.Dropout(dropout), | ||
) | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
r""" | ||
Args: | ||
input (torch.Tensor): with shape `(B, T, D)`. | ||
Returns: | ||
torch.Tensor: output, with shape `(B, T, D)`. | ||
""" | ||
x = self.layer_norm(input) | ||
x = x.transpose(1, 2) | ||
x = self.sequential(x) | ||
return x.transpose(1, 2) | ||
|
||
|
||
class _FeedForwardModule(torch.nn.Module): | ||
r"""Positionwise feed forward layer. | ||
Args: | ||
input_dim (int): input dimension. | ||
hidden_dim (int): hidden dimension. | ||
dropout (float, optional): dropout probability. (Default: 0.0) | ||
""" | ||
|
||
def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.0) -> None: | ||
super().__init__() | ||
self.sequential = torch.nn.Sequential( | ||
torch.nn.LayerNorm(input_dim), | ||
torch.nn.Linear(input_dim, hidden_dim, bias=True), | ||
torch.nn.SiLU(), | ||
torch.nn.Dropout(dropout), | ||
torch.nn.Linear(hidden_dim, input_dim, bias=True), | ||
torch.nn.Dropout(dropout), | ||
) | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
r""" | ||
Args: | ||
input (torch.Tensor): with shape `(*, D)`. | ||
Returns: | ||
torch.Tensor: output, with shape `(*, D)`. | ||
""" | ||
return self.sequential(input) | ||
|
||
|
||
class ConformerLayer(torch.nn.Module): | ||
r"""Conformer layer that constitutes Conformer. | ||
Args: | ||
input_dim (int): input dimension. | ||
ffn_dim (int): hidden layer dimension of feedforward network. | ||
num_attention_heads (int): number of attention heads. | ||
depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer. | ||
dropout (float, optional): dropout probability. (Default: 0.0) | ||
use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d`` | ||
in the convolution module. (Default: ``False``) | ||
convolution_first (bool, optional): apply the convolution module ahead of | ||
the attention module. (Default: ``False``) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_dim: int, | ||
ffn_dim: int, | ||
num_attention_heads: int, | ||
depthwise_conv_kernel_size: int, | ||
dropout: float = 0.0, | ||
use_group_norm: bool = False, | ||
convolution_first: bool = False, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.ffn1 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout) | ||
|
||
self.self_attn_layer_norm = torch.nn.LayerNorm(input_dim) | ||
self.self_attn = torch.nn.MultiheadAttention(input_dim, num_attention_heads, dropout=dropout) | ||
self.self_attn_dropout = torch.nn.Dropout(dropout) | ||
|
||
self.conv_module = _ConvolutionModule( | ||
input_dim=input_dim, | ||
num_channels=input_dim, | ||
depthwise_kernel_size=depthwise_conv_kernel_size, | ||
dropout=dropout, | ||
bias=True, | ||
use_group_norm=use_group_norm, | ||
) | ||
|
||
self.ffn2 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout) | ||
self.final_layer_norm = torch.nn.LayerNorm(input_dim) | ||
self.convolution_first = convolution_first | ||
|
||
def _apply_convolution(self, input: torch.Tensor) -> torch.Tensor: | ||
residual = input | ||
input = input.transpose(0, 1) | ||
input = self.conv_module(input) | ||
input = input.transpose(0, 1) | ||
input = residual + input | ||
return input | ||
|
||
def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor: | ||
r""" | ||
Args: | ||
input (torch.Tensor): input, with shape `(T, B, D)`. | ||
key_padding_mask (torch.Tensor or None): key padding mask to use in self attention layer. | ||
Returns: | ||
torch.Tensor: output, with shape `(T, B, D)`. | ||
""" | ||
residual = input | ||
x = self.ffn1(input) | ||
x = x * 0.5 + residual | ||
|
||
if self.convolution_first: | ||
x = self._apply_convolution(x) | ||
|
||
residual = x | ||
x = self.self_attn_layer_norm(x) | ||
x, _ = self.self_attn( | ||
query=x, | ||
key=x, | ||
value=x, | ||
key_padding_mask=key_padding_mask, | ||
need_weights=False, | ||
) | ||
x = self.self_attn_dropout(x) | ||
x = x + residual | ||
|
||
if not self.convolution_first: | ||
x = self._apply_convolution(x) | ||
|
||
residual = x | ||
x = self.ffn2(x) | ||
x = x * 0.5 + residual | ||
|
||
x = self.final_layer_norm(x) | ||
return x | ||
|
||
|
||
class Conformer(torch.nn.Module): | ||
r"""Conformer architecture introduced in | ||
*Conformer: Convolution-augmented Transformer for Speech Recognition* | ||
:cite:`gulati2020conformer`. | ||
Args: | ||
input_dim (int): input dimension. | ||
num_heads (int): number of attention heads in each Conformer layer. | ||
ffn_dim (int): hidden layer dimension of feedforward networks. | ||
num_layers (int): number of Conformer layers to instantiate. | ||
depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer. | ||
dropout (float, optional): dropout probability. (Default: 0.0) | ||
use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d`` | ||
in the convolution module. (Default: ``False``) | ||
convolution_first (bool, optional): apply the convolution module ahead of | ||
the attention module. (Default: ``False``) | ||
Examples: | ||
>>> conformer = Conformer( | ||
>>> input_dim=80, | ||
>>> num_heads=4, | ||
>>> ffn_dim=128, | ||
>>> num_layers=4, | ||
>>> depthwise_conv_kernel_size=31, | ||
>>> ) | ||
>>> lengths = torch.randint(1, 400, (10,)) # (batch,) | ||
>>> input = torch.rand(10, int(lengths.max()), input_dim) # (batch, num_frames, input_dim) | ||
>>> output = conformer(input, lengths) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_dim: int, | ||
num_heads: int, | ||
ffn_dim: int, | ||
num_layers: int, | ||
depthwise_conv_kernel_size: int, | ||
dropout: float = 0.0, | ||
use_group_norm: bool = False, | ||
convolution_first: bool = False, | ||
): | ||
super().__init__() | ||
|
||
self.conformer_layers = torch.nn.ModuleList( | ||
[ | ||
ConformerLayer( | ||
input_dim, | ||
ffn_dim, | ||
num_heads, | ||
depthwise_conv_kernel_size, | ||
dropout=dropout, | ||
use_group_norm=use_group_norm, | ||
convolution_first=convolution_first, | ||
) | ||
for _ in range(num_layers) | ||
] | ||
) | ||
|
||
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
r""" | ||
Args: | ||
input (torch.Tensor): with shape `(B, T, input_dim)`. | ||
lengths (torch.Tensor): with shape `(B,)` and i-th element representing | ||
number of valid frames for i-th batch element in ``input``. | ||
Returns: | ||
(torch.Tensor, torch.Tensor) | ||
torch.Tensor | ||
output frames, with shape `(B, T, input_dim)` | ||
torch.Tensor | ||
output lengths, with shape `(B,)` and i-th element representing | ||
number of valid frames for i-th batch element in output frames. | ||
""" | ||
encoder_padding_mask = _lengths_to_padding_mask(lengths) | ||
|
||
x = input.transpose(0, 1) | ||
for layer in self.conformer_layers: | ||
x = layer(x, encoder_padding_mask) | ||
return x.transpose(0, 1), lengths |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters