From b80b89527e789d85dde9bb92417a2be22fbbcd45 Mon Sep 17 00:00:00 2001 From: WhereIsMyHead Date: Sat, 26 Jan 2019 22:09:42 +0800 Subject: [PATCH] Fix on torch.argsort only supporting for >PyTorch1.0.0 (#24) * Fix version incompatibility of PyTorch * Update dynamic_rnn.py --- README.md | 2 +- layers/dynamic_rnn.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index f94bd70..29f3f45 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ ## Requirement -* PyTorch 0.4.0 +* PyTorch >= 0.4.0 * NumPy 1.13.3 * tensorboardX 1.2 * Python 3.6 diff --git a/layers/dynamic_rnn.py b/layers/dynamic_rnn.py index 8da0f82..7490f99 100644 --- a/layers/dynamic_rnn.py +++ b/layers/dynamic_rnn.py @@ -57,10 +57,10 @@ def forward(self, x, x_len): :return: """ """sort""" - x_sort_idx = torch.argsort(-x_len) - x_unsort_idx = torch.argsort(x_sort_idx).long() + x_sort_idx = torch.sort(-x_len)[1].long() + x_unsort_idx = torch.sort(x_sort_idx)[1].long() x_len = x_len[x_sort_idx] - x = x[x_sort_idx.long()] + x = x[x_sort_idx] """pack""" x_emb_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=self.batch_first)