Skip to content

Commit

Permalink
Correct the estimation of cnn output lengths in convtransformer (face…
Browse files Browse the repository at this point in the history
…bookresearch#1636)

Summary: Pull Request resolved: fairinternal/fairseq-py#1636

Reviewed By: xutaima

Differential Revision: D26562816

Pulled By: jmp84

fbshipit-source-id: 4e6efd0b4236d7187bd365d790f260bd5297aed5
  • Loading branch information
xutaima authored and facebook-github-bot committed Feb 20, 2021
1 parent c6b5c00 commit ae22da6
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions fairseq/models/speech_to_text/convtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class ConvTransformerModel(FairseqEncoderDecoderModel):
Transformer-based Speech translation model from ESPNet-ST
https://arxiv.org/abs/2004.10234
"""

def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)

Expand Down Expand Up @@ -307,7 +306,11 @@ def forward(self, src_tokens, src_lengths):

subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)

input_lengths = (src_lengths.float() / subsampling_factor).ceil().long()
input_lengths = torch.min(
(src_lengths.float() / subsampling_factor).ceil().long(),
x.size(0) * src_lengths.new_ones([src_lengths.size(0)]).long()
)

encoder_padding_mask, _ = lengths_to_encoder_padding_mask(
input_lengths, batch_first=True
)
Expand Down

0 comments on commit ae22da6

Please sign in to comment.