Skip to content

Commit

Permalink
Fix DGL
Browse files Browse the repository at this point in the history
  • Loading branch information
dynamicheart committed Oct 7, 2021
1 parent b8b3bf3 commit 460de59
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 18 deletions.
15 changes: 15 additions & 0 deletions example/dgl/common_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os

def get_default_timeout():
# In seconds
return 300

def wait_and_join(processes):
ret = os.waitpid(-1, 0)
if os.WEXITSTATUS(ret[1]) != 0:
print("Detect pid {:} error exit".format(ret[0]))
for p in processes:
p.kill()

for p in processes:
p.join()
1 change: 1 addition & 0 deletions example/dgl/multi_gpu/common_config.py
17 changes: 10 additions & 7 deletions example/dgl/multi_gpu/train_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Code: https://github.com/tkipf/gcn
"""
import argparse
import datetime
import torch
import torch.nn as nn
import torch.optim as optim
Expand All @@ -19,7 +20,7 @@
from torch.nn.parallel import DistributedDataParallel
import math
import sys

from common_config import *

class GCN(nn.Module):
def __init__(self,
Expand Down Expand Up @@ -225,7 +226,8 @@ def run(worker_id, run_config):
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=worker_id)
rank=worker_id,
timeout=datetime.timedelta(seconds=get_default_timeout()))

dataset = run_config['dataset']
g = run_config['g'].to(sample_device)
Expand Down Expand Up @@ -317,6 +319,9 @@ def run(worker_id, run_config):

if not run_config['pipelining']:
sync_device()
if num_worker > 1:
torch.distributed.barrier()


num_samples.append(sum([block.num_edges() for block in blocks]))
num_nodes.append(blocks[0].num_src_nodes())
Expand All @@ -325,9 +330,6 @@ def run(worker_id, run_config):
batch_labels = None
blocks = None

if num_worker > 1:
torch.distributed.barrier()

t4 = time.time()

sample_times.append(t1 - t0)
Expand Down Expand Up @@ -394,5 +396,6 @@ def run(worker_id, run_config):
p = mp.Process(target=run, args=(worker_id, run_config))
p.start()
workers.append(p)
for p in workers:
p.join()

wait_and_join(workers)

14 changes: 8 additions & 6 deletions example/dgl/multi_gpu/train_graphsage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import datetime
import dgl
import torch
import dgl.nn.pytorch as dglnn
Expand All @@ -12,6 +13,7 @@
import numpy as np
import math
import sys
from common_config import *


class SAGE(nn.Module):
Expand Down Expand Up @@ -207,7 +209,8 @@ def run(worker_id, run_config):
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=worker_id)
rank=worker_id,
timeout=datetime.timedelta(seconds=get_default_timeout()))

dataset = run_config['dataset']
g = run_config['g'].to(sample_device)
Expand Down Expand Up @@ -299,6 +302,8 @@ def run(worker_id, run_config):

if not run_config['pipelining']:
sync_device()
if num_worker > 1:
torch.distributed.barrier()

num_samples.append(sum([block.num_edges() for block in blocks]))
num_nodes.append(blocks[0].num_src_nodes())
Expand All @@ -307,9 +312,6 @@ def run(worker_id, run_config):
batch_labels = None
blocks = None

if num_worker > 1:
torch.distributed.barrier()

t4 = time.time()

sample_times.append(t1 - t0)
Expand Down Expand Up @@ -375,5 +377,5 @@ def run(worker_id, run_config):
p = mp.Process(target=run, args=(worker_id, run_config))
p.start()
workers.append(p)
for p in workers:
p.join()

wait_and_join(workers)
13 changes: 8 additions & 5 deletions example/dgl/multi_gpu/train_pinsage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import datetime
import dgl
import torch
import torch.optim as optim
Expand All @@ -12,6 +13,7 @@
import numpy as np
import math
import sys
from common_config import *

"""
We have made the following modification(or say, simplification) on PinSAGE,
Expand Down Expand Up @@ -256,7 +258,8 @@ def run(worker_id, run_config):
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=worker_id)
rank=worker_id,
timeout=datetime.timedelta(seconds=get_default_timeout()))

dataset = run_config['dataset']
g = run_config['g']
Expand Down Expand Up @@ -342,6 +345,8 @@ def run(worker_id, run_config):

if not run_config['pipelining']:
sync_device()
if num_worker > 1:
torch.distributed.barrier()

num_samples.append(sum([block.num_edges() for block in blocks]))
num_nodes.append(blocks[0].num_src_nodes())
Expand All @@ -350,8 +355,6 @@ def run(worker_id, run_config):
batch_labels = None
blocks = None

if num_worker > 1:
torch.distributed.barrier()
t4 = time.time()

sample_times.append(t1 - t0)
Expand Down Expand Up @@ -418,5 +421,5 @@ def run(worker_id, run_config):
p = mp.Process(target=run, args=(worker_id, run_config))
p.start()
workers.append(p)
for p in workers:
p.join()

wait_and_join(workers)

0 comments on commit 460de59

Please sign in to comment.