Skip to content

Commit

Permalink
switched from rmsprop to adam
Browse files Browse the repository at this point in the history
  • Loading branch information
tychovdo committed Oct 25, 2017
1 parent 85073ce commit a474575
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
5 changes: 3 additions & 2 deletions DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(self, params):
else:
self.global_step = tf.Variable(0, name='global_step', trainable=False)

self.rmsprop = tf.train.RMSPropOptimizer(self.params['lr'],self.params['rms_decay'],0.0,self.params['rms_eps']).minimize(self.cost,global_step=self.global_step)
# self.optim = tf.train.RMSPropOptimizer(self.params['lr'],self.params['rms_decay'],0.0,self.params['rms_eps']).minimize(self.cost,global_step=self.global_step)
self.optim = tf.train.AdamOptimizer(self.params['lr']).minimize(self.cost, global_step=self.global_step)
self.saver = tf.train.Saver(max_to_keep=0)

self.sess.run(tf.global_variables_initializer())
Expand All @@ -72,7 +73,7 @@ def train(self,bat_s,bat_a,bat_t,bat_n,bat_r):
q_t = self.sess.run(self.y,feed_dict=feed_dict)
q_t = np.amax(q_t, axis=1)
feed_dict={self.x: bat_s, self.q_t: q_t, self.actions: bat_a, self.terminals:bat_t, self.rewards: bat_r}
_,cnt,cost = self.sess.run([self.rmsprop,self.global_step,self.cost],feed_dict=feed_dict)
_,cnt,cost = self.sess.run([self.optim, self.global_step,self.cost],feed_dict=feed_dict)
return cnt, cost

def save_ckpt(self,filename):
Expand Down
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ Size of replay memory batch size: `batch_size` <br />
Amount of experience tuples in replay memory: `mem_size` <br />
Discount rate (gamma value): `discount` <br />
Learning rate: `lr` <br />
RMS Prop decay rate: `rms_decay` <br />
RMS Prop epsilon value: `rms_eps` <br />
<br />
Exploration/Exploitation (ε-greedy): <br />
Epsilon start value: `eps` <br />
Expand Down
4 changes: 2 additions & 2 deletions pacmanDQN_Agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@

'discount': 0.95, # Discount rate (gamma value)
'lr': .0002, # Learning reate
'rms_decay': 0.99, # RMS Prop decay
'rms_eps': 1e-6, # RMS Prop epsilon
# 'rms_decay': 0.99, # RMS Prop decay (switched to adam)
# 'rms_eps': 1e-6, # RMS Prop epsilon (switched to adam)

# Epsilon value (epsilon-greedy)
'eps': 1.0, # Epsilon start value
Expand Down

0 comments on commit a474575

Please sign in to comment.