Skip to content

Commit

Permalink
Change: FlowX_shap
Browse files Browse the repository at this point in the history
  • Loading branch information
CM-BF committed Jul 24, 2023
1 parent 820090b commit 2067dbc
Show file tree
Hide file tree
Showing 11 changed files with 712 additions and 12 deletions.
313 changes: 312 additions & 1 deletion xgraph/data/dataset_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,72 @@
from xgraph.definitions import ROOT_DIR
import pickle as pkl

import os
import re
import torch
import numpy as np
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Data, InMemoryDataset, \
download_url

try:
from rdkit import Chem
except ImportError:
Chem = None
import os.path as osp
import zipfile
import gzip

x_map = {
'atomic_num':
list(range(0, 119)),
'chirality': [
'CHI_UNSPECIFIED',
'CHI_TETRAHEDRAL_CW',
'CHI_TETRAHEDRAL_CCW',
'CHI_OTHER',
],
'degree':
list(range(0, 11)),
'formal_charge':
list(range(-5, 7)),
'num_hs':
list(range(0, 9)),
'num_radical_electrons':
list(range(0, 5)),
'hybridization': [
'UNSPECIFIED',
'S',
'SP',
'SP2',
'SP3',
'SP3D',
'SP3D2',
'OTHER',
],
'is_aromatic': [False, True],
'is_in_ring': [False, True],
}

e_map = {
'bond_type': [
'misc',
'SINGLE',
'DOUBLE',
'TRIPLE',
'AROMATIC',
],
'stereo': [
'STEREONONE',
'STEREOZ',
'STEREOE',
'STEREOCIS',
'STEREOTRANS',
'STEREOANY',
],
'is_conjugated': [False, True],
}


def index_to_mask(index, size):
mask = torch.zeros(size, dtype=torch.bool, device=index.device)
Expand Down Expand Up @@ -634,4 +700,249 @@ def process(self):
data_list.append(self.gen_data(1))

data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
torch.save((data, slices), self.processed_paths[0])


def maybe_log(path, log=True):
if log:
print('Extracting', path)


def extract_gz(path, folder, log=True):
maybe_log(path, log)
with gzip.open(path, 'r') as r:
with open(osp.join(folder, '.'.join(os.path.basename(path).split('.')[:-1])), 'wb') as w:
w.write(r.read())


def extract_zip(path, folder, log=True):
r"""Extracts a zip archive to a specific folder.
Args:
path (string): The path to the tar archive.
folder (string): The folder.
log (bool, optional): If :obj:`False`, will not print anything to the
console. (default: :obj:`True`)
"""
maybe_log(path, log)
with zipfile.ZipFile(path, 'r') as f:
f.extractall(folder)


class MoleculeDataset(InMemoryDataset):
r"""
The extension of MoleculeNet with `MUTAG <https://pubs.acs.org/doi/10.1021/jm00106a046>`_.
The `MoleculeNet benchmark collection <http://moleculenet.ai/datasets-1>`_ from the
`MoleculeNet: A Benchmark for Molecular Machine Learning <https://arxiv.org/abs/1703.00564>`_
paper, containing datasets from physical chemistry, biophysics and physiology.
The MoleculeNet datasets come with the additional node and edge features introduced by
the `Open Graph Benchmark <https://ogb.stanford.edu/docs/graphprop/>`_, and the node features
in MUTAG dataset are one hot features denoting the atom types.
Args:
root (string): Root directory where the dataset should be saved.
name (string): The name of the dataset (:obj:`"MUTAG"`, :obj:`"ESOL"`,
:obj:`"FreeSolv"`, :obj:`"Lipo"`, :obj:`"PCBA"`, :obj:`"MUV"`,
:obj:`"HIV"`, :obj:`"BACE"`, :obj:`"BBPB"`, :obj:`"Tox21"`,
:obj:`"ToxCast"`, :obj:`"SIDER"`, :obj:`"ClinTox"`).
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""

url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/{}'
mutag_url = 'https://github.com/divelab/DIG_storage/raw/main/xgraph/datasets/{}'

# Format: name: [display_name, url_name, filename, smiles_idx, y_idx]
names = {
'mutag': ['MUTAG', 'MUTAG.zip', None, None],
'esol': ['ESOL', 'delaney-processed.csv', 'delaney-processed.csv', -1, -2],
'freesolv': ['FreeSolv', 'SAMPL.csv', 'SAMPL.csv', 1, 2],
'lipo': ['Lipophilicity', 'Lipophilicity.csv', 'Lipophilicity.csv', 2, 1],
'pcba': ['PCBA', 'pcba.csv.gz', 'pcba.csv', -1,
slice(0, 128)],
'muv': ['MUV', 'muv.csv.gz', 'muv.csv', -1,
slice(0, 17)],
'hiv': ['HIV', 'HIV.csv', 'HIV.csv', 0, -1],
'bace': ['BACE', 'bace.csv', 'bace.csv', 0, 2],
'bbbp': ['BBBP', 'BBBP.csv', 'BBBP.csv', -1, -2],
'tox21': ['Tox21', 'tox21.csv.gz', 'tox21.csv', -1,
slice(0, 12)],
'toxcast':
['ToxCast', 'toxcast_data.csv.gz', 'toxcast_data.csv', 0,
slice(1, 618)],
'sider': ['SIDER', 'sider.csv.gz', 'sider.csv', 0,
slice(1, 28)],
'clintox': ['ClinTox', 'clintox.csv.gz', 'clintox.csv', 0,
slice(1, 3)],
}

def __init__(self, root, name, transform=None, pre_transform=None,
pre_filter=None):

if Chem is None:
raise ImportError('`MoleculeNet` requires `rdkit`.')

self.name = name.lower()
assert self.name in self.names.keys()
super(MoleculeDataset, self).__init__(root, transform, pre_transform,
pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_dir(self):
return osp.join(self.root, self.name, 'raw')

@property
def processed_dir(self):
return osp.join(self.root, self.name, 'processed')

@property
def raw_file_names(self):
if self.name.lower() == 'MUTAG'.lower():
return ['MUTAG_A.txt', 'MUTAG_graph_labels.txt', 'MUTAG_graph_indicator.txt',
'MUTAG_node_labels.txt', 'README.txt']
else:
return self.names[self.name][2]

@property
def processed_file_names(self):
return 'data.pt'

def download(self):
if self.name.lower() == 'MUTAG'.lower():
url = self.mutag_url.format(self.names[self.name][1])
else:
url = self.url.format(self.names[self.name][1])
path = download_url(url, self.raw_dir)
if self.names[self.name][1][-2:] == 'gz':
extract_gz(path, self.raw_dir)
os.unlink(path)
elif self.names[self.name][1][-3:] == 'zip':
extract_zip(path, self.raw_dir)
os.unlink(path)

def process(self):
if self.name.lower() == 'MUTAG'.lower():
with open(os.path.join(self.raw_dir, 'MUTAG_node_labels.txt'), 'r') as f:
nodes_all_temp = f.read().splitlines()
nodes_all = [int(i) for i in nodes_all_temp]

adj_all = np.zeros((len(nodes_all), len(nodes_all)))
with open(os.path.join(self.raw_dir, 'MUTAG_A.txt'), 'r') as f:
adj_list = f.read().splitlines()
for item in adj_list:
lr = item.split(', ')
l = int(lr[0])
r = int(lr[1])
adj_all[l - 1, r - 1] = 1

with open(os.path.join(self.raw_dir, 'MUTAG_graph_indicator.txt'), 'r') as f:
graph_indicator_temp = f.read().splitlines()
graph_indicator = [int(i) for i in graph_indicator_temp]
graph_indicator = np.array(graph_indicator)

with open(os.path.join(self.raw_dir, 'MUTAG_graph_labels.txt'), 'r') as f:
graph_labels_temp = f.read().splitlines()
graph_labels = [int(i) for i in graph_labels_temp]

data_list = []
for i in range(1, 189):
idx = np.where(graph_indicator == i)
graph_len = len(idx[0])
adj = adj_all[idx[0][0]:idx[0][0] + graph_len, idx[0][0]:idx[0][0] + graph_len]
label = int(graph_labels[i - 1] == 1)
feature = nodes_all[idx[0][0]:idx[0][0] + graph_len]
nb_clss = 7
targets = np.array(feature).reshape(-1)
one_hot_feature = np.eye(nb_clss)[targets]
data_example = Data(x=torch.from_numpy(one_hot_feature).float(),
edge_index=dense_to_sparse(torch.from_numpy(adj))[0],
y=label)
data_list.append(data_example)
else:
with open(self.raw_paths[0], 'r') as f:
dataset = f.read().split('\n')[1:-1]
dataset = [x for x in dataset if len(x) > 0] # Filter empty lines.

data_list = []
for line in dataset:
line = re.sub(r'\".*\"', '', line) # Replace ".*" strings.
line = line.split(',')

smiles = line[self.names[self.name][3]]
ys = line[self.names[self.name][4]]
ys = ys if isinstance(ys, list) else [ys]

ys = [float(y) if len(y) > 0 else float('NaN') for y in ys]
y = torch.tensor(ys, dtype=torch.float).view(1, -1)

mol = Chem.MolFromSmiles(smiles)
if mol is None:
continue

xs = []
for atom in mol.GetAtoms():
x = []
x.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
x.append(x_map['chirality'].index(str(atom.GetChiralTag())))
x.append(x_map['degree'].index(atom.GetTotalDegree()))
x.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
x.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
x.append(x_map['num_radical_electrons'].index(
atom.GetNumRadicalElectrons()))
x.append(x_map['hybridization'].index(
str(atom.GetHybridization())))
x.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
x.append(x_map['is_in_ring'].index(atom.IsInRing()))
xs.append(x)

x = torch.tensor(xs, dtype=torch.long).view(-1, 9)

edge_indices, edge_attrs = [], []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()

e = []
e.append(e_map['bond_type'].index(str(bond.GetBondType())))
e.append(e_map['stereo'].index(str(bond.GetStereo())))
e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))

edge_indices += [[i, j], [j, i]]
edge_attrs += [e, e]

edge_index = torch.tensor(edge_indices)
edge_index = edge_index.t().to(torch.long).view(2, -1)
edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)

# Sort indices.
if edge_index.numel() > 0:
perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]

data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y,
smiles=smiles)

if self.pre_filter is not None and not self.pre_filter(data):
continue

if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)
torch.save(self.collate(data_list), self.processed_paths[0])

def __repr__(self):
return '{}({})'.format(self.names[self.name][0], len(self))
3 changes: 2 additions & 1 deletion xgraph/data/dataset_loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
from .ba_lrp import ba_lrp
from .ba_shapes import ba_shapes
from .graph_sst2 import graph_sst2
from .ba_traffic import ba_traffic
from .ba_traffic import ba_traffic
from .mutag import mutag
51 changes: 51 additions & 0 deletions xgraph/data/dataset_loaders/mutag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
FileName: molecule_datasets.py
Description:
Time: 2021/11/12 16:18
Project: GNN_benchmark
Author: Shurui Gui
"""
from torch_geometric.datasets import MoleculeNet
import torch
from torch.utils.data import random_split
from xgraph import data_args
from xgraph.definitions import ROOT_DIR
from xgraph.kernel.utils import Metric
import os
from xgraph.data.dataset_gen import MoleculeDataset


def mutag(name):
data_args.dataset_type = 'mol'
data_args.model_level = 'graph'

dataset = MoleculeDataset(root=os.path.join(ROOT_DIR, 'xgraph', 'datasets'), name=name)
dataset.data.x = dataset.data.x.to(torch.float32)
dataset.data.y = dataset.data.y.unsqueeze(1).float()
data_args.dim_node = dataset.num_node_features
data_args.dim_edge = dataset.num_edge_features
data_args.num_targets = dataset.num_classes # This so-called classes are actually targets.

# Define models' output shape.
if Metric.cur_task == 'bcs':
data_args.num_classes = 2
elif Metric.cur_task == 'reg':
data_args.num_classes = 1

assert data_args.target_idx != -1, 'Explaining on multi tasks is meaningless.'
assert data_args.target_idx <= dataset.data.y.shape[1], 'No such target in the dataset.'

dataset.data.y = dataset.data.y[:, data_args.target_idx]
# --- clear buffer dataset._data_list ---
dataset._data_list = None
data_args.num_targets = 1

dataset_len = len(dataset)
dataset_split = [int(dataset_len * data_args.dataset_split[0]),
int(dataset_len * data_args.dataset_split[1]),
0]
dataset_split[2] = dataset_len - dataset_split[0] - dataset_split[1]
train_set, val_set, test_set = \
random_split(dataset, dataset_split)

return {'train': train_set, 'val': val_set, 'test': test_set}
2 changes: 2 additions & 0 deletions xgraph/data/dataset_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def load_dataset(name: str) -> dir:
# To Do: use transform to argument data
if name in molecule_set:
return benchmark_datasets.molecule_datasets(name)
elif name == 'mutag':
return benchmark_datasets.mutag(name)
elif name == 'ba_lrp':
return benchmark_datasets.ba_lrp(name)
elif name == 'ba_infe':
Expand Down
Loading

0 comments on commit 2067dbc

Please sign in to comment.