Skip to content

Commit

Permalink
[Sampling] Implement dgl.compact_graphs() for the GPU (dmlc#3423)
Browse files Browse the repository at this point in the history
* gpu compact graph template

* cuda compact graph draft

* fix typo

* compact graphs

* pass unit test but fail in training

* example using EdgeDataLoader on the GPU

* refactor cuda_compact_graph and cuda_to_block

* update training scripts

* fix linting

* fix linting

* fix exclude_edges for the GPU

* add --data-cpu & fix copyright
  • Loading branch information
yaox12 authored Oct 21, 2021
1 parent 308e52a commit a8c8101
Show file tree
Hide file tree
Showing 11 changed files with 676 additions and 295 deletions.
19 changes: 17 additions & 2 deletions examples/pytorch/graphsage/train_sampling_unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def run(proc_id, n_gpus, args, devices, data):
train_mask, val_mask, test_mask, n_classes, g = data
nfeat = g.ndata.pop('feat')
labels = g.ndata.pop('label')
if not args.data_cpu:
nfeat = nfeat.to(device)
labels = labels.to(device)
in_feats = nfeat.shape[1]

train_nid = th.LongTensor(np.nonzero(train_mask)).squeeze()
Expand All @@ -77,6 +80,11 @@ def run(proc_id, n_gpus, args, devices, data):
n_edges = g.num_edges()
train_seeds = th.arange(n_edges)

if args.sample_gpu:
assert n_gpus > 0, "Must have GPUs to enable GPU sampling"
train_seeds = train_seeds.to(device)
g = g.to(device)

# Create sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')])
Expand Down Expand Up @@ -121,11 +129,11 @@ def run(proc_id, n_gpus, args, devices, data):
tic_step = time.time()
for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
batch_inputs = nfeat[input_nodes].to(device)
d_step = time.time()

pos_graph = pos_graph.to(device)
neg_graph = neg_graph.to(device)
blocks = [block.int().to(device) for block in blocks]
d_step = time.time()

# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, pos_graph, neg_graph)
Expand Down Expand Up @@ -213,6 +221,13 @@ def main(args, devices):
argparser.add_argument('--dropout', type=float, default=0.5)
argparser.add_argument('--num-workers', type=int, default=0,
help="Number of sampling processes. Use 0 for no extra process.")
argparser.add_argument('--sample-gpu', action='store_true',
help="Perform the sampling process on the GPU. Must have 0 workers.")
argparser.add_argument('--data-cpu', action='store_true',
help="By default the script puts all node features and labels "
"on GPU when using it to save time for data copy. This may "
"be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.")
args = argparser.parse_args()

devices = list(map(int, args.gpu.split(',')))
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def sample_blocks(self, g, seed_nodes, exclude_eids=None):
else:
seed_nodes_in = seed_nodes_in.to(graph_device)

if self.exclude_edges_in_frontier:
if self.exclude_edges_in_frontier(g):
frontier = self.sample_frontier(
block_id, g, seed_nodes_in, exclude_eids=exclude_eids)
else:
Expand Down
7 changes: 3 additions & 4 deletions python/dgl/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,7 +1913,7 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru
graphs : DGLGraph or list[DGLGraph]
The graph, or list of graphs.
All graphs must be on CPU.
All graphs must be on the same devices.
All graphs must have the same set of nodes.
always_preserve : Tensor or dict[str, Tensor], optional
Expand Down Expand Up @@ -2013,7 +2013,6 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru
return []
if graphs[0].is_block:
raise DGLError('Compacting a block graph is not allowed.')
assert all(g.device == F.cpu() for g in graphs), 'all the graphs must be on CPU'

# Ensure the node types are ordered the same.
# TODO(BarclayII): we ideally need to remove this constraint.
Expand All @@ -2026,8 +2025,8 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru
ntypes, g.ntypes)
assert idtype == g.idtype, "Expect graph data type to be {}, but got {}".format(
idtype, g.idtype)
assert device == g.device, "Expect graph device to be {}, but got {}".format(
device, g.device)
assert device == g.device, "All graphs must be on the same devices." \
"Expect graph device to be {}, but got {}".format(device, g.device)

# Process the dictionary or tensor of "always preserve" nodes
if always_preserve is None:
Expand Down
52 changes: 41 additions & 11 deletions src/graph/transform/compact.cc
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
/*!
* Copyright (c) 2019 by Contributors
* Copyright 2019-2021 Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file graph/transform/compact.cc
* \brief Compact graph implementation
*/

#include "compact.h"

#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/container.h>
#include <vector>
#include <utility>
#include "../../c_api_common.h"
Expand All @@ -27,7 +44,7 @@ namespace {

template<typename IdType>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs(
CompactGraphsCPU(
const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) {
// TODO(BarclayII): check whether the node space and metagraph of each graph is the same.
Expand Down Expand Up @@ -121,17 +138,20 @@ CompactGraphs(

}; // namespace

template<>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs(
CompactGraphs<kDLCPU, int32_t>(
const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) {
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> result;
// TODO(BarclayII): check for all IdArrays
CHECK(graphs[0]->DataType() == always_preserve[0]->dtype) << "data type mismatch.";
ATEN_ID_TYPE_SWITCH(graphs[0]->DataType(), IdType, {
result = CompactGraphs<IdType>(graphs, always_preserve);
});
return result;
return CompactGraphsCPU<int32_t>(graphs, always_preserve);
}

template<>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs<kDLCPU, int64_t>(
const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve) {
return CompactGraphsCPU<int64_t>(graphs, always_preserve);
}

DGL_REGISTER_GLOBAL("transform._CAPI_DGLCompactGraphs")
Expand All @@ -146,7 +166,17 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLCompactGraphs")
for (Value array : always_preserve_refs)
always_preserve.push_back(array->data);

const auto &result_pair = CompactGraphs(graphs, always_preserve);
// TODO(BarclayII): check for all IdArrays
CHECK(graphs[0]->DataType() == always_preserve[0]->dtype) << "data type mismatch.";

std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> result_pair;

ATEN_XPU_SWITCH_CUDA(graphs[0]->Context().device_type, XPU, "CompactGraphs", {
ATEN_ID_TYPE_SWITCH(graphs[0]->DataType(), IdType, {
result_pair = CompactGraphs<XPU, IdType>(
graphs, always_preserve);
});
});

List<HeteroGraphRef> compacted_graph_refs;
List<Value> induced_nodes;
Expand Down
53 changes: 53 additions & 0 deletions src/graph/transform/compact.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*!
* Copyright 2021 Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file graph/transform/compact.h
* \brief Functions to find and eliminate the common isolated nodes across
* all given graphs with the same set of nodes.
*/

#ifndef DGL_GRAPH_TRANSFORM_COMPACT_H_
#define DGL_GRAPH_TRANSFORM_COMPACT_H_

#include <dgl/array.h>
#include <dgl/base_heterograph.h>

#include <vector>
#include <utility>

namespace dgl {
namespace transform {

/**
* @brief Given a list of graphs with the same set of nodes, find and eliminate
* the common isolated nodes across all graphs.
*
* @tparam XPU The type of device to operate on.
* @tparam IdType The type to use as an index.
* @param graphs The list of graphs to be compacted.
* @param always_preserve The vector of nodes to be preserved.
*
* @return The vector of compacted graphs and the vector of induced nodes.
*/
template<DLDeviceType XPU, typename IdType>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs(
const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve);

} // namespace transform
} // namespace dgl

#endif // DGL_GRAPH_TRANSFORM_COMPACT_H_
Loading

0 comments on commit a8c8101

Please sign in to comment.