Skip to content

Commit

Permalink
fix pyg
Browse files Browse the repository at this point in the history
  • Loading branch information
dynamicheart committed Oct 8, 2021
1 parent 1be5d07 commit f320cd0
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 28 deletions.
19 changes: 9 additions & 10 deletions example/pyg/multi_gpu/train_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_run_config():
default_run_config['num_sampling_worker'] = 0

# In PyG, the order from root to leaf is from front to end
default_run_config['fanout'] = [5, 10, 15]
default_run_config['fanout'] = [15, 10, 5]
default_run_config['num_epoch'] = 10
default_run_config['num_hidden'] = 256
default_run_config['batch_size'] = 8000
Expand Down Expand Up @@ -188,14 +188,13 @@ def run(worker_id, run_config):
dev_id = run_config['devices'][worker_id]
num_worker = run_config['num_worker']

if num_worker > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=num_worker,
rank=worker_id,
timeout=datetime.timedelta(seconds=get_default_timeout()))
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=num_worker,
rank=worker_id,
timeout=datetime.timedelta(seconds=get_default_timeout()))

dataset = run_config['dataset']
g = run_config['g']
Expand All @@ -211,7 +210,7 @@ def run(worker_id, run_config):
dataloader = MyNeighborSampler(g, sizes=run_config['fanout'],
batch_size=run_config['batch_size'],
node_idx=train_nids,
shuffle=False,
shuffle=True,
return_e_id=False,
drop_last=False,
num_workers=run_config['num_sampling_worker'],
Expand Down
19 changes: 9 additions & 10 deletions example/pyg/multi_gpu/train_graphsage.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_run_config():
default_run_config['num_sampling_worker'] = 0

# In PyG, the order from root to leaf is from front to end
default_run_config['fanout'] = [25, 10]
default_run_config['fanout'] = [10, 25]
default_run_config['num_epoch'] = 10
default_run_config['num_hidden'] = 256
default_run_config['batch_size'] = 8000
Expand Down Expand Up @@ -154,14 +154,13 @@ def run(worker_id, run_config):
dev_id = run_config['devices'][worker_id]
num_worker = run_config['num_worker']

if num_worker > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=num_worker,
rank=worker_id,
timeout=datetime.timedelta(seconds=get_default_timeout()))
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=num_worker,
rank=worker_id,
timeout=datetime.timedelta(seconds=get_default_timeout()))

dataset = run_config['dataset']
g = run_config['g']
Expand All @@ -176,7 +175,7 @@ def run(worker_id, run_config):
dataloader = MyNeighborSampler(g, sizes=run_config['fanout'],
batch_size=run_config['batch_size'],
node_idx=train_nids,
shuffle=False,
shuffle=True,
return_e_id=False,
drop_last=False,
num_workers=run_config['num_sampling_worker'],
Expand Down
9 changes: 5 additions & 4 deletions example/pyg/train_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def parse_args(default_run_config):
default=default_run_config['batch_size'])
argparser.add_argument(
'--lr', type=float, default=default_run_config['lr'])
argparser.add_argument('--dropout', type=float,
default=default_run_config['dropout'])
argparser.add_argument('--weight-decay', type=float,
default=default_run_config['weight_decay'])

Expand All @@ -115,7 +117,7 @@ def get_run_config():
default_run_config['num_sampling_worker'] = 0

# In PyG, the order from root to leaf is from front to end
default_run_config['fanout'] = [5, 10, 15]
default_run_config['fanout'] = [15, 10, 5]
default_run_config['num_epoch'] = 10
default_run_config['num_hidden'] = 256
default_run_config['batch_size'] = 8000
Expand Down Expand Up @@ -178,9 +180,8 @@ def run():
run_config = get_run_config()
device = torch.device(run_config['device'])

dataset = fastgraph.dataset(
run_config['dataset'], run_config['root_path'], force_load64=True)
g = dataset.to_pyg_graph()
dataset = run_config['dataset']
g = run_config['g']
feat = dataset.feat
label = dataset.label
train_nids = dataset.train_set
Expand Down
7 changes: 3 additions & 4 deletions example/pyg/train_graphsage.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_run_config():
default_run_config['num_sampling_worker'] = 0

# In PyG, the order from root to leaf is from front to end
default_run_config['fanout'] = [25, 10]
default_run_config['fanout'] = [10, 25]
default_run_config['num_epoch'] = 10
default_run_config['num_hidden'] = 256
default_run_config['batch_size'] = 8000
Expand Down Expand Up @@ -150,9 +150,8 @@ def run():
run_config = get_run_config()
device = torch.device(run_config['device'])

dataset = fastgraph.dataset(
run_config['dataset'], run_config['root_path'], force_load64=True)
g = dataset.to_pyg_graph()
dataset = run_config['dataset']
g = run_config['g']
feat = dataset.feat
label = dataset.label
train_nids = dataset.train_set
Expand Down

0 comments on commit f320cd0

Please sign in to comment.