Skip to content

Commit

Permalink
separated Proteus loss
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Aug 17, 2024
1 parent 7470239 commit 496dfd9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/models/proteus_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchmetrics import MeanMetric
from src.utils.exponential_moving_average import ExponentialMovingAverage
from src.utils.tensor_utils import tensor_tree_map
from src.utils.loss import AlphaFold3Loss
from src.utils.loss import ProteusLoss
from einops import rearrange


Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(self, config):
self.train_loss = MeanMetric()
self.val_loss = MeanMetric()
self.test_loss = MeanMetric()
self.loss_fn = AlphaFold3Loss(config.loss)
self.loss_fn = ProteusLoss(config.loss)

# Set matmul precision
torch.set_float32_matmul_precision(config.matmul_precision)
Expand Down
45 changes: 45 additions & 0 deletions src/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,48 @@ def forward(self, out, batch, _return_breakdown=False):
else:
cumulative_loss, losses = self.loss(out, batch, _return_breakdown)
return cumulative_loss, losses


class ProteusLoss(nn.Module):
"""Convenience class that just includes the diffusion loss for training the Proteus Module."""

def __init__(self, config):
super(ProteusLoss, self).__init__()
self.config = config

def loss(self, out, batch, _return_breakdown=False):
loss_fns = {
"diffusion_loss": lambda: diffusion_loss(
pred_atoms=out["denoised_atoms"],
gt_atoms=out["augmented_gt_atoms"], # rotated gt atoms from diffusion module
timesteps=out["timesteps"],
weights=batch["atom_exists"],
atom_is_rna=batch["ref_mask"].new_zeros(batch["ref_mask"].shape), # (bs, n_atoms)
atom_is_dna=batch["ref_mask"].new_zeros(batch["ref_mask"].shape), # (bs, n_atoms)
mask=batch["atom_exists"],
**{**self.config.diffusion_loss},
)
}
cumulative_loss = 0.0
losses = {}
for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight
loss = loss_fn()
if torch.isnan(loss) or torch.isinf(loss):
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
cumulative_loss = cumulative_loss + weight * loss
losses[loss_name] = loss.detach().clone()
losses["unscaled_loss"] = cumulative_loss.detach().clone()
losses["loss"] = cumulative_loss.detach().clone()
if not _return_breakdown:
return cumulative_loss
return cumulative_loss, losses

def forward(self, out, batch, _return_breakdown=False):
if not _return_breakdown:
cumulative_loss = self.loss(out, batch, _return_breakdown)
return cumulative_loss
else:
cumulative_loss, losses = self.loss(out, batch, _return_breakdown)
return cumulative_loss, losses

0 comments on commit 496dfd9

Please sign in to comment.