Skip to content

Commit

Permalink
fix bugs in SubsequenceNode & GraphConstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
Oxer11 committed Nov 22, 2022
1 parent 4ea99a0 commit 4e76c30
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
3 changes: 1 addition & 2 deletions torchdrug/layers/geometry/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,7 @@ def forward(self, graph):
starts = starts + graph.num_cum_residues - graph.num_residues
ends = ends + graph.num_cum_residues - graph.num_residues

node_mask = functional.multi_slice_mask(starts, ends, graph.num_residue)
residue_mask = node_mask[graph.atom2residue]
residue_mask = functional.multi_slice_mask(starts, ends, graph.num_residue)
graph = graph.subresidue(residue_mask)

return graph
Expand Down
8 changes: 6 additions & 2 deletions torchdrug/layers/geometry/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,14 @@ def apply_edge_layer(self, graph):
edge_feature = self.edge_residue_type(graph, edge_list)
elif self.edge_feature == "gearnet":
edge_feature = self.edge_gearnet(graph, edge_list, num_relation)
elif self.edge_feature is None:
edge_feature = None
else:
raise ValueError("Unknown edge feature `%s`" % self.edge_feature)
data_dict, meta_dict = graph.data_by_meta(include=("node", "residue", "node reference", "residue reference"))

data_dict, meta_dict = graph.data_by_meta(include=(
"node", "residue", "node reference", "residue reference", "graph"
))

if isinstance(graph, data.PackedProtein):
data_dict["num_residues"] = graph.num_residues
if isinstance(graph, data.PackedMolecule):
Expand Down

0 comments on commit 4e76c30

Please sign in to comment.