diff --git a/src/models/proteus_module.py b/src/models/proteus_module.py index 8a79d14..f10f17e 100644 --- a/src/models/proteus_module.py +++ b/src/models/proteus_module.py @@ -8,7 +8,7 @@ from torchmetrics import MeanMetric from src.utils.exponential_moving_average import ExponentialMovingAverage from src.utils.tensor_utils import tensor_tree_map -from src.utils.loss import AlphaFold3Loss +from src.utils.loss import ProteusLoss from einops import rearrange @@ -99,7 +99,7 @@ def __init__(self, config): self.train_loss = MeanMetric() self.val_loss = MeanMetric() self.test_loss = MeanMetric() - self.loss_fn = AlphaFold3Loss(config.loss) + self.loss_fn = ProteusLoss(config.loss) # Set matmul precision torch.set_float32_matmul_precision(config.matmul_precision) diff --git a/src/utils/loss.py b/src/utils/loss.py index c16749b..e2d5913 100644 --- a/src/utils/loss.py +++ b/src/utils/loss.py @@ -300,3 +300,48 @@ def forward(self, out, batch, _return_breakdown=False): else: cumulative_loss, losses = self.loss(out, batch, _return_breakdown) return cumulative_loss, losses + + +class ProteusLoss(nn.Module): + """Convenience class that just includes the diffusion loss for training the Proteus Module.""" + + def __init__(self, config): + super(ProteusLoss, self).__init__() + self.config = config + + def loss(self, out, batch, _return_breakdown=False): + loss_fns = { + "diffusion_loss": lambda: diffusion_loss( + pred_atoms=out["denoised_atoms"], + gt_atoms=out["augmented_gt_atoms"], # rotated gt atoms from diffusion module + timesteps=out["timesteps"], + weights=batch["atom_exists"], + atom_is_rna=batch["ref_mask"].new_zeros(batch["ref_mask"].shape), # (bs, n_atoms) + atom_is_dna=batch["ref_mask"].new_zeros(batch["ref_mask"].shape), # (bs, n_atoms) + mask=batch["atom_exists"], + **{**self.config.diffusion_loss}, + ) + } + cumulative_loss = 0.0 + losses = {} + for loss_name, loss_fn in loss_fns.items(): + weight = self.config[loss_name].weight + loss = loss_fn() + if torch.isnan(loss) or torch.isinf(loss): + logging.warning(f"{loss_name} loss is NaN. Skipping...") + loss = loss.new_tensor(0., requires_grad=True) + cumulative_loss = cumulative_loss + weight * loss + losses[loss_name] = loss.detach().clone() + losses["unscaled_loss"] = cumulative_loss.detach().clone() + losses["loss"] = cumulative_loss.detach().clone() + if not _return_breakdown: + return cumulative_loss + return cumulative_loss, losses + + def forward(self, out, batch, _return_breakdown=False): + if not _return_breakdown: + cumulative_loss = self.loss(out, batch, _return_breakdown) + return cumulative_loss + else: + cumulative_loss, losses = self.loss(out, batch, _return_breakdown) + return cumulative_loss, losses