Skip to content

Commit

Permalink
add papers implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
KiddoZhu committed Sep 19, 2022
1 parent 186aef2 commit 09ffe18
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 23 deletions.
5 changes: 5 additions & 0 deletions doc/source/api/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ MutualInformation
.. autoclass:: MutualInformation
:members:

SinusoidalPositionEmbedding
^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: SinusoidalPositionEmbedding
:members:

PairNorm
^^^^^^^^
.. autoclass:: PairNorm
Expand Down
7 changes: 7 additions & 0 deletions doc/source/bibliography.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@
.. Retrosynthesis
.. _G2Gs: https://arxiv.org/pdf/2003.12725.pdf

.. Protein Representation Learning
.. _TAPE: https://proceedings.neurips.cc/paper/2019/file/37f65c068b7723cd7809ee2d31d7861c-Paper.pdf
.. _ProteinCNN: https://arxiv.org/pdf/2011.03443.pdf
.. _ESM: https://www.biorxiv.org/content/10.1101/622803v1.full.pdf
.. _GearNet: https://arxiv.org/pdf/2203.06125.pdf

.. Knowledge Graph Reasoning
.. _TransE: http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf
.. _DistMult: https://arxiv.org/pdf/1412.6575.pdf
.. _ComplEx: http://proceedings.mlr.press/v48/trouillon16.pdf
Expand Down
2 changes: 1 addition & 1 deletion doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
#
html_theme = "furo"

html_logo = "../../asset/logo.svg"
html_logo = "../../asset/torchdrug_logo_full.svg"

# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
Expand Down
35 changes: 35 additions & 0 deletions doc/source/paper.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,41 @@ Retrosynthesis
:class:`SynthonCompletion <torchdrug.tasks.SynthonCompletion>`,
:class:`Retrosynthesis <torchdrug.tasks.Retrosynthesis>`

Protein Representation Learning
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

1. `Evaluating Protein Transfer Learning with TAPE <TAPE_>`_

Roshan Rao, Nicholas Bhattacharya, Neil Thomas, Yan Duan, Xi Chen, John Canny, Pieter Abbeel, Yun S Song. NeurIPS 2019.

:class:`SinusoidalPositionEmbedding <torchdrug.layers.SinusoidalPositionEmbedding>`
:class:`SelfAttentionBlock <torchdrug.layers.SelfAttentionBlock>`
:class:`ProteinResNetBlock <torchdrug.layers.ProteinResNetBlock>`
:class:`ProteinBERTBlock <torchdrug.layers.ProteinBERTBlock>`
:class:`ProteinResNet <torchdrug.models.ProteinResNet>`
:class:`ProteinLSTM <torchdrug.models.ProteinLSTM>`
:class:`ProteinBERT <torchdrug.models.ProteinBERT>`

2. `Is Transfer Learning Necessary for Protein Landscape Prediction? <ProteinCNN_>`_

Amir Shanehsazzadeh, David Belanger, David Dohan. arXiv 2020.

:class:`ProteinCNN <torchdrug.models.ProteinCNN>`

3. `Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences <ESM_>`_

Alexander Rives, Joshua Meier, Tom Sercu, Siddharth Goyal, Zeming Lin, Jason Liu, Demi Guo, Myle Ott, C. Lawrence Zitnick, Jerry Ma, Rob Fergus. PNAS 2021.

:class:`EvolutionaryScaleModeling <torchdrug.models.EvolutionaryScaleModeling>`

4. `Protein Representation Learning by Geometric Structure Pretraining <GearNet_>`_

Zuobai Zhang, Minghao Xu, Arian Jamasb, Vijil Chenthamarakshan, Aurélie Lozano, Payel Das, Jian Tang. arXiv 2022.

:class:`GeometricRelationalGraphConv <torchdrug.layers.GeometricRelationalGraphConv>`
:class:`GeometryAwareRelationalGraphNeuralNetwork <torchdrug.models.GeometryAwareRelationalGraphNeuralNetwork>`
:mod:`torchdrug.layers.geometry`

Knowledge Graph Reasoning
^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
5 changes: 2 additions & 3 deletions torchdrug/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,13 +715,12 @@ def load_lmdbs(self, lmdb_files, sequence_field="primary", target_fields=None, n
self.num_samples = num_samples

@utils.copy_args(data.Protein.from_molecule)
def load_pdbs(self, pdb_files, sanitize=True, transform=None, lazy=False, verbose=0, **kwargs):
def load_pdbs(self, pdb_files, transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from pdb files.
Parameters:
pdb_files (list of str): pdb file names
sanitize (bool, optional): whether to sanitize the molecule
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
Expand All @@ -744,7 +743,7 @@ def load_pdbs(self, pdb_files, sanitize=True, transform=None, lazy=False, verbos
pdb_files = tqdm(pdb_files, "Constructing proteins from pdbs")
for i, pdb_file in enumerate(pdb_files):
if not lazy or i == 0:
mol = Chem.MolFromPDBFile(pdb_file, sanitize=sanitize)
mol = Chem.MolFromPDBFile(pdb_file)
if not mol:
logger.debug("Can't construct molecule from pdb file `%s`. Ignore this sample." % pdb_file)
continue
Expand Down
10 changes: 4 additions & 6 deletions torchdrug/data/protein.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def from_sequence(cls, sequence, atom_feature="default", bond_feature="default",
@classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_pdb(cls, pdb_file, atom_feature="default", bond_feature="default", residue_feature="default",
mol_feature=None, kekulize=False, sanitize=False):
mol_feature=None, kekulize=False):
"""
Create a protein from a PDB file.
Expand All @@ -319,11 +319,10 @@ def from_pdb(cls, pdb_file, atom_feature="default", bond_feature="default", resi
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
sanitize (bool, optional): whether to sanitize the molecule
"""
if not os.path.exists(pdb_file):
raise FileNotFoundError("No such file `%s`" % pdb_file)
mol = Chem.MolFromPDBFile(pdb_file, sanitize=sanitize)
mol = Chem.MolFromPDBFile(pdb_file)
if mol is None:
raise ValueError("RDKit cannot read PDB file `%s`" % pdb_file)
return cls.from_molecule(mol, atom_feature, bond_feature, residue_feature, mol_feature, kekulize)
Expand Down Expand Up @@ -1052,7 +1051,7 @@ def from_sequence(cls, sequences, atom_feature="default", bond_feature="default"
@classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_pdb(cls, pdb_files, atom_feature="default", bond_feature="default", residue_feature="default",
mol_feature=None, kekulize=False, sanitize=False):
mol_feature=None, kekulize=False):
"""
Create a protein from a list of PDB files.
Expand All @@ -1066,11 +1065,10 @@ def from_pdb(cls, pdb_files, atom_feature="default", bond_feature="default", res
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
sanitize (bool, optional): whether to sanitize the molecule
"""
mols = []
for pdb_file in pdb_files:
mol = Chem.MolFromPDBFile(pdb_file, sanitize=sanitize)
mol = Chem.MolFromPDBFile(pdb_file)
mols.append(mol)

return cls.from_molecule(mols, atom_feature, bond_feature, residue_feature, mol_feature, kekulize)
Expand Down
14 changes: 12 additions & 2 deletions torchdrug/layers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def forward(self, input):
class GaussianSmearing(nn.Module):
r"""
Gaussian smearing from
`SchNet: A continuous-filter convolutional neural network for modeling quantum interactions`_.
`SchNet: A continuous-filter convolutional neural network for modeling quantum interactions`_.``
There are two modes for Gaussian smearing.
Expand Down Expand Up @@ -167,7 +167,7 @@ def forward(self, graph, input):
class InstanceNorm(nn.modules.instancenorm._InstanceNorm):
"""
Instance normalization for graphs. This layer follows the definition in
`GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training`.
`GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training`_.
.. _GraphNorm\: A Principled Approach to Accelerating Graph Neural Network Training:
https://arxiv.org/pdf/2009.03294.pdf
Expand Down Expand Up @@ -325,13 +325,23 @@ def forward(self, *args, **kwargs):


class SinusoidalPositionEmbedding(nn.Module):
"""
Positional embedding based on sine and cosine functions, proposed in `Attention Is All You Need`_.
.. _Attention Is All You Need:
https://arxiv.org/pdf/1706.03762.pdf
Parameters:
output_dim (int): output dimension
"""

def __init__(self, output_dim):
super(SinusoidalPositionEmbedding, self).__init__()
inverse_frequency = 1 / (10000 ** (torch.arange(0.0, output_dim, 2.0) / output_dim))
self.register_buffer("inverse_frequency", inverse_frequency)

def forward(self, input):
""""""
# input: [B, L, ...]
positions = torch.arange(input.shape[1] - 1, -1, -1.0, dtype=input.dtype, device=input.device)
sinusoidal_input = torch.outer(positions, self.inverse_frequency)
Expand Down
8 changes: 4 additions & 4 deletions torchdrug/layers/geometry/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def __init__(self, node_layers=None, edge_layers=None, edge_feature="residue_typ
def edge_residue_type(self, graph, edge_list):
node_in, node_out, _ = edge_list.t()
residue_in, residue_out = graph.atom2residue[node_in], graph.atom2residue[node_out]
in_residue_type = graph.edge_residue_type[residue_in]
out_residue_type = graph.edge_residue_type[residue_out]
in_residue_type = graph.residue_type[residue_in]
out_residue_type = graph.residue_type[residue_out]

return torch.cat([
functional.one_hot(in_residue_type, len(data.Protein.residue2id)),
Expand All @@ -57,8 +57,8 @@ def edge_residue_type(self, graph, edge_list):
def edge_gearnet(self, graph, edge_list, num_relation):
node_in, node_out, r = edge_list.t()
residue_in, residue_out = graph.atom2residue[node_in], graph.atom2residue[node_out]
in_residue_type = graph.edge_residue_type[residue_in]
out_residue_type = graph.edge_residue_type[residue_out]
in_residue_type = graph.residue_type[residue_in]
out_residue_type = graph.residue_type[residue_out]
sequential_dist = torch.abs(residue_in - residue_out)
spatial_dist = (graph.node_position[node_in] - graph.node_position[node_out]).norm(dim=-1)

Expand Down
2 changes: 1 addition & 1 deletion torchdrug/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def forward(self, graph, input, all_loss=None, metric=None):
dict with ``residue_feature`` and ``graph_feature`` fields:
residue representations of shape :math:`(|V_{res}|, d)`, graph representations of shape :math:`(n, d)`
"""
input = graph.edge_residue_type
input = graph.residue_type
size_ext = graph.num_residues
# Prepend BOS
bos = torch.ones(graph.batch_size, dtype=torch.long, device=self.device) * self.num_residue_type
Expand Down
2 changes: 1 addition & 1 deletion torchdrug/models/esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def forward(self, graph, input, all_loss=None, metric=None):
dict with ``residue_feature`` and ``graph_feature`` fields:
residue representations of shape :math:`(|V_{res}|, d)`, graph representations of shape :math:`(n, d)`
"""
input = graph.edge_residue_type
input = graph.residue_type
input = self.mapping[input]
size = graph.num_residues
if (size > self.max_input_length).any():
Expand Down
2 changes: 1 addition & 1 deletion torchdrug/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn
from torch.nn import functional as F

from torchdrug import core, layers
from torchdrug import core
from torchdrug.layers import functional
from torchdrug.core import Registry as R

Expand Down
2 changes: 1 addition & 1 deletion torchdrug/models/physicochemical.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def forward(self, graph, input, all_loss=None, metric=None):
Returns:
dict with ``graph_feature`` field: graph representations of shape :math:`(n, d)`
"""
input = graph.edge_residue_type
input = graph.residue_type

x = self.property[input] # num_residue * 8
x_mean = scatter_mean(x, graph.residue2graph, dim=0, dim_size=graph.batch_size) # batch_size * 8
Expand Down
2 changes: 1 addition & 1 deletion torchdrug/models/statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def forward(self, graph, input, all_loss=None, metric=None):
Returns:
dict with ``graph_feature`` field: graph representations of shape :math:`(n, d)`
"""
input = graph.edge_residue_type
input = graph.residue_type

index = input[:-1] * self.num_residue_type + input[1:]
index = graph.residue2graph[:-1] * self.input_dim + index
Expand Down
4 changes: 2 additions & 2 deletions torchdrug/tasks/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ def predict_and_target(self, batch, all_loss=None, metric=None):
input = graph.node_feature.float()
input[node_index] = 0
else:
target = graph.edge_residue_type[node_index]
target = graph.residue_type[node_index]
with graph.residue():
graph.residue_feature[node_index] = 0
graph.edge_residue_type[node_index] = 0
graph.residue_type[node_index] = 0
# Generate masked edge features. Any better implementation?
if self.graph_construction_model:
graph = self.graph_construction_model.apply_edge_layer(graph)
Expand Down

0 comments on commit 09ffe18

Please sign in to comment.