Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
KiddoZhu committed Sep 18, 2022
1 parent cf1a62c commit cd45518
Show file tree
Hide file tree
Showing 19 changed files with 95 additions and 98 deletions.
2 changes: 1 addition & 1 deletion doc/source/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ utilization of hardware. They can also be transferred between CPUs and GPUs usin

.. code:: bash
PackedMolecule(batch_size=4, num_nodes=[12, 6, 14, 9], num_edges=[22, 10, 30, 18],
PackedMolecule(batch_size=4, num_atoms=[12, 6, 14, 9], num_bonds=[22, 10, 30, 18],
device='cuda:0')
Just like original PyTorch tensors, graphs support a wide range of indexing
Expand Down
43 changes: 25 additions & 18 deletions torchdrug/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ def load_csv(self, csv_file, smiles_field="smiles", target_fields=None, verbose=
self.load_smiles(smiles, targets, verbose=verbose, **kwargs)

def load_pickle(self, pkl_file, verbose=0):
"""
Load the dataset from a pickle file.
Parameters:
pkl_file (str): file name
verbose (int, optional): output verbose level
"""
with utils.smart_open(pkl_file, "rb") as fin:
num_sample, tasks = pickle.load(fin)

Expand All @@ -133,6 +140,13 @@ def load_pickle(self, pkl_file, verbose=0):
self.targets[task] = value

def save_pickle(self, pkl_file, verbose=0):
"""
Save the dataset to a pickle file.
Parameters:
pkl_file (str): file name
verbose (int, optional): output verbose level
"""
with utils.smart_open(pkl_file, "wb") as fout:
num_sample = len(self.data)
tasks = self.targets.keys()
Expand Down Expand Up @@ -659,16 +673,16 @@ def load_sequence(self, sequences, targets, attributes=None, transform=None, laz
self.targets[field].append(targets[field][i])

@utils.copy_args(load_sequence)
def load_lmdbs(self, lmdb_files, number_field="num_examples", sequence_field="primary", target_fields=None,
def load_lmdbs(self, lmdb_files, sequence_field="primary", target_fields=None, number_field="num_examples",
transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from lmdb files.
Parameters:
lmdb_files (list of str): list of lmdb files
number_field (str, optional): name of the field of sample count in lmdb files
sequence_field (str, optional): name of the field of protein sequence in lmdb files
target_fields (list of str, optional): name of target fields in lmdb files
number_field (str, optional): name of the field of sample count in lmdb files
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
Expand Down Expand Up @@ -701,12 +715,13 @@ def load_lmdbs(self, lmdb_files, number_field="num_examples", sequence_field="pr
self.num_samples = num_samples

@utils.copy_args(data.Protein.from_molecule)
def load_pdbs(self, pdb_files, transform=None, lazy=False, verbose=0, **kwargs):
def load_pdbs(self, pdb_files, sanitize=True, transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from pdb files.
Parameters:
pdb_files (list of str): pdb file names
sanitize (bool, optional): whether to sanitize the molecule
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
Expand All @@ -729,7 +744,6 @@ def load_pdbs(self, pdb_files, transform=None, lazy=False, verbose=0, **kwargs):
pdb_files = tqdm(pdb_files, "Constructing proteins from pdbs")
for i, pdb_file in enumerate(pdb_files):
if not lazy or i == 0:
sanitize = kwargs.pop("sanitize", True)
mol = Chem.MolFromPDBFile(pdb_file, sanitize=sanitize)
if not mol:
logger.debug("Can't construct molecule from pdb file `%s`. Ignore this sample." % pdb_file)
Expand Down Expand Up @@ -779,10 +793,10 @@ def load_fasta(self, fasta_file, verbose=0, **kwargs):
@utils.copy_args(data.Protein.from_molecule)
def load_pickle(self, pkl_file, transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from pickle files.
Load the dataset from a pickle file.
Parameters:
pkl_file (str): pickle file name
pkl_file (str): file name
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
Expand All @@ -808,13 +822,6 @@ def load_pickle(self, pkl_file, transform=None, lazy=False, verbose=0, **kwargs)
self.data.append(protein)

def save_pickle(self, pkl_file, verbose=0):
"""
Save the dataset to pickle files.
Parameters:
pkl_file (str): pickle file name
verbose (int, optional): output verbose level
"""
with utils.smart_open(pkl_file, "wb") as fout:
num_sample = len(self.data)
pickle.dump(num_sample, fout)
Expand Down Expand Up @@ -890,16 +897,16 @@ def load_sequence(self, sequences, targets, attributes=None, transform=None, laz
self.targets[field].append(targets[field][i])

@utils.copy_args(load_sequence)
def load_lmdbs(self, lmdb_files, number_field="num_examples", sequence_field="primary", target_fields=None,
def load_lmdbs(self, lmdb_files, sequence_field="primary", target_fields=None, number_field="num_examples",
transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from lmdb files.
Parameters:
lmdb_files (list of str): file names
number_field (str, optional): name of the field of sample count in lmdb files
sequence_field (str or list of str, optional): names of the fields of protein sequence in lmdb files
target_fields (list of str, optional): name of target fields in lmdb files
number_field (str, optional): name of the field of sample count in lmdb files
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the protein pairs are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
Expand Down Expand Up @@ -1022,17 +1029,17 @@ def load_sequence(self, sequences, smiles, targets, num_samples, attributes=None
return num_samples

@utils.copy_args(load_sequence)
def load_lmdbs(self, lmdb_files, number_field="num_examples", sequence_field="target", smiles_field="drug",
target_fields=None, transform=None, lazy=False, verbose=0, **kwargs):
def load_lmdbs(self, lmdb_files, sequence_field="target", smiles_field="drug", target_fields=None,
number_field="num_examples", transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from lmdb files.
Parameters:
lmdb_files (list of str): file names
number_field (str, optional): name of the field of sample count in lmdb files
sequence_field (str, optional): name of the field of protein sequence in lmdb files
smiles_field (str, optional): name of the field of ligand SMILES string in lmdb files
target_fields (list of str, optional): name of target fields in lmdb files
number_field (str, optional): name of the field of sample count in lmdb files
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the protein-ligand pairs are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
Expand Down
2 changes: 1 addition & 1 deletion torchdrug/data/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def residue_symbol(residue):

@R.register("features.residue.default")
def residue_default(residue):
"""Default atom feature.
"""Default residue feature.
Features:
GetResidueName(): one-hot embedding for the residue symbol
Expand Down
53 changes: 25 additions & 28 deletions torchdrug/data/protein.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,7 @@ def from_molecule(cls, mol, atom_feature="default", bond_feature="default", resi
meta_dict=protein.meta_dict, **protein.data_dict)

@classmethod
def from_sequence_fast(cls, sequence):
"""
A faster version of creating a protein from a sequence.
Parameters:
sequence (str): string
"""
def _residue_from_sequence(cls, sequence):
residue_type = []
residue_feature = []
sequence = sequence + "G"
Expand All @@ -278,10 +272,16 @@ def from_sequence_fast(cls, sequence):
@classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_sequence(cls, sequence, atom_feature="default", bond_feature="default", residue_feature="default",
mol_feature=None, kekulize=False, residue_only=False):
mol_feature=None, kekulize=False):
"""
Create a protein from a sequence.
.. note::
It takes considerable time to construct proteins with a large number of atoms and bonds.
If you only need residue information, you may speed up the construction by setting
``atom_feature`` and ``bond_feature`` to ``None``.
Parameters:
sequence (str): protein sequence
atom_feature (str or list of str, optional): atom features to extract
Expand All @@ -292,14 +292,9 @@ def from_sequence(cls, sequence, atom_feature="default", bond_feature="default",
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
residue_only (bool, optional): only store residue information without atom information.
This can speed up the processing.
"""
if residue_only:
if residue_feature != "default":
raise ValueError("`residue_only` only supports the default residue feature, "
"but found `%s` for `residue_feature`" % residue_feature)
return cls.from_sequence_fast(sequence)
if atom_feature is None and bond_feature is None and residue_feature == "default":
return cls._residue_from_sequence(sequence)

mol = Chem.MolFromSequence(sequence)
if mol is None:
Expand All @@ -324,7 +319,7 @@ def from_pdb(cls, pdb_file, atom_feature="default", bond_feature="default", resi
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
sanitize (bool, optional): whether to sanitize the molecule.
sanitize (bool, optional): whether to sanitize the molecule
"""
if not os.path.exists(pdb_file):
raise FileNotFoundError("No such file `%s`" % pdb_file)
Expand Down Expand Up @@ -524,7 +519,7 @@ def repeat(self, count):
num_relation=num_relation, meta_dict=self.meta_dict, **data_dict)

def residue2atom(self, residue_index):
"""Map residue id to atom ids."""
"""Map residue ids to atom ids."""
residue_index = self._standarize_index(residue_index, self.num_residue)
if not hasattr(self, "node_inverted_index"):
self.node_inverted_index = self._build_node_inverted_index()
Expand Down Expand Up @@ -992,7 +987,7 @@ def from_molecule(cls, mols, atom_feature="default", bond_feature="default", res
offsets=protein._offsets, meta_dict=protein.meta_dict, **protein.data_dict)

@classmethod
def from_sequence_fast(cls, sequences):
def _residue_from_sequence(cls, sequences):
num_residues = []
residue_type = []
residue_feature = []
Expand Down Expand Up @@ -1021,10 +1016,16 @@ def from_sequence_fast(cls, sequences):
@classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_sequence(cls, sequences, atom_feature="default", bond_feature="default", residue_feature="default",
mol_feature=None, kekulize=False, residue_only=False):
mol_feature=None, kekulize=False):
"""
Create a packed protein from a list of sequences.
.. note::
It takes considerable time to construct proteins with a large number of atoms and bonds.
If you only need residue information, you may speed up the construction by setting
``atom_feature`` and ``bond_feature`` to ``None``.
Parameters:
sequences (str): list of protein sequences
atom_feature (str or list of str, optional): atom features to extract
Expand All @@ -1035,14 +1036,9 @@ def from_sequence(cls, sequences, atom_feature="default", bond_feature="default"
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
residue_only (bool, optional): only store residue information without atom information.
This can speed up the processing.
"""
if residue_only:
if residue_feature != "default":
raise ValueError("`residue_only` only supports the default residue feature, "
"but found `%s` for `residue_feature`" % residue_feature)
return cls.from_sequence_fast(sequences)
if atom_feature is None and bond_feature is None and residue_feature == "default":
return cls._residue_from_sequence(sequences)

mols = []
for sequence in sequences:
Expand All @@ -1056,7 +1052,7 @@ def from_sequence(cls, sequences, atom_feature="default", bond_feature="default"
@classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_pdb(cls, pdb_files, atom_feature="default", bond_feature="default", residue_feature="default",
mol_feature=None, kekulize=False):
mol_feature=None, kekulize=False, sanitize=False):
"""
Create a protein from a list of PDB files.
Expand All @@ -1070,10 +1066,11 @@ def from_pdb(cls, pdb_files, atom_feature="default", bond_feature="default", res
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
sanitize (bool, optional): whether to sanitize the molecule
"""
mols = []
for pdb_file in pdb_files:
mol = Chem.MolFromPDBFile(pdb_file)
mol = Chem.MolFromPDBFile(pdb_file, sanitize=sanitize)
mols.append(mol)

return cls.from_molecule(mols, atom_feature, bond_feature, residue_feature, mol_feature, kekulize)
Expand Down
2 changes: 1 addition & 1 deletion torchdrug/datasets/alphafolddb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


@R.register("datasets.AlphaFoldDB")
@utils.copy_args(data.ProteinDataset.load_pdbs, ignore=("filtered_pdb",))
@utils.copy_args(data.ProteinDataset.load_pdbs)
class AlphaFoldDB(data.ProteinDataset):
"""
3D protein structures predicted by AlphaFold.
Expand Down
18 changes: 10 additions & 8 deletions torchdrug/datasets/enzyme_commission.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@R.register("datasets.EnzymeCommission")
@utils.copy_args(data.ProteinDataset.load_pdbs, ignore=("filtered_pdb",))
@utils.copy_args(data.ProteinDataset.load_pdbs)
class EnzymeCommission(data.ProteinDataset):
"""
A set of proteins with their 3D structures and EC numbers, which describes their
Expand All @@ -23,7 +23,7 @@ class EnzymeCommission(data.ProteinDataset):
Parameters:
path (str): the path to store the dataset
test_cutoff (float): the test cutoff used to split the dataset
test_cutoff (float, optional): the test cutoff used to split the dataset
verbose (int, optional): output verbose level
**kwargs
"""
Expand All @@ -47,13 +47,14 @@ def __init__(self, path, test_cutoff=0.95, verbose=1, **kwargs):
pkl_file = os.path.join(path, self.processed_file)

csv_file = os.path.join(path, "nrPDB-EC_test.csv")
filtered_pdb = set()
pdb_ids = []
with open(csv_file, "r") as fin:
reader = csv.reader(fin, delimiter=",")
idx = self.test_cutoffs.index(test_cutoff) + 1
_ = next(reader)
for line in reader:
if line[idx] == "0": filtered_pdb.add(line[0])
if line[idx] == "0":
pdb_ids.append(line[0])

if os.path.exists(pkl_file):
self.load_pickle(pkl_file, verbose=verbose, **kwargs)
Expand All @@ -64,8 +65,8 @@ def __init__(self, path, test_cutoff=0.95, verbose=1, **kwargs):
pdb_files += sorted(glob.glob(os.path.join(split_path, split, "*.pdb")))
self.load_pdbs(pdb_files, verbose=verbose, **kwargs)
self.save_pickle(pkl_file, verbose=verbose)
if len(filtered_pdb) > 0:
self.filter_pdb(filtered_pdb)
if len(pdb_ids) > 0:
self.filter_pdb(pdb_ids)

tsv_file = os.path.join(path, "nrPDB-EC_annot.tsv")
pdb_ids = [os.path.basename(pdb_file).split("_")[0] for pdb_file in self.pdb_files]
Expand All @@ -74,12 +75,13 @@ def __init__(self, path, test_cutoff=0.95, verbose=1, **kwargs):
splits = [os.path.basename(os.path.dirname(pdb_file)) for pdb_file in self.pdb_files]
self.num_samples = [splits.count("train"), splits.count("valid"), splits.count("test")]

def filter_pdb(self, filtered_pdb):
def filter_pdb(self, pdb_ids):
pdb_ids = set(pdb_ids)
sequences = []
pdb_files = []
data = []
for sequence, pdb_file, protein in zip(self.sequences, self.pdb_files, self.data):
if os.path.basename(pdb_file).split("_")[0] in filtered_pdb:
if os.path.basename(pdb_file).split("_")[0] in pdb_ids:
continue
sequences.append(sequence)
pdb_files.append(pdb_file)
Expand Down
Loading

0 comments on commit cd45518

Please sign in to comment.