-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathppi_ka_s.py
30 lines (23 loc) · 1.03 KB
/
ppi_ka_s.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import json
import numpy as np
import networkx as nx
from networkx.readwrite import json_graph
from dgl.data.ppi import PPIDataset
class StudentPPIDataset(PPIDataset):
"""Customized PPI Dataset for the student GNN, inherited from dgl.data.ppi.PPIDataset
Args:
PPIDataset (dgl.data.ppi): dgl.data.ppi.PPIDataset
"""
def __getitem__(self, item):
"""This is the function that returns the i-th sample.
Args:
item (integer): the sample index
Returns:
tuple: a tuple containing the graphs, the node features as well as the corresponding labels
"""
if self.mode == 'train':
return self.train_graphs[item], self.features[self.train_mask_list[item]], self.train_labels[item]
if self.mode == 'valid':
return self.valid_graphs[item], self.features[self.valid_mask_list[item]], self.valid_labels[item]
if self.mode == 'test':
return self.test_graphs[item], self.features[self.test_mask_list[item]], self.test_labels[item]