Skip to content

Commit

Permalink
[Windows] fix compilation issues on vs2015 (dmlc#1405)
Browse files Browse the repository at this point in the history
* [Windows] fix compilation issues on vs2015

* fix test
  • Loading branch information
BarclayII authored Mar 30, 2020
1 parent e9440ac commit e4cc818
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 22 deletions.
4 changes: 1 addition & 3 deletions conda/dgl/bld.bat
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ git submodule init
git submodule update --recursive
md build
cd build
cmake -DUSE_CUDA=%USE_CUDA% -DUSE_OPENMP=ON -DCUDA_ARCH_NAME=All -DCMAKE_CXX_FLAGS="/DDGL_EXPORTS" -DCMAKE_CONFIGURATION_TYPES="Release" -DDMLC_FORCE_SHARED_CRT=ON .. -G "Visual Studio 15 2017 Win64" || EXIT /B 1
msbuild dgl.sln || EXIT /B 1
COPY Release\dgl.dll .
COPY %TEMP%\dgl.dll .
cd ..\python
"%PYTHON%" setup.py install --single-version-externally-managed --record=record.txt || EXIT /B 1
EXIT /B
20 changes: 13 additions & 7 deletions python/dgl/sampling/neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
fanout_array = [None] * len(g.etypes)
for etype, value in fanout.items():
fanout_array[g.get_etype_id(etype)] = value
fanout_array = utils.toindex(fanout_array).todgltensor()

if prob is None:
prob_arrays = [nd.array([], ctx=nd.cpu())] * len(g.etypes)
Expand Down Expand Up @@ -100,7 +101,7 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False):
----------
g : DGLHeteroGraph
Full graph structure.
k : int
k : int or dict[etype, int]
The K value.
weight : str
Feature name of the weights associated with each edge. Its shape should be
Expand Down Expand Up @@ -138,11 +139,16 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False):
else:
nodes_all_types.append(nd.array([], ctx=nd.cpu()))

if not isinstance(k, list):
k = [int(k)] * len(g.etypes)
if len(k) != len(g.etypes):
raise DGLError('K value must be specified for each edge type '
'if a list is provided.')
if not isinstance(k, dict):
k_array = [int(k)] * len(g.etypes)
else:
if len(k) != len(g.etypes):
raise DGLError('K value must be specified for each edge type '
'if a dict is provided.')
k_array = [None] * len(g.etypes)
for etype, value in k.items():
k_array[g.get_etype_id(etype)] = value
k_array = utils.toindex(k_array).todgltensor()

weight_arrays = []
for etype in g.canonical_etypes:
Expand All @@ -153,7 +159,7 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False):
weight, etype))

subgidx = _CAPI_DGLSampleNeighborsTopk(
g._graph, nodes_all_types, k, edge_dir, weight_arrays, bool(ascending))
g._graph, nodes_all_types, k_array, edge_dir, weight_arrays, bool(ascending))
induced_edges = subgidx.induced_edges
ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes)
for i, etype in enumerate(ret.canonical_etypes):
Expand Down
19 changes: 13 additions & 6 deletions src/graph/metis_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
* \brief Call Metis partitioning
*/

#include <metis.h>
#include <dgl/graph_op.h>
#include <dgl/packed_func_ext.h>
#include "../c_api_common.h"

#if !defined(_WIN32)

#include <metis.h>
#include <dgl/graph_op.h>

using namespace dgl::runtime;

namespace dgl {

#if !defined(_WIN32)

IdArray GraphOp::MetisPartition(GraphPtr g, int k) {
// The index type of Metis needs to be compatible with DGL index type.
CHECK_EQ(sizeof(idx_t), sizeof(dgl_id_t));
Expand Down Expand Up @@ -71,7 +72,13 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLMetisPartition")
*rv = GraphOp::MetisPartition(g.sptr(), k);
});

#else
} // namespace dgl

#else // defined(_WIN32)

using namespace dgl::runtime;

namespace dgl {

DGL_REGISTER_GLOBAL("transform._CAPI_DGLMetisPartition")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Expand All @@ -81,6 +88,6 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLMetisPartition")
*rv = aten::NullArray();
});

#endif // !defined(_WIN32)

} // namespace dgl
#endif // !defined(_WIN32)
12 changes: 7 additions & 5 deletions src/graph/sampling/neighbor/neighbor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
const auto& fanouts = ListValueToVector<int64_t>(args[2]);
IdArray fanouts_array = args[2];
const auto& fanouts = fanouts_array.ToVector<int64_t>();
const std::string dir_str = args[3];
const auto& prob = ListValueToVector<FloatArray>(args[4]);
const bool replace = args[5];
Expand All @@ -192,14 +193,15 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
const auto& k = ListValueToVector<int64_t>(args[2]);
IdArray k_array = args[2];
const auto& k = k_array.ToVector<int64_t>();
const std::string dir_str = args[3];
const auto& weight = ListValueToVector<FloatArray>(args[4]);
const bool ascending = args[5];

CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut;
CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut;

std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighborsTopk(
Expand Down
3 changes: 2 additions & 1 deletion tests/compute/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ def _test3():
_test3()

# test different k for different relations
subg = dgl.sampling.select_topk(hg, [1, 2, 0, 2], 'weight', {'user' : [0,1], 'game' : 0})
subg = dgl.sampling.select_topk(
hg, {'follow': 1, 'play': 2, 'liked-by': 0, 'flips': 2}, 'weight', {'user' : [0,1], 'game' : 0})
assert len(subg.ntypes) == 3
assert len(subg.etypes) == 4
assert subg['follow'].number_of_edges() == 2
Expand Down

0 comments on commit e4cc818

Please sign in to comment.