Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahalamdari committed Sep 12, 2023
1 parent 22b6c41 commit 776b5cd
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 180 deletions.
59 changes: 1 addition & 58 deletions evodiff/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch.nn.functional as F
import numpy as np
from torch.utils.checkpoint import checkpoint
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
from sequence_models.layers import PositionFeedForward, DoubleEmbedding
from sequence_models.convolutional import ByteNetBlock
from sequence_models.constants import MSA_PAD, MASK, MSA_ALPHABET
Expand Down Expand Up @@ -253,60 +252,4 @@ def forward(self, tokens, timesteps):
x = self.emb_layer_norm_after(x)
x = x.permute(2, 0, 1, 3) # R x C x B x D -> B x R x C x D
x = self.lm_head(x)
return x


class TransformerTime(nn.Module):
"""
"""
def __init__(self, n_tokens, d_embedding, d_model, n_layers, n_head, d_feedforward, padding_idx=None,
max_positions=1024, bidirectional=True, dropout=0.0, activation='relu',
norm_first=False, timesteps=None):
"""
"""
super().__init__()
self.d_model = d_model
self.bidirectional = bidirectional
self.embedder = nn.Embedding(n_tokens, d_embedding, padding_idx=padding_idx)
self.pos_encoding = PositionalEncoding(d_embedding, max_positions)
self.timesteps = timesteps
if self.timesteps is not None:
self.time_encoding = PositionalEncoding1D(d_embedding, timesteps) # Timestep encoding
self.up_embedder = PositionFeedForward(d_embedding, d_model)
if bidirectional: # for oa autoregressive model, d3pm models
encoder_layers = TransformerEncoderLayer(d_model, n_head, dim_feedforward=d_feedforward, dropout=dropout,
activation=activation, batch_first=True, norm_first=norm_first)
self.transformer = TransformerEncoder(encoder_layers, n_layers)
else: # for single-order autoregressive model
decoder_layers = TransformerDecoderLayer(d_model, n_head, dim_feedforward=d_feedforward, dropout=dropout,
activation=activation, batch_first=True, norm_first=norm_first)
self.transformer = TransformerDecoder(decoder_layers, n_layers)
self.decoder = nn.Linear(d_model, n_tokens)

# self.init_weights()

# def init_weights(self):
# initrange = 0.1
# self.embedder.weight.data.uniform_(-initrange, initrange)
# self.decoder.bias.data.zero_()
# self.decoder.weight.data.uniform_(-initrange, initrange)

def forward(self, src, tgt, t, input_mask=None):
src = self.embedder(src) * np.sqrt(self.d_model)
src = self.pos_encoding(src.reshape(src.shape[1], src.shape[0], src.shape[2]))
tgt = self.embedder(tgt) * np.sqrt(self.d_model)
tgt = self.pos_encoding(tgt.reshape(tgt.shape[1], tgt.shape[0], tgt.shape[2]))

if self.timesteps is not None:
t = self.time_encoding(t).unsqueeze(1)
t = t.expand(src.shape[0], src.shape[1], src.shape[2])
src += t

src = self.up_embedder(src)
tgt = self.up_embedder(tgt)

if self.bidirectional:
out = self.transformer(src, src_key_padding_mask=input_mask)
else:
out = self.transformer(tgt, src, tgt_key_padding_mask=input_mask)
return self.decoder(out)
return x
42 changes: 0 additions & 42 deletions train-msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import os
from datetime import datetime, timedelta
import pathlib
import esm
import numpy as np
# import mlflow
import torch
from torch.cuda.amp import GradScaler
import torch.multiprocessing as mp
Expand All @@ -21,12 +19,10 @@
from evodiff.model import MSATransformerTime
from sequence_models.esm import MSATransformer
from sequence_models.constants import MSA_ALPHABET
#from sequence_models.datasets import TRRMSADataset #, A3MMSADataset # TODO move datasets back to sequence_models
from evodiff.data import TRRMSADataset, A3MMSADataset
from sequence_models.collaters import MSAAbsorbingCollater
from sequence_models.samplers import SortishSampler, ApproxBatchSampler
from sequence_models.losses import MaskedCrossEntropyLossMSA
#from sequence_models.metrics import MaskedAccuracy # TODO move this back to sequence_models?
from evodiff.metrics import MaskedAccuracyMSA
from torch.utils.data import Subset
from sequence_models.utils import warmup, transformer_lr
Expand Down Expand Up @@ -90,7 +86,6 @@ def train(gpu, args):
with open(args.config_fpath, 'r') as f:
config = json.load(f)

#selection_type = config['selection_type']
selection_type = args.selection_type
d_embed = config['d_embed']
d_hidden = config['d_hidden']
Expand Down Expand Up @@ -121,10 +116,8 @@ def train(gpu, args):
ptjob = True
except:
data_top_dir = 'data/'
#print(data_top_dir)
data_dir = data_top_dir
data_dir += config['dataset'] + '/'
#print(data_dir)
ptjob = False

# build datasets, samplers, and loaders
Expand All @@ -140,7 +133,6 @@ def train(gpu, args):
if args.mask == 'blosum':
Q_prod, Q_t = tokenizer.q_blosum_schedule(timesteps=diffusion_timesteps)
collater = D3PMCollaterMSA(tokenizer=tokenizer, num_timesteps=diffusion_timesteps, Q=Q_t, Q_bar=Q_prod)
#Q_prod = Q_prod.to(device)
else:
print("mask must be: 'autoreg', 'blosum', or 'random'")

Expand Down Expand Up @@ -195,7 +187,6 @@ def train(gpu, args):
if config['dataset'] == 'trrosetta':
dl_valid = DataLoader(dataset=ds_valid,
batch_size=4,
# batch_sampler=valid_sampler,
collate_fn=collater,
num_workers=8)
elif config['dataset'] == 'openfold':
Expand Down Expand Up @@ -407,19 +398,6 @@ def epoch(model, e, split, current_step=0, current_tokens=0):
str(current_step), str(e)]))
f.write('\n')
print('Validation complete in ' + str(datetime.now() - start_time))

# if split == 'test':
# with open(args.out_fpath + 'metrics_test.csv', 'a') as f:
# f.write(','.join(
# [str(rloss_ardm), str(rloss_nll), str(raccu), str(int(current_tokens)),
# str(current_step)]))
# f.write('\n')

# print('Testing complete in ' + str(datetime.now() - start_time))

# elif rank == 0:
# if not ptjob:
# print()
print('Epoch complete in ' + str(datetime.now() - start_time))
return i, tokens_trained

Expand All @@ -429,16 +407,13 @@ def step(model, batch, split):
src_one_hot = src_one_hot.to(device)
tgt_one_hot = tgt_one_hot.to(device)
q = q.to(device)
#q_minus1 = q_minus1.to(device)
Q = Q.to(device)
Q_prod = Q_prod.to(device)
timestep = timestep.to(device)
else:
src, tgt, mask = batch
mask = mask.to(device)
# print('z', rank, device)
src = src.to(device)
# print('y', rank)
tgt = tgt.to(device)
input_mask = (src != masking_idx).float()
nonpad_mask = (src != padding_idx).float()
Expand All @@ -453,7 +428,6 @@ def step(model, batch, split):
if split == 'train':
optimizer.zero_grad()

#with torch.cuda.amp.autocast(): # TODO enable debug
if args.mask == 'blosum' or args.mask == 'random':
outputs = model(src, timestep)
lvb_loss = loss_func1(src_one_hot, q, outputs, tgt, tgt_one_hot, nonpad_mask, timestep, Q, Q_prod)
Expand All @@ -464,10 +438,7 @@ def step(model, batch, split):
accu = accu_func(outputs, tgt, nonpad_mask) * n_tokens
loss = (lvb_loss + _lambda * ce_loss) * n_tokens
elif args.mask == 'autoreg':
#print(src.shape)
#import pdb; pdb.set_trace()
outputs = model(src)
#print(outputs.shape)
ce_loss, nll_loss = loss_func(outputs, tgt, mask, nonpad_mask)
loss = ce_loss
accu = accu_func(outputs, tgt, mask) * n_tokens
Expand All @@ -481,32 +452,19 @@ def step(model, batch, split):
skip_scheduler = (scale > scaler.get_scale())
if not skip_scheduler:
scheduler.step()
# # remove mixed precision for debugging TODO delete
# loss.backward()
# _ = clip_grad_norm_(model.parameters(), clip)
# optimizer.step()
# scheduler.step()

n_seqs = torch.tensor(len(src), device=device)
return loss, nll_loss, accu, n_tokens, n_seqs, n_processed

n_parameters = sum(p.numel() for p in model.parameters())
if rank == 0:
print('%d model parameters' % n_parameters)
#print('%d training sequences' % len(len_train))
#print('%d validation sequences' % len(len_valid))
for e in range(initial_epoch, epochs):
print("epoch: ", e + 1, rank)
#train_sortish_sampler.set_epoch(e + 1)
s, t = epoch(model, e, split='train', current_step=total_steps, current_tokens=total_tokens)
total_steps += s
total_tokens += t

# writer.flush()
# writer.close()

# _, _ = epoch(model, e, split='test', current_step=total_steps, current_tokens=total_tokens)


if __name__ == '__main__':
main()
Loading

0 comments on commit 776b5cd

Please sign in to comment.