Skip to content

Commit

Permalink
Fix potential bug and add more docs
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Aug 7, 2022
1 parent 15a3d1c commit 6afe995
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions fast_rnnt/python/fast_rnnt/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,13 @@ def get_rnnt_prune_ranges(
s_range >= 2
), "Pruning range for standard RNN-T should be equal to or greater than 2, or no valid paths could survive pruning."

(B_stride, S_stride, T_stride) = py_grad.stride()
blk_grad = torch.as_strided(
py_grad, (B, S1 - s_range + 1, s_range, T), (S1 * T, T, T, 1)
py_grad,
(B, S1 - s_range + 1, s_range, T),
(B_stride, S_stride, S_stride, T_stride),
)

# (B, S1 - s_range + 1, T)
blk_sum_grad = torch.sum(blk_grad, axis=2)

Expand All @@ -572,13 +576,17 @@ def get_rnnt_prune_ranges(

# (B, T)
s_begin = torch.argmax(final_grad, axis=1)
s_begin = s_begin[:, :T]

# Handle the values of s_begin in padding positions.
# -1 here means we fill the position of the last frame of real data with
# -1 here means we fill the position of the last frame (before padding) with
# padding value which is `len(symbols) - s_range + 1`.
# This is to guarantee that we reach the last symbol at last frame of real
# data.
# This is to guarantee that we reach the last symbol at last frame (before
# padding).
# The shape of the mask is (B, T), for example, we have a batch containing
# 3 sequences, their lengths are 3, 5, 6 (i.e. B = 3, T = 6), so the mask is
# [[True, True, False, False, False, False],
# [True, True, True, True, False, False],
# [True, True, True, True, True, False]]
mask = torch.arange(0, T, device=px_grad.device).reshape(1, T).expand(B, T)
mask = mask < boundary[:, 3].reshape(B, 1) - 1

Expand All @@ -589,7 +597,7 @@ def get_rnnt_prune_ranges(
s_begin = torch.where(mask, s_begin, s_begin_padding)

# adjusting lower bound to make it satisfied some constrains, see docs in
# `adjust_pruning_lower_bound` for more details of these constrains.
# `_adjust_pruning_lower_bound` for more details of these constrains.
# T1 == T here means we are using the modified version of transducer,
# the third constrain becomes `s_begin[i + 1] - s_begin[i] < 2`, because
# it only emits one symbol per frame.
Expand Down

0 comments on commit 6afe995

Please sign in to comment.