Skip to content

Commit

Permalink
Improve memory management in extra msa stack
Browse files Browse the repository at this point in the history
  • Loading branch information
gahdritz committed Apr 6, 2022
1 parent cdadff3 commit e310bba
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions openfold/model/evoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,20 +352,30 @@ def forward(self,
chunk_size: Optional[int] = None,
_chunk_logits: Optional[int] = 1024,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
def add(m1, m2):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if(torch.is_grad_enabled()):
m1 = m1 + m2
else:
m1 += m2

return m1

m = add(m, self.msa_dropout_layer(
self.msa_att_row(
m.clone(),
z=z.clone(),
m.clone() if torch.is_grad_enabled() else m,
z=z.clone() if torch.is_grad_enabled() else z,
mask=msa_mask,
chunk_size=chunk_size,
_chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
)

))
def fn(m, z):
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size))
m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
)
Expand Down

0 comments on commit e310bba

Please sign in to comment.