Skip to content

Commit

Permalink
added parameter file
Browse files Browse the repository at this point in the history
  • Loading branch information
nirajverma288 committed May 28, 2019
1 parent 047bd4a commit 9cf8558
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 50 deletions.
28 changes: 20 additions & 8 deletions denovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,25 @@


class environ_grid:
def __init__(self, pdb, name, RENDER = 0, test = 0):
def __init__(self, pdb,
name,
RENDER = 0,
test = 0,
track = 5,
fcounts = 6,
bcount = -1):
self.name = name
self.test = test
self.RENDER = RENDER
self.SYNC_TARGET_FRAMES = 100

# how much to look in future
self.fcounts = fcounts

# bcount is how much to go backward for reward
self.bcount = bcount #5

self.res_track = track # how much residue coordinates be included from generated sequence in the state
if 'proteins' not in os.listdir('.'):
raise Exception('No folder named proteins found !')

Expand Down Expand Up @@ -65,14 +79,10 @@ def initialize(self):
if not self.test:
np.save('models/res_d.npy', self.res_d)

# how much to look in future
self.fcounts = 6

# bcount is how much to go backward for reward
self.bcount = -1#5

#print ('i', self.igrid)
# initial grid
print (self.res_d)
print ('\nUnique residues :',self.res_d)

def make_ohe(self):
l = self.nres
Expand Down Expand Up @@ -127,6 +137,8 @@ def init_args(self):
for i in range (len(self.pdb_files)):
if len(self.fcords[i]) != len(self.res_arrs[i]):
raise Exception('Multiple chains detected in protein : '+self.pdb_files[i])
if self.res_track == -1:
self.res_track = max([len(i) for i in self.res_arrs])
state = self.reset()
l = state.shape[0]
self.obs_size = l
Expand Down Expand Up @@ -265,7 +277,7 @@ def state(self):
lis = np.concatenate((np.array(l_temp), lis))

# last n coordinates of residues
n = 5
n = self.res_track
n_tem = np.zeros(n*4)

for i in range (n):
Expand Down
70 changes: 51 additions & 19 deletions idqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.optim as optim

from denovo import environ_grid
import sys

class DQN(nn.Module):
def __init__(self, obs_size, hidden_size, n_actions):
Expand Down Expand Up @@ -122,35 +123,62 @@ def sample(self, batch_size):
return np.array(states), np.array(actions), np.array(rewards, dtype=np.float32), \
np.array(dones, dtype=np.uint8), np.array(next_states)


DEFAULT_ENV_NAME = "1k43"
MEAN_REWARD_BOUND = -3.0

env = environ_grid('1k43.pdb', DEFAULT_ENV_NAME, 0)

GAMMA = 0.99
BATCH_SIZE = 32
REPLAY_SIZE = 10000
LEARNING_RATE = 1e-4
SYNC_TARGET_FRAMES = 1000
REPLAY_START_SIZE = 10000

EPSILON_DECAY_LAST_FRAME = 10**6
EPSILON_START = 1.0
EPSILON_FINAL = 0.05
def read_inp():
inp = 'inp'
if len(sys.argv) > 1:
inp = sys.argv[1]
f=open(inp,'r')
lines=f.readlines()
f.close()
dic={}
for line in lines:
if '#' in line or len(line.strip().split())==0:
continue
a,b=line.strip().split()
dic[a]=b
return dic

params = read_inp()
for i in params:
print (i, params[i])

DEFAULT_ENV_NAME = params['DEFAULT_ENV_NAME']#"1k43"
MEAN_REWARD_BOUND = eval(params['MEAN_REWARD_BOUND'])#-3.0
RENDER = eval(params['RENDER'])#0

FCOUNTS = eval(params['FCOUNTS'])#10
BCOUNT = eval(params['BCOUNT'])#-1
TRACK = eval(params['TRACK'])#5 # how much residue coordinates be included from generated sequence

env = environ_grid('1k43.pdb', DEFAULT_ENV_NAME, RENDER, 0, TRACK, FCOUNTS, BCOUNT)

GAMMA = eval(params['GAMMA'])#0.99
BATCH_SIZE = eval(params['BATCH_SIZE'])#32
REPLAY_SIZE = eval(params['REPLAY_SIZE'])#10000
LEARNING_RATE = eval(params['LEARNING_RATE'])#1e-4
SYNC_TARGET_FRAMES = eval(params['SYNC_TARGET_FRAMES'])#1000
REPLAY_START_SIZE = eval(params['REPLAY_START_SIZE'])#10000

EPSILON_DECAY_LAST_FRAME = eval(params['EPSILON_DECAY_LAST_FRAME'])#10**6
EPSILON_START = eval(params['EPSILON_START'])#1.0
EPSILON_FINAL = eval(params['EPSILON_FINAL'])#0.05

MAX_ITER = eval(params['MAX_ITER'])#10**9


Experience = collections.namedtuple('Experience',
field_names=['state', 'action', 'reward', 'done', 'new_state'])


device = "cpu"
device = params['device']#'cpu'

HIDDEN_SIZE = eval(params['HIDDEN_SIZE'])#256

#env = make_env(DEFAULT_ENV_NAME)

net = DQN(env.obs_size, 256, env.n_actions).to(device)
net = DQN(env.obs_size, HIDDEN_SIZE, env.n_actions).to(device)
#net.load_state_dict(torch.load("models/" +DEFAULT_ENV_NAME + "-best.dat", map_location=lambda storage, loc: storage))
tgt_net = DQN(env.obs_size, 256, env.n_actions).to(device)
tgt_net = DQN(env.obs_size, HIDDEN_SIZE, env.n_actions).to(device)
print(net)

buffer = ExperienceBuffer(REPLAY_SIZE)
Expand Down Expand Up @@ -193,6 +221,10 @@ def sample(self, batch_size):
print("Solved in %d frames!" % frame_idx)
break

if frame_idx >= MAX_ITER:
print ('Maximum iteration reached')
break

if len(buffer) < REPLAY_START_SIZE:
continue

Expand Down
40 changes: 40 additions & 0 deletions inp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
DEFAULT_ENV_NAME PfRL

RENDER 0

MEAN_REWARD_BOUND -3.0

GAMMA 0.99

BATCH_SIZE 32

REPLAY_SIZE 10000

LEARNING_RATE 1e-4

SYNC_TARGET_FRAMES 1000

REPLAY_START_SIZE 10000

EPSILON_DECAY_LAST_FRAME 10**6

MAX_ITER 10**9

EPSILON_START 1.0

EPSILON_FINAL 0.05

device cpu

HIDDEN_SIZE 256

# how much feature residues to consider
FCOUNTS 10

# how much to look backward for reward
# -1 denotes all
BCOUNT -1

# how much residue coordinates be included from generated sequence in the state
# -1 denotes all
TRACK 5
58 changes: 35 additions & 23 deletions test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn as nn
import torch.optim as optim
from denovo import environ_grid

class DQN(nn.Module):
def __init__(self, obs_size, hidden_size, n_actions):
Expand All @@ -28,12 +29,40 @@ def init_weights(m):
def forward(self, x):
return self.net(x)

from denovo import environ_grid
def read_inp():
inp = 'inp'
if len(sys.argv) > 1:
inp = sys.argv[1]
f=open(inp,'r')
lines=f.readlines()
f.close()
dic={}
for line in lines:
if '#' in line or len(line.strip().split())==0:
continue
a,b=line.strip().split()
dic[a]=b
return dic

params = read_inp()


DEFAULT_ENV_NAME = '1k43'
pdb = '1k43.pdb'
if len(sys.argv) > 1:
pdb = sys.argv[1]
env = environ_grid(pdb,'test', 1, 1)
if len(sys.argv) > 2:
pdb = sys.argv[2]

RENDER = 1
test = 1

DEFAULT_ENV_NAME = params['DEFAULT_ENV_NAME']
FCOUNTS = eval(params['FCOUNTS'])#10
BCOUNT = eval(params['BCOUNT'])#-1
TRACK = eval(params['TRACK'])#5

HIDDEN_SIZE = eval(params['HIDDEN_SIZE'])

env = environ_grid(pdb, DEFAULT_ENV_NAME, RENDER, test, TRACK, FCOUNTS, BCOUNT)

state = env.reset()
total_reward = 0.0
Expand All @@ -42,7 +71,7 @@ def forward(self, x):
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

test_net = DQN(env.obs_size, 256, env.n_actions)
test_net = DQN(env.obs_size, HIDDEN_SIZE, env.n_actions)
test_net.load_state_dict(torch.load("models/" +DEFAULT_ENV_NAME + "-best.dat", map_location=lambda storage, loc: storage))


Expand All @@ -57,24 +86,7 @@ def forward(self, x):
total_reward += reward
if done:
break
print("Total reward: %.2f" % total_reward)
#print("Total reward: %.2f" % total_reward)
print("Action counts:", c)

'''
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
l1 = np.array(env.current_status)
l2 = np.array(env.fcords[env.current_index])
x1, y1, z1 = l1[:,0],l1[:,1],l1[:,2]
x2, y2, z2 = l2[:,0],l2[:,1],l2[:,2]
lines1 = ax.scatter(x1, y1, z1, c = 'r', s = 100)
lines2 = ax.plot(x1, y1, z1, c = 'r')
lines3 = ax.scatter(x2, y2, z2, c = 'g', s = 100)
lines4 = ax.plot(x2, y2, z2, c = 'g')

plt.show(block = True)
'''

0 comments on commit 9cf8558

Please sign in to comment.