Skip to content

Commit

Permalink
handling case where all molecules are invalid in property optimization (
Browse files Browse the repository at this point in the history
DeepGraphLearning#125)

* fix: handling all-invalid cases

* change invalid metrics to NaN

Co-authored-by: Zhaocheng Zhu <[email protected]>
  • Loading branch information
jannisborn and KiddoZhu authored Oct 13, 2022
1 parent cc2c1ca commit 28a677d
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions torchdrug/tasks/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,22 @@ def reinforce_forward(self, batch):

# generation takes less time when early_stop=True
graph = self.generate(len(batch["graph"]), off_policy=True, early_stop=True)

if graph.num_nodes.max() == 1:
raise ValueError("Generation results collapse to singleton molecules")
if len(graph) == 0 or graph.num_nodes.max() == 1:
logger.error("Generation results collapse to singleton molecules")

all_loss.requires_grad_()
nan = torch.tensor(float("nan"), device=self.device)
for task in self.task:
if task == "plogp":
metric["Penalized logP"] = nan
metric["Penalized logP (max)"] = nan
elif task == "qed":
metric["QED"] = nan
metric["QED (max)"] = nan
metric["node PPO objective"] = nan
metric["edge PPO objective"] = nan

return all_loss, metric

reward = torch.zeros(len(graph), device=self.device)
for task in self.task:
Expand Down Expand Up @@ -804,8 +817,21 @@ def reinforce_forward(self, batch):

# generation takes less time when early_stop=True
graph = self.generate(len(batch["graph"]), max_resample=20, off_policy=True, max_step=40 * 2, verbose=1)
if graph.num_nodes.max() == 1:
raise ValueError("Generation results collapse to singleton molecules")
if len(graph) == 0 or graph.num_nodes.max() == 1:
logger.error("Generation results collapse to singleton molecules")

all_loss.requires_grad_()
nan = torch.tensor(float("nan"), device=self.device)
for task in self.task:
if task == "plogp":
metric["Penalized logP"] = nan
metric["Penalized logP (max)"] = nan
elif task == "qed":
metric["QED"] = nan
metric["QED (max)"] = nan
metric["PPO objective"] = nan

return all_loss, metric

reward = torch.zeros(len(graph), device=self.device)
for task in self.task:
Expand Down

0 comments on commit 28a677d

Please sign in to comment.