Skip to content

Commit

Permalink
add protein transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
Oxer11 committed Sep 16, 2022
1 parent 921f3f3 commit cf1a62c
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 10 deletions.
7 changes: 4 additions & 3 deletions torchdrug/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .transform import TargetNormalize, RemapAtomType, RandomBFSOrder, Shuffle, VirtualNode, VirtualAtom, Compose
from .transform import NormalizeTarget, RemapAtomType, RandomBFSOrder, Shuffle, VirtualNode, \
VirtualAtom, TruncateProtein, ProteinView, Compose

__all__ = [
"TargetNormalize", "RemapAtomType", "RandomBFSOrder", "Shuffle",
"VirtualNode", "VirtualAtom", "Compose",
"NormalizeTarget", "RemapAtomType", "RandomBFSOrder", "Shuffle",
"VirtualNode", "VirtualAtom", "TruncateProtein", "ProteinView", "Compose",
]
91 changes: 84 additions & 7 deletions torchdrug/transforms/transform.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import copy
import logging
from collections import deque
from random import randint

import torch

from torchdrug import core
from torchdrug.core import Registry as R


logger = logging.getLogger(__name__)


class TargetNormalize(object):
@R.register("transforms.NormalizeTarget")
class NormalizeTarget(core.Configurable):
"""
Normalize the target values in a sample.
Expand All @@ -30,9 +35,11 @@ def __call__(self, item):
return item


class RemapAtomType(object):
@R.register("transforms.RemapAtomType")
class RemapAtomType(core.Configurable):
"""
Map atom types to their index in a vocabulary. Atom types that don't present in the vocabulary are mapped to -1.
Parameters:
atom_types (array_like): vocabulary of atom types
"""
Expand All @@ -51,7 +58,8 @@ def __call__(self, item):
return item


class RandomBFSOrder(object):
@R.register("transforms.RandomBFSOrder")
class RandomBFSOrder(core.Configurable):
"""
Order the nodes in a graph according to a random BFS order.
"""
Expand Down Expand Up @@ -81,9 +89,11 @@ def __call__(self, item):
return item


class Shuffle(object):
@R.register("transforms.Shuffle")
class Shuffle(core.Configurable):
"""
Shuffle the order of nodes and edges in a graph.
Parameters:
shuffle_node (bool, optional): shuffle node order or not
shuffle_edge (bool, optional): shuffle edge order or not
Expand Down Expand Up @@ -125,7 +135,8 @@ def transform_data(self, data, meta):
return new_data


class VirtualNode(object):
@R.register("transforms.VirtualNode")
class VirtualNode(core.Configurable):
"""
Add a virtual node and connect it with every node in the graph.
Expand Down Expand Up @@ -199,9 +210,11 @@ def __call__(self, item):
return item


class VirtualAtom(VirtualNode):
@R.register("transforms.VirtualAtom")
class VirtualAtom(VirtualNode, core.Configurable):
"""
Add a virtual atom and connect it with every atom in the molecule.
Parameters:
atom_type (int, optional): type of the virtual atom
bond_type (int, optional): type of the virtual bonds
Expand All @@ -215,9 +228,73 @@ def __init__(self, atom_type=None, bond_type=None, node_feature=None, edge_featu
edge_feature=edge_feature, atom_type=atom_type, **kwargs)


class Compose(object):
@R.register("transforms.TruncateProtein")
class TruncateProtein(core.Configurable):
"""
Truncate over long protein sequences into a fixed length.
Parameters:
max_length (int, optional): maximal length of the sequence. Truncate the sequence if it exceeds this limit.
random (bool, optional): truncate the sequence at a random position.
If not, truncate the suffix of the sequence.
keys (str or list of str, optional): keys for the items that require truncation in a sample
"""

def __init__(self, max_length=None, random=False, keys="graph"):
self.truncate_length = max_length
self.random = random
if isinstance(keys, str):
keys = [keys]
self.keys = keys

def __call__(self, item):
new_item = item.copy()
for key in self.keys:
graph = item[key]
if graph.num_residue > self.truncate_length:
if self.random:
start = randint(0, graph.num_residue - self.truncate_length)
else:
start = 0
end = start + self.truncate_length
mask = torch.zeros(graph.num_residue, dtype=torch.bool, device=graph.device)
mask[start:end] = True
graph = graph.subresidue(mask)

new_item[key] = graph
return new_item


@R.register("transforms.ProteinView")
class ProteinView(core.Configurable):
"""
Convert proteins to a specific view.
Parameters:
view (str): protein view. Can be ``atom`` or ``residue``.
keys (str or list of str, optional): keys for the items that require view change in a sample
"""

def __init__(self, view, keys="graph"):
self.view = view
if isinstance(keys, str):
keys = [keys]
self.keys = keys

def __call__(self, item):
item = item.copy()
for key in self.keys:
graph = copy.copy(item[key])
graph.view = self.view
item[key] = graph
return item


@R.register("transforms.Compose")
class Compose(core.Configurable):
"""
Compose a list of transforms into one.
Parameters:
transforms (list of callable): list of transforms
"""
Expand Down

0 comments on commit cf1a62c

Please sign in to comment.