Skip to content

Commit

Permalink
Fix tree lstm (dmlc#2052)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
classicsong and Ubuntu authored Aug 19, 2020
1 parent c8a44c7 commit 4edde00
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
3 changes: 3 additions & 0 deletions examples/mxnet/tree_lstm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ DGLBACKEND=mxnet python3 train.py --gpu 0
## Speed Test

See https://docs.google.com/spreadsheets/d/1eCQrVn7g0uWriz63EbEDdes2ksMdKdlbWMyT8PSU4rc .

## Note
The code can work with MXNet 1.5.1
2 changes: 2 additions & 0 deletions examples/mxnet/tree_lstm/tree_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self,
self.linear = gluon.nn.Dense(num_classes)
cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell
self.cell = cell(x_size, h_size)
self.ctx = ctx

def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch.
Expand All @@ -113,6 +114,7 @@ def forward(self, batch, h, c):
The prediction of each node.
"""
g = batch.graph
g = g.to(self.ctx)
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
wiou = self.cell.W_iou(self.dropout(embeds))
Expand Down
1 change: 1 addition & 0 deletions python/dgl/data/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def load(self):

self._trees = load_graphs(graph_path)[0]
self._vocab = load_info(vocab_path)['vocab']
self._pretrained_emb = None
if os.path.exists(emb_path):
self._pretrained_emb = load_info(emb_path)['embed']

Expand Down

0 comments on commit 4edde00

Please sign in to comment.