diff --git a/example/dgl/common_config.py b/example/dgl/common_config.py new file mode 100644 index 00000000..eabd26c5 --- /dev/null +++ b/example/dgl/common_config.py @@ -0,0 +1,15 @@ +import os + +def get_default_timeout(): + # In seconds + return 300 + +def wait_and_join(processes): + ret = os.waitpid(-1, 0) + if os.WEXITSTATUS(ret[1]) != 0: + print("Detect pid {:} error exit".format(ret[0])) + for p in processes: + p.kill() + + for p in processes: + p.join() \ No newline at end of file diff --git a/example/dgl/multi_gpu/common_config.py b/example/dgl/multi_gpu/common_config.py new file mode 120000 index 00000000..e98e6dd2 --- /dev/null +++ b/example/dgl/multi_gpu/common_config.py @@ -0,0 +1 @@ +../common_config.py \ No newline at end of file diff --git a/example/dgl/multi_gpu/train_gcn.py b/example/dgl/multi_gpu/train_gcn.py index d4ce89b8..e0ca60ac 100644 --- a/example/dgl/multi_gpu/train_gcn.py +++ b/example/dgl/multi_gpu/train_gcn.py @@ -6,6 +6,7 @@ - Code: https://github.com/tkipf/gcn """ import argparse +import datetime import torch import torch.nn as nn import torch.optim as optim @@ -19,7 +20,7 @@ from torch.nn.parallel import DistributedDataParallel import math import sys - +from common_config import * class GCN(nn.Module): def __init__(self, @@ -225,7 +226,8 @@ def run(worker_id, run_config): torch.distributed.init_process_group(backend="nccl", init_method=dist_init_method, world_size=world_size, - rank=worker_id) + rank=worker_id, + timeout=datetime.timedelta(seconds=get_default_timeout())) dataset = run_config['dataset'] g = run_config['g'].to(sample_device) @@ -317,6 +319,9 @@ def run(worker_id, run_config): if not run_config['pipelining']: sync_device() + if num_worker > 1: + torch.distributed.barrier() + num_samples.append(sum([block.num_edges() for block in blocks])) num_nodes.append(blocks[0].num_src_nodes()) @@ -325,9 +330,6 @@ def run(worker_id, run_config): batch_labels = None blocks = None - if num_worker > 1: - torch.distributed.barrier() - t4 = time.time() sample_times.append(t1 - t0) @@ -394,5 +396,6 @@ def run(worker_id, run_config): p = mp.Process(target=run, args=(worker_id, run_config)) p.start() workers.append(p) - for p in workers: - p.join() + + wait_and_join(workers) + diff --git a/example/dgl/multi_gpu/train_graphsage.py b/example/dgl/multi_gpu/train_graphsage.py index 1a0cdeb2..435630ff 100644 --- a/example/dgl/multi_gpu/train_graphsage.py +++ b/example/dgl/multi_gpu/train_graphsage.py @@ -1,4 +1,5 @@ import argparse +import datetime import dgl import torch import dgl.nn.pytorch as dglnn @@ -12,6 +13,7 @@ import numpy as np import math import sys +from common_config import * class SAGE(nn.Module): @@ -207,7 +209,8 @@ def run(worker_id, run_config): torch.distributed.init_process_group(backend="nccl", init_method=dist_init_method, world_size=world_size, - rank=worker_id) + rank=worker_id, + timeout=datetime.timedelta(seconds=get_default_timeout())) dataset = run_config['dataset'] g = run_config['g'].to(sample_device) @@ -299,6 +302,8 @@ def run(worker_id, run_config): if not run_config['pipelining']: sync_device() + if num_worker > 1: + torch.distributed.barrier() num_samples.append(sum([block.num_edges() for block in blocks])) num_nodes.append(blocks[0].num_src_nodes()) @@ -307,9 +312,6 @@ def run(worker_id, run_config): batch_labels = None blocks = None - if num_worker > 1: - torch.distributed.barrier() - t4 = time.time() sample_times.append(t1 - t0) @@ -375,5 +377,5 @@ def run(worker_id, run_config): p = mp.Process(target=run, args=(worker_id, run_config)) p.start() workers.append(p) - for p in workers: - p.join() + + wait_and_join(workers) diff --git a/example/dgl/multi_gpu/train_pinsage.py b/example/dgl/multi_gpu/train_pinsage.py index 20e8a386..67743592 100644 --- a/example/dgl/multi_gpu/train_pinsage.py +++ b/example/dgl/multi_gpu/train_pinsage.py @@ -1,4 +1,5 @@ import argparse +import datetime import dgl import torch import torch.optim as optim @@ -12,6 +13,7 @@ import numpy as np import math import sys +from common_config import * """ We have made the following modification(or say, simplification) on PinSAGE, @@ -256,7 +258,8 @@ def run(worker_id, run_config): torch.distributed.init_process_group(backend="nccl", init_method=dist_init_method, world_size=world_size, - rank=worker_id) + rank=worker_id, + timeout=datetime.timedelta(seconds=get_default_timeout())) dataset = run_config['dataset'] g = run_config['g'] @@ -342,6 +345,8 @@ def run(worker_id, run_config): if not run_config['pipelining']: sync_device() + if num_worker > 1: + torch.distributed.barrier() num_samples.append(sum([block.num_edges() for block in blocks])) num_nodes.append(blocks[0].num_src_nodes()) @@ -350,8 +355,6 @@ def run(worker_id, run_config): batch_labels = None blocks = None - if num_worker > 1: - torch.distributed.barrier() t4 = time.time() sample_times.append(t1 - t0) @@ -418,5 +421,5 @@ def run(worker_id, run_config): p = mp.Process(target=run, args=(worker_id, run_config)) p.start() workers.append(p) - for p in workers: - p.join() + + wait_and_join(workers)