From 6d37153afe09f7684381ce56e8366675e22833e9 Mon Sep 17 00:00:00 2001 From: Wengong Jin Date: Wed, 21 Apr 2021 12:07:20 -0400 Subject: [PATCH] molecule generation code ready --- README.md | 51 +++++++++++++++----- cond_decode.py | 64 ------------------------- decode.py | 71 ---------------------------- generate.py | 88 ++++++++-------------------------- get_vocab.py | 23 ++++----- hgraph/__init__.py | 3 +- hgraph/dataset.py | 3 ++ train_generator.py | 113 ++++++++++++++++++++++++++++++++++++++++++++ train_translator.py | 110 ++++++++++++++++++++++++++++++++++++++++++ translate.py | 97 ++++++++++++------------------------- 10 files changed, 326 insertions(+), 297 deletions(-) delete mode 100755 cond_decode.py delete mode 100755 decode.py create mode 100755 train_generator.py create mode 100755 train_translator.py diff --git a/README.md b/README.md index 73d3267..81fd5ff 100644 --- a/README.md +++ b/README.md @@ -6,41 +6,66 @@ Our paper is at https://arxiv.org/pdf/2002.03230.pdf First install the dependencies via conda: * PyTorch >= 1.0.0 * networkx - * RDKit + * RDKit >= 2019.03 * numpy * Python >= 3.6 And then run `pip install .` -## Molecule Generation -The molecule generation code is in the `generation/` folder. +## Data Format +* For graph generation, each line of a training file is a SMILES string of a molecule +* For graph translation, each line of a training file is a pair of molecules (molA, molB) that are similar to each other but molB has better chemical properties. Please see `data/qed/train_pairs.txt`. The test file is a list of molecules to be optimized. Please see `data/qed/test.txt`. -## Graph translation Data Format -* The training file should contain pairs of molecules (molA, molB) that are similar to each other but molB has better chemical properties. Please see `data/qed/train_pairs.txt`. -* The test file is a list of molecules to be optimized. Please see `data/qed/test.txt`. +## Graph generation training procedure +1. Extract substructure vocabulary from a given set of molecules: +``` +python get_vocab.py --ncpu 16 < data/qed/mols.txt > vocab.txt +``` + +2. Preprocess training data: +``` +python preprocess.py --train data/qed/mols.txt --vocab data/qed/vocab.txt --ncpu 16 --mode single +mkdir train_processed +mv tensor* train_processed/ +``` + +3. Train graph generation model +``` +mkdir ckpt/generation +python train_generator.py --train train_processed/ --vocab data/qed/vocab.txt --save_dir ckpt/generation +``` + +4. Sample molecules from a model checkpoint +``` +python generate.py --vocab data/qed/vocab.txt --model ckpt/generation/model.5 --nsamples 1000 +``` ## Graph translation training procedure 1. Extract substructure vocabulary from a given set of molecules: ``` -python get_vocab.py < data/qed/mols.txt > vocab.txt +python get_vocab.py --ncpu 16 < data/qed/mols.txt > vocab.txt ``` -Please replace `data/qed/mols.txt` with your molecules data file. +Please replace `data/qed/mols.txt` with your molecules. 2. Preprocess training data: ``` -python preprocess.py --train data/qed/train_pairs.txt --vocab data/qed/vocab.txt --ncpu 16 < data/qed/train_pairs.txt +python preprocess.py --train data/qed/train_pairs.txt --vocab data/qed/vocab.txt --ncpu 16 mkdir train_processed mv tensor* train_processed/ ``` -Please replace `--train` and `--vocab` with training and vocab file. 3. Train the model: ``` -mkdir models/ -python gnn_train.py --train train_processed/ --vocab data/qed/vocab.txt --save_dir models/ +mkdir ckpt/translation +python train_translator.py --train train_processed/ --vocab data/qed/vocab.txt --save_dir ckpt/translation ``` 4. Make prediction on your lead compounds (you can use any model checkpoint, here we use model.5 for illustration) ``` -python decode.py --test data/qed/valid.txt --vocab data/qed/vocab.txt --model models/model.5 --num_decode 20 > results.csv +python translate.py --test data/qed/valid.txt --vocab data/qed/vocab.txt --model ckpt/translation/model.5 --num_decode 20 > results.csv ``` + +## Polymer generation +The polymer generation code is in the `polymer/` folder. The polymer generation code is similar to `train_generator.py`, but the substructures are tailored for polymers. +For generating regular drug like molecules, we recommend to use `train_generator.py` in the root directory. + diff --git a/cond_decode.py b/cond_decode.py deleted file mode 100755 index 2f2d8a2..0000000 --- a/cond_decode.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim -import torch.optim.lr_scheduler as lr_scheduler -from torch.utils.data import DataLoader - -import math, random, sys -import numpy as np -import argparse - -from hgraph import * -import rdkit - -lg = rdkit.RDLogger.logger() -lg.setLevel(rdkit.RDLogger.CRITICAL) - -parser = argparse.ArgumentParser() -parser.add_argument('--test', required=True) -parser.add_argument('--vocab', required=True) -parser.add_argument('--atom_vocab', default=common_atom_vocab) -parser.add_argument('--model', required=True) - -parser.add_argument('--num_decode', type=int, default=20) -parser.add_argument('--enum_root', action='store_true') -parser.add_argument('--cond', type=str, default="1,0,1,0") -parser.add_argument('--seed', type=int, default=1) - -parser.add_argument('--rnn_type', type=str, default='LSTM') -parser.add_argument('--hidden_size', type=int, default=300) -parser.add_argument('--embed_size', type=int, default=300) -parser.add_argument('--batch_size', type=int, default=1) -parser.add_argument('--latent_size', type=int, default=4) -parser.add_argument('--depthT', type=int, default=20) -parser.add_argument('--depthG', type=int, default=20) -parser.add_argument('--diterT', type=int, default=1) -parser.add_argument('--diterG', type=int, default=3) -parser.add_argument('--dropout', type=float, default=0.0) - -args = parser.parse_args() - -args.test = [line.strip("\r\n ") for line in open(args.test)] -vocab = [x.strip("\r\n ").split() for x in open(args.vocab)] -args.vocab = PairVocab(vocab) - -assert args.cond in ['1,0,1,0', '0,1,1,0', '1,0,0,1'] -cond = map(float, args.cond.split(',')) -args.cond_size = len(cond) -cond = torch.tensor(cond).cuda() - -model = HierCondVGNN(args).cuda() -model.load_state_dict(torch.load(args.model)) -model.eval() - -dataset = MolEnumRootDataset(args.test, args.vocab, args.atom_vocab) -loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=lambda x:x[0]) - -torch.manual_seed(args.seed) -with torch.no_grad(): - for i,batch in enumerate(loader): - new_mols = model.translate(batch[1], cond, args.num_decode, args.enum_root) - smiles = args.test[i] - for k in xrange(args.num_decode): - print smiles, new_mols[k] - diff --git a/decode.py b/decode.py deleted file mode 100755 index 5aac9d9..0000000 --- a/decode.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim -import torch.optim.lr_scheduler as lr_scheduler -from torch.utils.data import DataLoader - -import math, random, sys -import numpy as np -import argparse - -from hgraph import * -import rdkit - -lg = rdkit.RDLogger.logger() -lg.setLevel(rdkit.RDLogger.CRITICAL) - -parser = argparse.ArgumentParser() -parser.add_argument('--test', required=True) -parser.add_argument('--vocab', required=True) -parser.add_argument('--atom_vocab', default=common_atom_vocab) -parser.add_argument('--model', required=True) - -parser.add_argument('--num_decode', type=int, default=20) -parser.add_argument('--sample', action='store_true') -parser.add_argument('--novi', action='store_true') -parser.add_argument('--seed', type=int, default=1) - -parser.add_argument('--rnn_type', type=str, default='LSTM') -parser.add_argument('--hidden_size', type=int, default=270) -parser.add_argument('--embed_size', type=int, default=270) -parser.add_argument('--batch_size', type=int, default=1) -parser.add_argument('--latent_size', type=int, default=4) -parser.add_argument('--depthT', type=int, default=20) -parser.add_argument('--depthG', type=int, default=20) -parser.add_argument('--diterT', type=int, default=1) -parser.add_argument('--diterG', type=int, default=3) -parser.add_argument('--dropout', type=float, default=0.0) - -args = parser.parse_args() -args.enum_root = True -args.greedy = not args.sample - -args.test = [line.strip("\r\n ") for line in open(args.test)] -vocab = [x.strip("\r\n ").split() for x in open(args.vocab)] -args.vocab = PairVocab(vocab) - -if args.novi: - model = HierGNN(args).cuda() -else: - model = HierVGNN(args).cuda() - -model.load_state_dict(torch.load(args.model)) -model.eval() - -dataset = MolEnumRootDataset(args.test, args.vocab, args.atom_vocab) -loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=lambda x:x[0]) - -torch.manual_seed(args.seed) -random.seed(args.seed) - -with torch.no_grad(): - for i,batch in enumerate(loader): - smiles = args.test[i] - if batch is None: - for k in range(args.num_decode): - print(smiles, smiles) - else: - new_mols = model.translate(batch[1], args.num_decode, args.enum_root, args.greedy) - for k in range(args.num_decode): - print(smiles, new_mols[k]) - diff --git a/generate.py b/generate.py index 2b28c09..88e5670 100755 --- a/generate.py +++ b/generate.py @@ -7,6 +7,7 @@ import math, random, sys import numpy as np import argparse +from tqdm import tqdm from hgraph import * import rdkit @@ -15,89 +16,40 @@ lg.setLevel(rdkit.RDLogger.CRITICAL) parser = argparse.ArgumentParser() -parser.add_argument('--train', required=True) parser.add_argument('--vocab', required=True) parser.add_argument('--atom_vocab', default=common_atom_vocab) -parser.add_argument('--save_dir', required=True) -parser.add_argument('--load_epoch', type=int, default=-1) +parser.add_argument('--model', required=True) + +parser.add_argument('--seed', type=int, default=7) +parser.add_argument('--nsample', type=int, default=10000) parser.add_argument('--rnn_type', type=str, default='LSTM') -parser.add_argument('--hidden_size', type=int, default=260) -parser.add_argument('--embed_size', type=int, default=260) -parser.add_argument('--batch_size', type=int, default=20) -parser.add_argument('--latent_size', type=int, default=24) -parser.add_argument('--depthT', type=int, default=20) -parser.add_argument('--depthG', type=int, default=20) +parser.add_argument('--hidden_size', type=int, default=250) +parser.add_argument('--embed_size', type=int, default=250) +parser.add_argument('--batch_size', type=int, default=50) +parser.add_argument('--latent_size', type=int, default=32) +parser.add_argument('--depthT', type=int, default=15) +parser.add_argument('--depthG', type=int, default=15) parser.add_argument('--diterT', type=int, default=1) parser.add_argument('--diterG', type=int, default=3) parser.add_argument('--dropout', type=float, default=0.0) -parser.add_argument('--lr', type=float, default=1e-3) -parser.add_argument('--clip_norm', type=float, default=20.0) -parser.add_argument('--beta', type=float, default=0.3) - -parser.add_argument('--epoch', type=int, default=20) -parser.add_argument('--anneal_rate', type=float, default=0.9) -parser.add_argument('--print_iter', type=int, default=50) -parser.add_argument('--save_iter', type=int, default=-1) - args = parser.parse_args() -print(args) vocab = [x.strip("\r\n ").split() for x in open(args.vocab)] args.vocab = PairVocab(vocab) model = HierVAE(args).cuda() -for param in model.parameters(): - if param.dim() == 1: - nn.init.constant_(param, 0) - else: - nn.init.xavier_normal_(param) - -if args.load_epoch >= 0: - model.load_state_dict(torch.load(args.save_dir + "/model." + str(args.load_epoch))) - -print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)) - -optimizer = optim.Adam(model.parameters(), lr=args.lr) -scheduler = lr_scheduler.ExponentialLR(optimizer, args.anneal_rate) - -param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()])) -grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None])) - -total_step = 0 -beta = args.beta -meters = np.zeros(6) - -for epoch in range(args.load_epoch + 1, args.epoch): - dataset = DataFolder(args.train, args.batch_size) - - for batch in dataset: - total_step += 1 - model.zero_grad() - loss, kl_div, wacc, iacc, tacc, sacc = model(*batch, beta=beta) - - loss.backward() - nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) - optimizer.step() +model.load_state_dict(torch.load(args.model)) +model.eval() - meters = meters + np.array([kl_div, loss.item(), wacc * 100, iacc * 100, tacc * 100, sacc * 100]) +torch.manual_seed(args.seed) +random.seed(args.seed) - if total_step % args.print_iter == 0: - meters /= args.print_iter - print("[%d] Beta: %.3f, KL: %.2f, loss: %.3f, Word: %.2f, %.2f, Topo: %.2f, Assm: %.2f, PNorm: %.2f, GNorm: %.2f" % (total_step, beta, meters[0], meters[1], meters[2], meters[3], meters[4], meters[5], param_norm(model), grad_norm(model))) - sys.stdout.flush() - meters *= 0 - - if args.save_iter >= 0 and total_step % args.save_iter == 0: - n_iter = total_step // args.save_iter - 1 - torch.save(model.state_dict(), args.save_dir + "/model." + str(n_iter)) - scheduler.step() - print("learning rate: %.6f" % scheduler.get_lr()[0]) +with torch.no_grad(): + for _ in tqdm(range(args.nsample // args.batch_size)): + smiles_list = model.sample(args.batch_size) + for _,smiles in enumerate(smiles_list): + print(smiles) - del dataset - if args.save_iter == -1: - torch.save(model.state_dict(), args.save_dir + "/model." + str(epoch)) - scheduler.step() - print("learning rate: %.6f" % scheduler.get_lr()[0]) diff --git a/get_vocab.py b/get_vocab.py index b49d7a9..82a5928 100755 --- a/get_vocab.py +++ b/get_vocab.py @@ -1,8 +1,8 @@ import sys +import argparse from hgraph import * from rdkit import Chem from multiprocessing import Pool -from collections import Counter def process(data): vocab = set() @@ -11,26 +11,27 @@ def process(data): hmol = MolGraph(s) for node,attr in hmol.mol_tree.nodes(data=True): smiles = attr['smiles'] - vocab[attr['label']] += 1 + vocab.add( attr['label'] ) for i,s in attr['inter_label']: - vocab[(smiles, s)] += 1 + vocab.add( (smiles, s) ) return vocab if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--ncpu', type=int, default=1) + args = parser.parse_args() + data = [mol for line in sys.stdin for mol in line.split()[:2]] data = list(set(data)) - ncpu = 15 - batch_size = len(data) // ncpu + 1 + batch_size = len(data) // args.ncpu + 1 batches = [data[i : i + batch_size] for i in range(0, len(data), batch_size)] - pool = Pool(ncpu) + pool = Pool(args.ncpu) vocab_list = pool.map(process, batches) + vocab = [(x,y) for vocab in vocab_list for x,y in vocab] + vocab = list(set(vocab)) - vocab = Counter() - for c in vocab_list: - vocab |= c - - for (x,y),c in vocab: + for x,y in sorted(vocab): print(x, y) diff --git a/hgraph/__init__.py b/hgraph/__init__.py index fb2af21..5bb18bd 100644 --- a/hgraph/__init__.py +++ b/hgraph/__init__.py @@ -2,6 +2,5 @@ from hgraph.encoder import HierMPNEncoder from hgraph.decoder import HierMPNDecoder from hgraph.vocab import Vocab, PairVocab, common_atom_vocab -from hgraph.hgnn import HierGNN, HierVGNN, HierCondVGNN +from hgraph.hgnn import HierVAE, HierVGNN, HierCondVGNN from hgraph.dataset import MoleculeDataset, MolPairDataset, DataFolder, MolEnumRootDataset -from hgraph.stereo import restore_stereo diff --git a/hgraph/dataset.py b/hgraph/dataset.py index c0474b7..98f5a07 100644 --- a/hgraph/dataset.py +++ b/hgraph/dataset.py @@ -74,6 +74,9 @@ def __init__(self, data_folder, batch_size, shuffle=True): self.batch_size = batch_size self.shuffle = shuffle + def __len__(self): + return len(self.data_files) * 1000 + def __iter__(self): for fn in self.data_files: fn = os.path.join(self.data_folder, fn) diff --git a/train_generator.py b/train_generator.py new file mode 100755 index 0000000..bbb3e2f --- /dev/null +++ b/train_generator.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torch.optim.lr_scheduler as lr_scheduler +from torch.utils.data import DataLoader + +import rdkit +import math, random, sys +import numpy as np +import argparse +import os +from tqdm.auto import tqdm + +from hgraph import * + +lg = rdkit.RDLogger.logger() +lg.setLevel(rdkit.RDLogger.CRITICAL) + +parser = argparse.ArgumentParser() +parser.add_argument('--train', required=True) +parser.add_argument('--vocab', required=True) +parser.add_argument('--atom_vocab', default=common_atom_vocab) +parser.add_argument('--save_dir', required=True) +parser.add_argument('--load_model', default=None) +parser.add_argument('--seed', type=int, default=7) + +parser.add_argument('--rnn_type', type=str, default='LSTM') +parser.add_argument('--hidden_size', type=int, default=250) +parser.add_argument('--embed_size', type=int, default=250) +parser.add_argument('--batch_size', type=int, default=50) +parser.add_argument('--latent_size', type=int, default=32) +parser.add_argument('--depthT', type=int, default=15) +parser.add_argument('--depthG', type=int, default=15) +parser.add_argument('--diterT', type=int, default=1) +parser.add_argument('--diterG', type=int, default=3) +parser.add_argument('--dropout', type=float, default=0.0) + +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--clip_norm', type=float, default=5.0) +parser.add_argument('--step_beta', type=float, default=0.001) +parser.add_argument('--max_beta', type=float, default=0.5) +parser.add_argument('--warmup', type=int, default=10000) +parser.add_argument('--kl_anneal_iter', type=int, default=2000) + +parser.add_argument('--epoch', type=int, default=20) +parser.add_argument('--anneal_rate', type=float, default=0.9) +parser.add_argument('--anneal_iter', type=int, default=25000) +parser.add_argument('--print_iter', type=int, default=50) +parser.add_argument('--save_iter', type=int, default=5000) + +args = parser.parse_args() +print(args) + +torch.manual_seed(args.seed) +random.seed(args.seed) + +vocab = [x.strip("\r\n ").split() for x in open(args.vocab)] +args.vocab = PairVocab(vocab) + +model = HierVAE(args).cuda() +print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)) + +for param in model.parameters(): + if param.dim() == 1: + nn.init.constant_(param, 0) + else: + nn.init.xavier_normal_(param) + +optimizer = optim.Adam(model.parameters(), lr=args.lr) +scheduler = lr_scheduler.ExponentialLR(optimizer, args.anneal_rate) + +if args.load_model: + print('continuing from checkpoint ' + args.load_model) + model_state, optimizer_state, total_step, beta = torch.load(args.load_model) + model.load_state_dict(model_state) + optimizer.load_state_dict(optimizer_state) +else: + total_step = beta = 0 + +param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()])) +grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None])) + +meters = np.zeros(6) +for epoch in range(args.epoch): + dataset = DataFolder(args.train, args.batch_size) + + for batch in tqdm(dataset): + total_step += 1 + model.zero_grad() + loss, kl_div, wacc, iacc, tacc, sacc = model(*batch, beta=beta) + + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) + optimizer.step() + + meters = meters + np.array([kl_div, loss.item(), wacc * 100, iacc * 100, tacc * 100, sacc * 100]) + + if total_step % args.print_iter == 0: + meters /= args.print_iter + print("[%d] Beta: %.3f, KL: %.2f, loss: %.3f, Word: %.2f, %.2f, Topo: %.2f, Assm: %.2f, PNorm: %.2f, GNorm: %.2f" % (total_step, beta, meters[0], meters[1], meters[2], meters[3], meters[4], meters[5], param_norm(model), grad_norm(model))) + sys.stdout.flush() + meters *= 0 + + if total_step % args.save_iter == 0: + ckpt = (model.state_dict(), optimizer.state_dict(), total_step, beta) + torch.save(ckpt, os.path.join(args.save_dir, f"model.ckpt.{total_step}")) + + if total_step % args.anneal_iter == 0: + scheduler.step() + print("learning rate: %.6f" % scheduler.get_lr()[0]) + + if total_step >= args.warmup and total_step % args.kl_anneal_iter == 0: + beta = min(args.max_beta, beta + args.step_beta) diff --git a/train_translator.py b/train_translator.py new file mode 100755 index 0000000..0da8c27 --- /dev/null +++ b/train_translator.py @@ -0,0 +1,110 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torch.optim.lr_scheduler as lr_scheduler +from torch.utils.data import DataLoader + +import math, random, sys +import numpy as np +import argparse + +from hgraph import * +import rdkit + +lg = rdkit.RDLogger.logger() +lg.setLevel(rdkit.RDLogger.CRITICAL) + +parser = argparse.ArgumentParser() +parser.add_argument('--train', required=True) +parser.add_argument('--vocab', required=True) +parser.add_argument('--atom_vocab', default=common_atom_vocab) +parser.add_argument('--save_dir', required=True) +parser.add_argument('--load_epoch', type=int, default=-1) + +parser.add_argument('--conditional', action='store_true') +parser.add_argument('--cond_size', type=int, default=4) + +parser.add_argument('--rnn_type', type=str, default='LSTM') +parser.add_argument('--hidden_size', type=int, default=270) +parser.add_argument('--embed_size', type=int, default=270) +parser.add_argument('--batch_size', type=int, default=32) +parser.add_argument('--latent_size', type=int, default=4) +parser.add_argument('--depthT', type=int, default=20) +parser.add_argument('--depthG', type=int, default=20) +parser.add_argument('--diterT', type=int, default=1) +parser.add_argument('--diterG', type=int, default=3) +parser.add_argument('--dropout', type=float, default=0.0) + +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--clip_norm', type=float, default=20.0) +parser.add_argument('--beta', type=float, default=0.3) + +parser.add_argument('--epoch', type=int, default=12) +parser.add_argument('--anneal_rate', type=float, default=0.9) +parser.add_argument('--print_iter', type=int, default=50) +parser.add_argument('--save_iter', type=int, default=-1) + +args = parser.parse_args() +print(args) + +vocab = [x.strip("\r\n ").split() for x in open(args.vocab)] +args.vocab = PairVocab(vocab) + +if args.conditional: + model = HierCondVGNN(args).cuda() +else: + model = HierVGNN(args).cuda() + +for param in model.parameters(): + if param.dim() == 1: + nn.init.constant_(param, 0) + else: + nn.init.xavier_normal_(param) + +if args.load_epoch >= 0: + model.load_state_dict(torch.load(args.save_dir + "/model." + str(args.load_epoch))) + +print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)) + +optimizer = optim.Adam(model.parameters(), lr=args.lr) +scheduler = lr_scheduler.ExponentialLR(optimizer, args.anneal_rate) + +param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()])) +grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None])) + +total_step = 0 +beta = args.beta +meters = np.zeros(6) + +for epoch in range(args.load_epoch + 1, args.epoch): + dataset = DataFolder(args.train, args.batch_size) + + for batch in dataset: + total_step += 1 + batch = batch + (beta,) + model.zero_grad() + loss, kl_div, wacc, iacc, tacc, sacc = model(*batch) + + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) + optimizer.step() + + meters = meters + np.array([kl_div, loss.item(), wacc * 100, iacc * 100, tacc * 100, sacc * 100]) + + if total_step % args.print_iter == 0: + meters /= args.print_iter + print("[%d] Beta: %.3f, KL: %.2f, loss: %.3f, Word: %.2f, %.2f, Topo: %.2f, Assm: %.2f, PNorm: %.2f, GNorm: %.2f" % (total_step, beta, meters[0], meters[1], meters[2], meters[3], meters[4], meters[5], param_norm(model), grad_norm(model))) + sys.stdout.flush() + meters *= 0 + + if args.save_iter >= 0 and total_step % args.save_iter == 0: + n_iter = total_step // args.save_iter - 1 + torch.save(model.state_dict(), args.save_dir + "/model." + str(n_iter)) + scheduler.step() + print("learning rate: %.6f" % scheduler.get_lr()[0]) + + del dataset + if args.save_iter == -1: + torch.save(model.state_dict(), args.save_dir + "/model." + str(epoch)) + scheduler.step() + print("learning rate: %.6f" % scheduler.get_lr()[0]) diff --git a/translate.py b/translate.py index 0da8c27..5aac9d9 100755 --- a/translate.py +++ b/translate.py @@ -15,19 +15,20 @@ lg.setLevel(rdkit.RDLogger.CRITICAL) parser = argparse.ArgumentParser() -parser.add_argument('--train', required=True) +parser.add_argument('--test', required=True) parser.add_argument('--vocab', required=True) parser.add_argument('--atom_vocab', default=common_atom_vocab) -parser.add_argument('--save_dir', required=True) -parser.add_argument('--load_epoch', type=int, default=-1) +parser.add_argument('--model', required=True) -parser.add_argument('--conditional', action='store_true') -parser.add_argument('--cond_size', type=int, default=4) +parser.add_argument('--num_decode', type=int, default=20) +parser.add_argument('--sample', action='store_true') +parser.add_argument('--novi', action='store_true') +parser.add_argument('--seed', type=int, default=1) parser.add_argument('--rnn_type', type=str, default='LSTM') parser.add_argument('--hidden_size', type=int, default=270) parser.add_argument('--embed_size', type=int, default=270) -parser.add_argument('--batch_size', type=int, default=32) +parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--latent_size', type=int, default=4) parser.add_argument('--depthT', type=int, default=20) parser.add_argument('--depthG', type=int, default=20) @@ -35,76 +36,36 @@ parser.add_argument('--diterG', type=int, default=3) parser.add_argument('--dropout', type=float, default=0.0) -parser.add_argument('--lr', type=float, default=1e-3) -parser.add_argument('--clip_norm', type=float, default=20.0) -parser.add_argument('--beta', type=float, default=0.3) - -parser.add_argument('--epoch', type=int, default=12) -parser.add_argument('--anneal_rate', type=float, default=0.9) -parser.add_argument('--print_iter', type=int, default=50) -parser.add_argument('--save_iter', type=int, default=-1) - args = parser.parse_args() -print(args) +args.enum_root = True +args.greedy = not args.sample +args.test = [line.strip("\r\n ") for line in open(args.test)] vocab = [x.strip("\r\n ").split() for x in open(args.vocab)] -args.vocab = PairVocab(vocab) +args.vocab = PairVocab(vocab) -if args.conditional: - model = HierCondVGNN(args).cuda() +if args.novi: + model = HierGNN(args).cuda() else: model = HierVGNN(args).cuda() -for param in model.parameters(): - if param.dim() == 1: - nn.init.constant_(param, 0) - else: - nn.init.xavier_normal_(param) - -if args.load_epoch >= 0: - model.load_state_dict(torch.load(args.save_dir + "/model." + str(args.load_epoch))) - -print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)) - -optimizer = optim.Adam(model.parameters(), lr=args.lr) -scheduler = lr_scheduler.ExponentialLR(optimizer, args.anneal_rate) - -param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()])) -grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None])) - -total_step = 0 -beta = args.beta -meters = np.zeros(6) - -for epoch in range(args.load_epoch + 1, args.epoch): - dataset = DataFolder(args.train, args.batch_size) - - for batch in dataset: - total_step += 1 - batch = batch + (beta,) - model.zero_grad() - loss, kl_div, wacc, iacc, tacc, sacc = model(*batch) +model.load_state_dict(torch.load(args.model)) +model.eval() - loss.backward() - nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) - optimizer.step() +dataset = MolEnumRootDataset(args.test, args.vocab, args.atom_vocab) +loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=lambda x:x[0]) - meters = meters + np.array([kl_div, loss.item(), wacc * 100, iacc * 100, tacc * 100, sacc * 100]) +torch.manual_seed(args.seed) +random.seed(args.seed) - if total_step % args.print_iter == 0: - meters /= args.print_iter - print("[%d] Beta: %.3f, KL: %.2f, loss: %.3f, Word: %.2f, %.2f, Topo: %.2f, Assm: %.2f, PNorm: %.2f, GNorm: %.2f" % (total_step, beta, meters[0], meters[1], meters[2], meters[3], meters[4], meters[5], param_norm(model), grad_norm(model))) - sys.stdout.flush() - meters *= 0 - - if args.save_iter >= 0 and total_step % args.save_iter == 0: - n_iter = total_step // args.save_iter - 1 - torch.save(model.state_dict(), args.save_dir + "/model." + str(n_iter)) - scheduler.step() - print("learning rate: %.6f" % scheduler.get_lr()[0]) +with torch.no_grad(): + for i,batch in enumerate(loader): + smiles = args.test[i] + if batch is None: + for k in range(args.num_decode): + print(smiles, smiles) + else: + new_mols = model.translate(batch[1], args.num_decode, args.enum_root, args.greedy) + for k in range(args.num_decode): + print(smiles, new_mols[k]) - del dataset - if args.save_iter == -1: - torch.save(model.state_dict(), args.save_dir + "/model." + str(epoch)) - scheduler.step() - print("learning rate: %.6f" % scheduler.get_lr()[0])