diff --git a/torchdrug/transforms/__init__.py b/torchdrug/transforms/__init__.py index 36da4fa..f8109df 100644 --- a/torchdrug/transforms/__init__.py +++ b/torchdrug/transforms/__init__.py @@ -1,6 +1,7 @@ -from .transform import TargetNormalize, RemapAtomType, RandomBFSOrder, Shuffle, VirtualNode, VirtualAtom, Compose +from .transform import NormalizeTarget, RemapAtomType, RandomBFSOrder, Shuffle, VirtualNode, \ + VirtualAtom, TruncateProtein, ProteinView, Compose __all__ = [ - "TargetNormalize", "RemapAtomType", "RandomBFSOrder", "Shuffle", - "VirtualNode", "VirtualAtom", "Compose", + "NormalizeTarget", "RemapAtomType", "RandomBFSOrder", "Shuffle", + "VirtualNode", "VirtualAtom", "TruncateProtein", "ProteinView", "Compose", ] diff --git a/torchdrug/transforms/transform.py b/torchdrug/transforms/transform.py index 3245ca9..e8d1606 100644 --- a/torchdrug/transforms/transform.py +++ b/torchdrug/transforms/transform.py @@ -1,14 +1,19 @@ import copy import logging from collections import deque +from random import randint import torch +from torchdrug import core +from torchdrug.core import Registry as R + logger = logging.getLogger(__name__) -class TargetNormalize(object): +@R.register("transforms.NormalizeTarget") +class NormalizeTarget(core.Configurable): """ Normalize the target values in a sample. @@ -30,9 +35,11 @@ def __call__(self, item): return item -class RemapAtomType(object): +@R.register("transforms.RemapAtomType") +class RemapAtomType(core.Configurable): """ Map atom types to their index in a vocabulary. Atom types that don't present in the vocabulary are mapped to -1. + Parameters: atom_types (array_like): vocabulary of atom types """ @@ -51,7 +58,8 @@ def __call__(self, item): return item -class RandomBFSOrder(object): +@R.register("transforms.RandomBFSOrder") +class RandomBFSOrder(core.Configurable): """ Order the nodes in a graph according to a random BFS order. """ @@ -81,9 +89,11 @@ def __call__(self, item): return item -class Shuffle(object): +@R.register("transforms.Shuffle") +class Shuffle(core.Configurable): """ Shuffle the order of nodes and edges in a graph. + Parameters: shuffle_node (bool, optional): shuffle node order or not shuffle_edge (bool, optional): shuffle edge order or not @@ -125,7 +135,8 @@ def transform_data(self, data, meta): return new_data -class VirtualNode(object): +@R.register("transforms.VirtualNode") +class VirtualNode(core.Configurable): """ Add a virtual node and connect it with every node in the graph. @@ -199,9 +210,11 @@ def __call__(self, item): return item -class VirtualAtom(VirtualNode): +@R.register("transforms.VirtualAtom") +class VirtualAtom(VirtualNode, core.Configurable): """ Add a virtual atom and connect it with every atom in the molecule. + Parameters: atom_type (int, optional): type of the virtual atom bond_type (int, optional): type of the virtual bonds @@ -215,9 +228,73 @@ def __init__(self, atom_type=None, bond_type=None, node_feature=None, edge_featu edge_feature=edge_feature, atom_type=atom_type, **kwargs) -class Compose(object): +@R.register("transforms.TruncateProtein") +class TruncateProtein(core.Configurable): + """ + Truncate over long protein sequences into a fixed length. + + Parameters: + max_length (int, optional): maximal length of the sequence. Truncate the sequence if it exceeds this limit. + random (bool, optional): truncate the sequence at a random position. + If not, truncate the suffix of the sequence. + keys (str or list of str, optional): keys for the items that require truncation in a sample + """ + + def __init__(self, max_length=None, random=False, keys="graph"): + self.truncate_length = max_length + self.random = random + if isinstance(keys, str): + keys = [keys] + self.keys = keys + + def __call__(self, item): + new_item = item.copy() + for key in self.keys: + graph = item[key] + if graph.num_residue > self.truncate_length: + if self.random: + start = randint(0, graph.num_residue - self.truncate_length) + else: + start = 0 + end = start + self.truncate_length + mask = torch.zeros(graph.num_residue, dtype=torch.bool, device=graph.device) + mask[start:end] = True + graph = graph.subresidue(mask) + + new_item[key] = graph + return new_item + + +@R.register("transforms.ProteinView") +class ProteinView(core.Configurable): + """ + Convert proteins to a specific view. + + Parameters: + view (str): protein view. Can be ``atom`` or ``residue``. + keys (str or list of str, optional): keys for the items that require view change in a sample + """ + + def __init__(self, view, keys="graph"): + self.view = view + if isinstance(keys, str): + keys = [keys] + self.keys = keys + + def __call__(self, item): + item = item.copy() + for key in self.keys: + graph = copy.copy(item[key]) + graph.view = self.view + item[key] = graph + return item + + +@R.register("transforms.Compose") +class Compose(core.Configurable): """ Compose a list of transforms into one. + Parameters: transforms (list of callable): list of transforms """