Skip to content

Commit

Permalink
[Bugfix] MX utest traversal memory corruption (dmlc#312)
Browse files Browse the repository at this point in the history
* WIP

* temp fix mx traversal memory crash bug
  • Loading branch information
jermainewang authored and zheng-da committed Dec 17, 2018
1 parent 3d44630 commit 0d0f443
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
16 changes: 8 additions & 8 deletions python/dgl/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def bfs_nodes_generator(graph, source, reversed=False):
[tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])]
"""
ghandle = graph._graph._handle
source = utils.toindex(source).todgltensor()
ret = _CAPI_DGLBFSNodes(ghandle, source, reversed)
source = utils.toindex(source)
ret = _CAPI_DGLBFSNodes(ghandle, source.todgltensor(), reversed)
all_nodes = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist()
Expand Down Expand Up @@ -80,8 +80,8 @@ def bfs_edges_generator(graph, source, reversed=False):
[tensor([0]), tensor([1, 2]), tensor([4, 5])]
"""
ghandle = graph._graph._handle
source = utils.toindex(source).todgltensor()
ret = _CAPI_DGLBFSEdges(ghandle, source, reversed)
source = utils.toindex(source)
ret = _CAPI_DGLBFSEdges(ghandle, source.todgltensor(), reversed)
all_edges = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist()
Expand Down Expand Up @@ -161,8 +161,8 @@ def dfs_edges_generator(graph, source, reversed=False):
[tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4])]
"""
ghandle = graph._graph._handle
source = utils.toindex(source).todgltensor()
ret = _CAPI_DGLDFSEdges(ghandle, source, reversed)
source = utils.toindex(source)
ret = _CAPI_DGLDFSEdges(ghandle, source.todgltensor(), reversed)
all_edges = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist()
Expand Down Expand Up @@ -232,10 +232,10 @@ def dfs_labeled_edges_generator(
(tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([2]))
"""
ghandle = graph._graph._handle
source = utils.toindex(source).todgltensor()
source = utils.toindex(source)
ret = _CAPI_DGLDFSLabeledEdges(
ghandle,
source,
source.todgltensor(),
reversed,
has_reverse_edge,
has_nontree_edge,
Expand Down
5 changes: 2 additions & 3 deletions tests/mxnet/test_propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_prop_nodes_bfs():
assert np.allclose(g.ndata['x'].asnumpy(),
np.array([[2., 2.], [4., 4.], [6., 6.], [8., 8.], [9., 9.]]))

def _test_prop_edges_dfs():
def test_prop_edges_dfs():
g = dgl.DGLGraph(nx.path_graph(5))
g.register_message_func(mfunc)
g.register_reduce_func(rfunc)
Expand Down Expand Up @@ -69,6 +69,5 @@ def test_prop_nodes_topo():

if __name__ == '__main__':
test_prop_nodes_bfs()
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#_test_prop_edges_dfs()
test_prop_edges_dfs()
test_prop_nodes_topo()
5 changes: 2 additions & 3 deletions tests/mxnet/test_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def tensor_topo_traverse():
assert all(toset(x) == toset(y) for x, y in zip(layers_dgl, layers_spmv))

DFS_LABEL_NAMES = ['forward', 'reverse', 'nontree']
def _test_dfs_labeled_edges(n=1000, example=False):
def test_dfs_labeled_edges(n=1000, example=False):
dgl_g = dgl.DGLGraph()
dgl_g.add_nodes(6)
dgl_g.add_edges([0, 1, 0, 3, 3], [1, 2, 2, 4, 5])
Expand Down Expand Up @@ -123,5 +123,4 @@ def combine_frontiers(sol):
if __name__ == '__main__':
test_bfs()
test_topological_nodes()
#TODO(zhengda): the test leads to segfault in MXNet on Ubuntu 16.04.
#_test_dfs_labeled_edges()
test_dfs_labeled_edges()

0 comments on commit 0d0f443

Please sign in to comment.