Skip to content

Commit

Permalink
Remove undirected edge transform
Browse files Browse the repository at this point in the history
  • Loading branch information
chaitjo committed Sep 12, 2023
1 parent 4946a58 commit 59716b9
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 5 deletions.
3 changes: 1 addition & 2 deletions proteinworkshop/features/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from graphein.protein.tensor.data import Protein, ProteinBatch
from omegaconf import ListConfig
from torch_geometric.data import Batch, Data
from torch_geometric.utils import to_undirected


@beartype
Expand Down Expand Up @@ -59,7 +58,7 @@ def compute_edges(
for edge_type in edge_types:
if edge_type.startswith("knn") or edge_type.startswith("eps"):
edges.append(
to_undirected(edge_fn(x.pos, edge_type))
edge_fn(x.pos, edge_type)
)
elif edge_type == "seq_forward":
edges.append(
Expand Down
3 changes: 0 additions & 3 deletions proteinworkshop/models/graph_encoders/tfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,6 @@ def forward(self, batch: Union[Batch, ProteinBatch]) -> EncoderOutput:
the dimension of the embeddings.
:rtype: EncoderOutput
"""
# from torch_geometric.utils import to_undirected
# edge_index = to_undirected(batch.edge_index)

# Node embedding
h = self.emb_in(batch.x) # (n,) -> (n, d)

Expand Down

0 comments on commit 59716b9

Please sign in to comment.