Skip to content

Commit

Permalink
Merge pull request uw-ipd#2 from sokrypton/main
Browse files Browse the repository at this point in the history
minor edits
  • Loading branch information
fdimaio authored May 29, 2023
2 parents c19e84e + f36fbda commit 815fa13
Showing 1 changed file with 89 additions and 30 deletions.
119 changes: 89 additions & 30 deletions network/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from symmetry import symm_subunit_matrix, find_symm_subs, get_symm_map
from data_loader import merge_a3m_hetero
import json
import random

# suppress dgl warning w/ newest pytorch
import warnings
Expand All @@ -37,7 +38,7 @@ def get_args():
parser.add_argument("-prefix", default="S", type=str, help="Output file prefix [S]")
parser.add_argument("-symm", default="C1", help="Symmetry group (Cn,Dn,T,O, or I). If provided, 'input' should cover the asymmetric unit. [C1]")
parser.add_argument("-model", default=default_model, help="Model weights. [weights/RF2_apr23.pt]")
parser.add_argument("-n_recycles", default=4, type=int, help="Number of recycles to use [4].")
parser.add_argument("-n_recycles", default=3, type=int, help="Number of recycles to use [3].")
parser.add_argument("-n_models", default=1, type=int, help="Number of models to predict [1].")
parser.add_argument("-subcrop", default=-1, type=int, help="Subcrop pair-to-pair updates. For very large models (>3000 residues) a subcrop of 800-1600 can improve structure accuracy and reduce runtime. A value of -1 means no subcropping. [-1]")
parser.add_argument("-nseqs", default=256, type=int, help="The number of MSA sequences to sample in the main 1D track [256].")
Expand Down Expand Up @@ -96,21 +97,61 @@ def pae_unbin(pred_pae):
pred_pae = nn.Softmax(dim=1)(pred_pae)
return torch.sum(pae_bins[None,:,None,None]*pred_pae, dim=1)

def merge_a3m_homo(msa_orig, ins_orig, nmer):
def merge_a3m_homo(msa_orig, ins_orig, nmer, mode="default"):
N, L = msa_orig.shape[:2]
msa = torch.full((2*N-1, L*nmer), 20, dtype=msa_orig.dtype, device=msa_orig.device)
ins = torch.full((2*N-1, L*nmer), 0, dtype=ins_orig.dtype, device=msa_orig.device)

msa[:N, :L] = msa_orig
ins[:N, :L] = ins_orig
start = L

for i_c in range(1,nmer):
msa[0, start:start+L] = msa_orig[0]
msa[N:, start:start+L] = msa_orig[1:]
ins[0, start:start+L] = ins_orig[0]
ins[N:, start:start+L] = ins_orig[1:]
start += L
if mode == "repeat":

# AAAAAA
# AAAAAA

msa = torch.tile(msa_orig,(1,nmer))
ins = torch.tile(ins_orig,(1,nmer))

elif mode == "diag":

# AAAAAA
# A-----
# -A----
# --A---
# ---A--
# ----A-
# -----A

N = N - 1
new_N = 1 + N * nmer
new_L = L * nmer
msa = torch.full((new_N, new_L), 20, dtype=msa_orig.dtype, device=msa_orig.device)
ins = torch.full((new_N, new_L), 0, dtype=ins_orig.dtype, device=msa_orig.device)

start_L = 0
start_N = 1
for i_c in range(nmer):
msa[0, start_L:start_L+L] = msa_orig[0]
msa[start_N:start_N+N, start_L:start_L+L] = msa_orig[1:]
ins[0, start_L:start_L+L] = ins_orig[0]
ins[start_N:start_N+N, start_L:start_L+L] = ins_orig[1:]
start_L += L
start_N += N
else:

# AAAAAA
# A-----
# -AAAAA

msa = torch.full((2*N-1, L*nmer), 20, dtype=msa_orig.dtype, device=msa_orig.device)
ins = torch.full((2*N-1, L*nmer), 0, dtype=ins_orig.dtype, device=msa_orig.device)

msa[:N, :L] = msa_orig
ins[:N, :L] = ins_orig
start = L

for i_c in range(1,nmer):
msa[0, start:start+L] = msa_orig[0]
msa[N:, start:start+L] = msa_orig[1:]
ins[0, start:start+L] = ins_orig[0]
ins[N:, start:start+L] = ins_orig[1:]
start += L

return msa, ins

class Predictor():
Expand Down Expand Up @@ -148,11 +189,9 @@ def load_model(self, model_weights):
def predict(
self, inputs, out_prefix, symm="C1", ffdb=None,
n_recycles=4, n_models=1, subcrop=-1, nseqs=256, nseqs_full=2048,
n_templ=4
n_templ=4, msa_mask=0.0, is_training=False, msa_concat_mode="default"
):
self.xyz_converter = self.xyz_converter.cpu()
symmids,symmRs,symmmeta,symmoffset = symm_subunit_matrix(symm)
O = symmids.shape[0]

###
# pass 1, combined MSA
Expand All @@ -179,6 +218,17 @@ def predict(
msa_orig = merge_a3m_hetero(msa_orig, {'msa':msas[i],'ins':inss[i]}, [sum(Ls_blocked[:i]),Ls_blocked[i]])
msa_orig, ins_orig = msa_orig['msa'], msa_orig['ins']

# pseudo symmetry
if symm.startswith("X"):
Osub = int(symm[1:])
if Osub > 1:
msa_orig, ins_orig = merge_a3m_homo(msa_orig, ins_orig, Osub, mode=msa_concat_mode)
Ls = sum([Ls] * Osub,[])
symm = "C1"

symmids,symmRs,symmmeta,symmoffset = symm_subunit_matrix(symm)
O = symmids.shape[0]

###
# pass 2, templates
L = sum(Ls)
Expand Down Expand Up @@ -236,7 +286,7 @@ def predict(
# symmetrize msa
effL = Osub*L
if (Osub>1):
msa_orig, ins_orig = merge_a3m_homo(msa_orig, ins_orig, Osub)
msa_orig, ins_orig = merge_a3m_homo(msa_orig, ins_orig, Osub, mode=msa_concat_mode)

# index
idx_pdb = torch.arange(Osub*L)[None,:]
Expand All @@ -256,7 +306,10 @@ def predict(
t2d = xyz_to_t2d(xyz_t, mask_t_2d)


self.model.eval()
if is_training:
self.model.train()
else:
self.model.eval()
for i_trial in range(n_models):
#if os.path.exists("%s_%02d_init.pdb"%(out_prefix, i_trial)):
# continue
Expand All @@ -268,26 +321,29 @@ def predict(
xyz_prev, mask_prev, same_chain, idx_pdb,
symmids, symmsub, symmRs, symmmeta, Ls,
n_recycles, nseqs, nseqs_full, subcrop,
"%s_%02d"%(out_prefix, i_trial)
"%s_%02d"%(out_prefix, i_trial),
msa_mask=msa_mask
)
max_mem = torch.cuda.max_memory_allocated()/1e9
print ("Memory used:", max_mem, "/ Time: %.2f sec"%(time.time()-start_time))
runtime = time.time() - start_time
vram = torch.cuda.max_memory_allocated() / 1e9
print(f"runtime={runtime:.2f} vram={vram:.2f}")
torch.cuda.empty_cache()

def run_prediction(
self, msa_orig, ins_orig,
t1d, t2d, xyz_t, alpha_t, mask_t,
xyz_prev, mask_prev, same_chain, idx_pdb,
symmids, symmsub, symmRs, symmmeta, L_s,
n_recycles, nseqs, nseqs_full, subcrop, out_prefix
n_recycles, nseqs, nseqs_full, subcrop, out_prefix,
msa_mask=0.0,
):
self.xyz_converter = self.xyz_converter.to(self.device)

with torch.no_grad():
msa = msa_orig.long().to(self.device) # (N, L)
ins = ins_orig.long().to(self.device)

print ("Input size", msa.shape[1], msa.shape[0])
print(f"N={msa.shape[0]} L={msa.shape[1]}")
N, L = msa.shape[:2]
O = symmids.shape[0]
Osub = symmsub.shape[0]
Expand Down Expand Up @@ -324,9 +380,9 @@ def run_prediction(
best_logit = None
best_aa = None
best_pae = None
for i_cycle in range(n_recycles):
for i_cycle in range(n_recycles + 1):
seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(
msa, ins, p_mask=0.0, params={'MAXLAT': nseqs, 'MAXSEQ': nseqs_full, 'MAXCYCLE': 1})
msa, ins, p_mask=msa_mask, params={'MAXLAT': nseqs, 'MAXSEQ': nseqs_full, 'MAXCYCLE': 1})

seq = seq.unsqueeze(0)
msa_seed = msa_seed.unsqueeze(0)
Expand Down Expand Up @@ -358,7 +414,8 @@ def run_prediction(
pred_lddt = nn.Softmax(dim=1)(pred_lddt) * self.lddt_bins[None,:,None]
pred_lddt = pred_lddt.sum(dim=1)
pae = pae_unbin(logits_pae)
print ("RECYCLE", i_cycle, pred_lddt.mean(), pae.mean(), best_lddt.mean())
print (f"recycle={i_cycle} plddt={pred_lddt.mean():.3f} pae={pae.mean():.3f}")

#util.writepdb("%s_cycle_%02d.pdb"%(out_prefix, i_cycle), xyz_prev[0], seq[0], L_s, bfacts=100*pred_lddt[0])

logit_s = [l.cpu() for l in logit_s]
Expand Down Expand Up @@ -408,8 +465,10 @@ def run_prediction(
util.writepdb("%s_pred.pdb"%(out_prefix), best_xyzfull[0], seq_full[0], L_s, bfacts=100*best_lddtfull[0])

prob_s = [prob.permute(0,2,3,1).detach().cpu().numpy().astype(np.float16) for prob in prob_s]
np.savez_compressed("%s.npz"%(out_prefix), dist=prob_s[0].astype(np.float16), \
lddt=best_lddt[0].detach().cpu().numpy().astype(np.float16))
np.savez_compressed("%s.npz"%(out_prefix),
dist=prob_s[0].astype(np.float16),
lddt=best_lddt[0].detach().cpu().numpy().astype(np.float16),
pae=best_pae[0].detach().cpu().numpy().astype(np.float16))



Expand Down

0 comments on commit 815fa13

Please sign in to comment.