Skip to content

Commit

Permalink
Refactor some functions into tools
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergey Arkhangelskiy authored and Sergey Arkhangelskiy committed Mar 5, 2017
1 parent 1c96d0a commit f37e889
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 51 deletions.
61 changes: 10 additions & 51 deletions dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
import shutil
import time

import gym
import numpy as np
import tensorflow as tf

import atari_wrappers
import tools

tf.app.flags.DEFINE_string('base_dir', '', 'Base directory to save summaries and checkpoints')
Expand All @@ -18,7 +16,7 @@
tf.app.flags.DEFINE_float('clip_grad', 10., 'Gradients norm to clip gradients to')
tf.app.flags.DEFINE_integer('steps', 10 * 10 ** 6, 'Number of steps to run learning for')
tf.app.flags.DEFINE_integer('steps_per_action', 5, 'How many NN updates per one env action')
tf.app.flags.DEFINE_boolean('restart', True,
tf.app.flags.DEFINE_boolean('restart', False,
'If true, starts over, otherwise, starts from the last checkpoint')
tf.app.flags.DEFINE_integer('batch_size', 64, 'Batch size')

Expand Down Expand Up @@ -50,30 +48,6 @@

FLAGS = tf.app.flags.FLAGS

def EnvFactory(env_name):
parts = env_name.split(':')
if len(parts) > 2:
raise ValueError('Incorrect environment name %s' % env_name)

env = gym.make(parts[0])
if len(parts) == 2:
for letter in parts[1]:
if letter == 'L':
env = atari_wrappers.EpisodicLifeEnv(env)
elif letter == 'N':
env = atari_wrappers.NoopResetEnv(env, noop_max=30)
elif letter == 'S':
env = atari_wrappers.MaxAndSkipEnv(env, skip=4)
elif letter == 'F':
env = atari_wrappers.FireResetEnv(env)
elif letter == 'C':
env = atari_wrappers.ClippedRewardsWrapper(env)
elif letter == 'P':
env = atari_wrappers.ProcessFrame84(env)
else:
raise ValueError('Unexpected code of wrapper %s' % letter)
return env


def ConvQNetwork(state, num_actions, unused_is_training):
with tf.variable_scope("convnet"):
Expand Down Expand Up @@ -106,6 +80,7 @@ def ConvQNetwork(state, num_actions, unused_is_training):


def CartPoleQNetwork(state, num_actions, unused_is_training):
state = tf.contrib.layers.flatten(state)
hidden = tf.contrib.layers.fully_connected(
state, 256,
activation_fn=tf.nn.elu,
Expand Down Expand Up @@ -253,26 +228,6 @@ def GenerateExperience(env, policy, rollout_len, gamma, step_callback, stats_cal
return


def InitSession(sess, folder):
"""If folder has checkpoint, reinitializes session with it"""
last_step = -1
if not FLAGS.restart:
for fname in os.listdir(folder):
m = re.match(r'model.ckpt-(\d+).meta', fname)
if m:
step = int(m.group(1))
if step > last_step:
last_step = step

saver = tf.train.Saver()
if last_step > 0:
saver.restore(sess, os.path.join(folder, 'model.ckpt-%d' % last_step))
else:
if os.path.exists(folder):
shutil.rmtree(folder)
sess.run(tf.global_variables_initializer())

return saver

def main(argv):
folder = os.path.join(FLAGS.base_dir, FLAGS.env, 'lr-%.1E' % FLAGS.lr,
Expand All @@ -281,7 +236,7 @@ def main(argv):
'bs-%d' % FLAGS.batch_size,
FLAGS.exploration)

env = EnvFactory(FLAGS.env)
env = tools.EnvFactory(FLAGS.env)

rollout_len = int(math.floor(FLAGS.experience))
gamma_exp = FLAGS.experience - rollout_len
Expand All @@ -296,8 +251,8 @@ def FillBuffer(*args):
rollout_len, gamma_exp,
FillBuffer, lambda *args: None)

# state2q = CartPoleQNetwork
state2q = ConvQNetwork
state2q = CartPoleQNetwork
# state2q = ConvQNetwork
print buf.state_shape

state = tf.placeholder(tf.float32, shape=[None] + list(buf.state_shape), name='state')
Expand All @@ -307,6 +262,7 @@ def FillBuffer(*args):
gamma = tf.placeholder(tf.float32, shape=[None], name='gamma')
is_weights = tf.placeholder(tf.float32, shape=[None], name='is_weights')
is_training = tf.placeholder(tf.bool, shape=None, name='is_training')
tf.add_to_collection('placeholders', state)

with tf.variable_scope('model', reuse=False):
qvalues = state2q(state, env.action_space.n, is_training)
Expand All @@ -329,6 +285,9 @@ def FillBuffer(*args):
global_step = tf.Variable(0, name='global_step', trainable=False)
policy = PolicyFactory(FLAGS.exploration, qvalues, global_step)

q_policy = tf.argmax(qvalues, axis=1, name='q_policy')
tf.add_to_collection('q_policy', q_policy)

delta = target_q - q
td_err_weight = tf.abs(delta)
loss = tf.reduce_mean(tools.HuberLoss(delta, 5) * is_weights)
Expand Down Expand Up @@ -361,7 +320,7 @@ def FillBuffer(*args):
summary_op = tf.summary.merge_all()

with tf.Session() as sess:
saver = InitSession(sess, folder)
saver = tools.InitSession(sess, folder, FLAGS.restart)
writer = tf.summary.FileWriter(folder)
writer.add_graph(tf.get_default_graph())

Expand Down
1 change: 1 addition & 0 deletions game2048.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,4 @@ def _render(self, mode='human', close=False):

sys.stdout.write('|'.join([str(int(v)).center(3) for v in self.state[i, :]]))
sys.stdout.write('\n')
sys.stdout.write('\n')
63 changes: 63 additions & 0 deletions tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os
import random
import re
import time
import gym

import numpy as np
import tensorflow as tf

import atari_wrappers
import game2048


class ExperienceBuffer(object):
"""Simple experience buffer"""
Expand Down Expand Up @@ -181,3 +186,61 @@ def Select(value, index):
tf.expand_dims(index, 1)], 1)
return tf.gather_nd(value, ind)


def EnvFactory(env_name):
parts = env_name.split(':')
if len(parts) > 2:
raise ValueError('Incorrect environment name %s' % env_name)

if parts[0] == '2048':
env = game2048.Game2048()
else:
env = gym.make(parts[0])

if len(parts) == 2:
for letter in parts[1]:
if letter == 'L':
env = atari_wrappers.EpisodicLifeEnv(env)
elif letter == 'N':
env = atari_wrappers.NoopResetEnv(env, noop_max=30)
elif letter == 'S':
env = atari_wrappers.MaxAndSkipEnv(env, skip=4)
elif letter == 'F':
env = atari_wrappers.FireResetEnv(env)
elif letter == 'C':
env = atari_wrappers.ClippedRewardsWrapper(env)
elif letter == 'P':
env = atari_wrappers.ProcessFrame84(env)
else:
raise ValueError('Unexpected code of wrapper %s' % letter)
return env


def GetLastCheckpoint(folder):
last_step = None
for fname in os.listdir(folder):
m = re.match(r'model.ckpt-(\d+).meta', fname)
if m:
step = int(m.group(1))
if step > last_step:
last_step = step
if last_step is not None:
return 'model.ckpt-%d' % last_step
return None


def InitSession(sess, folder, restart):
"""If folder has checkpoint, reinitializes session with it"""
ckpt = None
if not restart:
ckpt = GetLastCheckpoint(folder)

saver = tf.train.Saver()
if ckpt is not None:
saver.restore(sess, os.path.join(folder, ckpt))
else:
if os.path.exists(folder):
shutil.rmtree(folder)
sess.run(tf.global_variables_initializer())

return saver

0 comments on commit f37e889

Please sign in to comment.