Skip to content

Commit

Permalink
update format of RNNLM state
Browse files Browse the repository at this point in the history
  • Loading branch information
mn5k committed Oct 5, 2018
1 parent 871f1f8 commit 09b5d4f
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/nets/e2e_asr_th.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def make_pad_mask(lengths):
return mask


def mask_by_length(xs, length, fill=0):
assert xs.size(0) == len(length)
ret = xs.data.new(*xs.size()).fill_(fill)
for i, l in enumerate(length):
ret[i, :l] = xs[i, :l]
return ret


def th_accuracy(pad_outputs, pad_targets, ignore_label):
"""Function to calculate accuracy
Expand Down Expand Up @@ -1744,7 +1752,7 @@ def index_select_lm_state(rnnlm_state, dim, vidx):
if isinstance(rnnlm_state, dict):
new_state = {}
for k, v in rnnlm_state.items():
new_state[k] = torch.index_select(v, dim, vidx)
new_state[k] = [torch.index_select(vi, dim, vidx) for vi in v]
elif isinstance(rnnlm_state, list):
new_state = []
for i in vidx:
Expand Down Expand Up @@ -2077,6 +2085,7 @@ def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None):
def recognize_beam_batch(self, h, hlens, lpz, recog_args, char_list, rnnlm=None,
normalize_score=True):
logging.info('input lengths: ' + str(h.size(1)))
h = mask_by_length(h, hlens, 0.0)

# search params
batch = len(hlens)
Expand Down

0 comments on commit 09b5d4f

Please sign in to comment.