Skip to content

Commit

Permalink
distributional rl
Browse files Browse the repository at this point in the history
  • Loading branch information
shixiaowen03 committed Dec 21, 2018
1 parent c659ae9 commit 7d03dd0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
16 changes: 7 additions & 9 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 10 additions & 1 deletion RL/Basic-DisRL-Demo/Categorical_DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ def __init__(self,env,config):
self.v_min = self.config.v_min
self.atoms = self.config.atoms

self.time_step = 0
self.epsilon = self.config.INITIAL_EPSILON
self.state_shape = env.observation_space.shape
self.action_dim = env.action_space.n

self.time_step = 0

target_state_shape = [1]
target_state_shape.extend(self.state_shape)

Expand Down Expand Up @@ -82,10 +83,16 @@ def build_cate_dqn_net(self):

self.optimizer = tf.train.AdamOptimizer(self.config.LEARNING_RATE).minimize(self.cross_entropy_loss)

eval_params = tf.get_collection("eval_net_params")
target_params = tf.get_collection('target_net_params')

self.update_target_net = [tf.assign(t, e) for t, e in zip(target_params, eval_params)]




def train(self,s,r,action,s_,gamma):
self.time_step += 1
list_q_ = [self.sess.run(self.q_target,feed_dict={self.state_input:[s_],self.action_input:[[a]]}) for a in range(self.action_dim)]
a_ = tf.argmax(list_q_).eval()
m = np.zeros(self.atoms)
Expand All @@ -103,6 +110,8 @@ def train(self,s,r,action,s_,gamma):
self.sess.run(self.optimizer,feed_dict={self.state_input:[s] , self.action_input:[action], self.m_input: m })


if self.time_step % self.config.UPDATE_TARGET_NET == 0:
self.sess.run(self.update_target_net)

def save_model(self):
print("Model saved in : ", self.saver.save(self.sess, self.config.MODEL_PATH))
Expand Down
2 changes: 2 additions & 0 deletions RL/Basic-DisRL-Demo/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ class Categorical_DQN_Config():
replay_buffer_size = 2000
iteration = 5
episode = 300 # 300 games per iteration


0 comments on commit 7d03dd0

Please sign in to comment.