Skip to content

Commit

Permalink
fixed typing error of anchor_gt_residue
Browse files Browse the repository at this point in the history
  • Loading branch information
dingquanyu authored and jnwei committed May 11, 2024
1 parent 6f1329e commit d968098
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions openfold/utils/multi_chain_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def split_dim(shape):
return labels


def get_per_asym_residue_index(features: dict) -> Dict[int, list]:
def get_per_asym_residue_index(features: dict) -> Dict[int, torch.Tensor]:
"""
A function that retrieve which residues belong to which asym_id
Expand Down Expand Up @@ -354,15 +354,15 @@ def get_entity_2_asym_list(features: dict) -> Dict[int, list]:


def calculate_input_mask(true_ca_masks: List[torch.Tensor], anchor_gt_idx: torch.Tensor,
anchor_gt_residue: list,
anchor_gt_residue: torch.Tensor,
asym_mask: torch.Tensor, pred_ca_mask: torch.Tensor) -> torch.Tensor:
"""
Calculate an input mask for downstream optimal transformation computation
Args:
true_ca_masks: list of masks from ground truth chains.
anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor.
anchor_gt_residue:a list of residue indexes that belongs to the selected ground truth anchor
anchor_gt_residue:a 1D vector tensor of residue indexes that belongs to the selected ground truth anchor
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor): ca mask from predicted structure.
Expand All @@ -378,7 +378,7 @@ def calculate_input_mask(true_ca_masks: List[torch.Tensor], anchor_gt_idx: torch


def calculate_optimal_transform(true_ca_poses: List[torch.Tensor],
anchor_gt_idx: int, anchor_gt_residue: list,
anchor_gt_idx: int, anchor_gt_residue: torch.Tensor,
true_ca_masks: List[torch.Tensor], pred_ca_mask: torch.Tensor,
asym_mask: torch.Tensor,
pred_ca_pos: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -391,7 +391,7 @@ def calculate_optimal_transform(true_ca_poses: List[torch.Tensor],
Args:
true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5
anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor.
anchor_gt_residue:a list of residue indexes that belongs to the selected ground truth anchor
anchor_gt_residue:a 1D vector tensor of residue indexes that belongs to the selected ground truth anchor
true_ca_masks: list of masks from ground truth chains e.g. it will be length=5 if there are 5 chains in ground truth structure
pred_ca_mask: A boolean tensor corresponds to the mask to mask the predicted features
asym_mask: A boolean tensor that mask out other elements in a tensor if they do not belong to a this asym_id
Expand Down

0 comments on commit d968098

Please sign in to comment.