Skip to content

Commit

Permalink
fix bug in tf.1.0 & add recording output
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghuzi committed Mar 28, 2017
1 parent 92e5e70 commit 1b4abe0
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, params):
self.rmsprop = tf.train.RMSPropOptimizer(self.params['lr'],self.params['rms_decay'],0.0,self.params['rms_eps']).minimize(self.cost,global_step=self.global_step)
self.saver = tf.train.Saver(max_to_keep=0)

self.sess.run(tf.initialize_all_variables())
self.sess.run(tf.global_variables_initializer())

if self.params['load_file'] is not None:
print('Loading checkpoint...')
Expand Down
1 change: 1 addition & 0 deletions logs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Logs
6 changes: 6 additions & 0 deletions pacmanDQN_Agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(self, args):
self.sess = tf.Session(config = tf.ConfigProto(gpu_options = gpu_options))
self.qnet = DQN(self.params)

# time started
self.general_record_time = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
# Q and cost
self.Q_global = []
self.cost_disp = 0
Expand Down Expand Up @@ -196,6 +198,10 @@ def final(self, state):
self.observation_step(state)

# Print stats
log_file = open('./logs/'+str(self.general_record_time)+'-l-'+str(self.params['width'])+'-m-'+str(self.params['height'])+'-x-'+str(self.params['num_training'])+'.log','a')
log_file.write("# %4d | steps: %5d | steps_t: %5d | t: %4f | r: %12f | e: %10f " %
(self.numeps,self.local_cnt, self.cnt, time.time()-self.s, self.ep_rew, self.params['eps']))
log_file.write("| Q: %10f | won: %r \n" % ((max(self.Q_global, default=float('nan')), self.won)))
sys.stdout.write("# %4d | steps: %5d | steps_t: %5d | t: %4f | r: %12f | e: %10f " %
(self.numeps,self.local_cnt, self.cnt, time.time()-self.s, self.ep_rew, self.params['eps']))
sys.stdout.write("| Q: %10f | won: %r \n" % ((max(self.Q_global, default=float('nan')), self.won)))
Expand Down

0 comments on commit 1b4abe0

Please sign in to comment.