Skip to content

Commit

Permalink
fix feature padding in G2Gs
Browse files Browse the repository at this point in the history
  • Loading branch information
KiddoZhu committed Aug 14, 2022
1 parent 5e07dd3 commit 342c87e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchdrug/tasks/retrosynthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def predict_synthon(self, batch, k=1):
center_topk_shifted = torch.cat([-torch.ones(1, dtype=torch.long, device=self.device),
center_topk[:-1]])
product_id_shifted = torch.cat([-torch.ones(1, dtype=torch.long, device=self.device),
graph.product_id[:-1]])
graph.product_id[:-1]])
is_duplicate = (center_topk == center_topk_shifted) & (graph.product_id == product_id_shifted)
node_index = node_index[~is_edge]
edge_index = edge_index[is_edge]
Expand Down Expand Up @@ -847,11 +847,11 @@ def _apply_action(self, graph, action, logp):
data_dict.pop(key)
# pad 0 for node / edge attributes
for k, v in data_dict.items():
if meta_dict[k] == "node":
if "node" in meta_dict[k]:
shape = (len(new_atom_type), *v.shape[1:])
new_data = torch.zeros(shape, dtype=v.dtype, device=self.device)
data_dict[k] = functional._extend(v, graph.num_nodes, new_data, has_new_node)[0]
if meta_dict[k] == "edge":
if "edge" in meta_dict[k]:
shape = (len(new_edge_list) * 2, *v.shape[1:])
new_data = torch.zeros(shape, dtype=v.dtype, device=self.device)
data_dict[k] = functional._extend(v, graph.num_edges, new_data, has_new_edge * 2)[0]
Expand Down

0 comments on commit 342c87e

Please sign in to comment.