Skip to content

Commit

Permalink
TF 1.0 Compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
awjuliani authored Feb 20, 2017
1 parent 376c935 commit 0b8c0c3
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions Deep-Recurrent-Q-Network.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -138,31 +138,31 @@
" inputs=self.convFlat,cell=rnn_cell,dtype=tf.float32,initial_state=self.state_in,scope=myScope+'_rnn')\n",
" self.rnn = tf.reshape(self.rnn,shape=[-1,h_size])\n",
" #The output from the recurrent player is then split into separate Value and Advantage streams\n",
" self.streamA,self.streamV = tf.split(1,2,self.rnn)\n",
" self.streamA,self.streamV = tf.split(self.rnn,2,1)\n",
" self.AW = tf.Variable(tf.random_normal([h_size/2,4]))\n",
" self.VW = tf.Variable(tf.random_normal([h_size/2,1]))\n",
" self.Advantage = tf.matmul(self.streamA,self.AW)\n",
" self.Value = tf.matmul(self.streamV,self.VW)\n",
" \n",
" self.salience = tf.gradients(self.Advantage,self.imageIn)\n",
" #Then combine them together to get our final Q-values.\n",
" self.Qout = self.Value + tf.sub(self.Advantage,tf.reduce_mean(self.Advantage,reduction_indices=1,keep_dims=True))\n",
" self.Qout = self.Value + tf.subtract(self.Advantage,tf.reduce_mean(self.Advantage,axis=1,keep_dims=True))\n",
" self.predict = tf.argmax(self.Qout,1)\n",
" \n",
" #Below we obtain the loss by taking the sum of squares difference between the target and prediction Q values.\n",
" self.targetQ = tf.placeholder(shape=[None],dtype=tf.float32)\n",
" self.actions = tf.placeholder(shape=[None],dtype=tf.int32)\n",
" self.actions_onehot = tf.one_hot(self.actions,4,dtype=tf.float32)\n",
" \n",
" self.Q = tf.reduce_sum(tf.mul(self.Qout, self.actions_onehot), reduction_indices=1)\n",
" self.Q = tf.reduce_sum(tf.multiply(self.Qout, self.actions_onehot), axis=1)\n",
" \n",
" self.td_error = tf.square(self.targetQ - self.Q)\n",
" \n",
" #In order to only propogate accurate gradients through the network, we will mask the first\n",
" #half of the losses for each trace as per Lample & Chatlot 2016\n",
" self.maskA = tf.zeros([self.batch_size,self.trainLength/2])\n",
" self.maskB = tf.ones([self.batch_size,self.trainLength/2])\n",
" self.mask = tf.concat(1,[self.maskA,self.maskB])\n",
" self.mask = tf.concat([self.maskA,self.maskB],1)\n",
" self.mask = tf.reshape(self.mask,[-1])\n",
" self.loss = tf.reduce_mean(self.td_error * self.mask)\n",
" \n",
Expand Down Expand Up @@ -244,7 +244,8 @@
"h_size = 512 #The size of the final convolutional layer before splitting it into Advantage and Value streams.\n",
"max_epLength = 50 #The max allowed length of our episode.\n",
"time_per_step = 1 #Length of each step used in gif creation\n",
"summaryLength = 100 #Number of epidoes to periodically save for analysis"
"summaryLength = 100 #Number of epidoes to periodically save for analysis\n",
"tau = 0.001"
]
},
{
Expand All @@ -258,18 +259,18 @@
"source": [
"tf.reset_default_graph()\n",
"#We define the cells for the primary and target q-networks\n",
"cell = tf.nn.rnn_cell.LSTMCell(num_units=h_size,state_is_tuple=True)\n",
"cellT = tf.nn.rnn_cell.LSTMCell(num_units=h_size,state_is_tuple=True)\n",
"cell = tf.contrib.rnn.BasicLSTMCell(num_units=h_size,state_is_tuple=True)\n",
"cellT = tf.contrib.rnn.BasicLSTMCell(num_units=h_size,state_is_tuple=True)\n",
"mainQN = Qnetwork(h_size,cell,'main')\n",
"targetQN = Qnetwork(h_size,cellT,'target')\n",
"\n",
"init = tf.initialize_all_variables()\n",
"init = tf.global_variables_initializer()\n",
"\n",
"saver = tf.train.Saver(max_to_keep=5)\n",
"\n",
"trainables = tf.trainable_variables()\n",
"\n",
"targetOps = updateTargetGraph(trainables)\n",
"targetOps = updateTargetGraph(trainables,tau)\n",
"\n",
"myBuffer = experience_buffer()\n",
"\n",
Expand Down Expand Up @@ -299,10 +300,7 @@
" saver.restore(sess,ckpt.model_checkpoint_path)\n",
" sess.run(init)\n",
" \n",
" \n",
" merged = tf.merge_all_summaries()\n",
" train_writer = tf.train.SummaryWriter('./train',\n",
" sess.graph)\n",
" updateTarget(targetOps,sess) #Set the target network to be equal to the primary network.\n",
" for i in range(num_episodes):\n",
" episodeBuffer = []\n",
" #Reset environment and get first new observation\n",
Expand Down Expand Up @@ -365,6 +363,8 @@
" break\n",
"\n",
" #Add the episode to the experience buffer\n",
" bufferArray = np.array(episodeBuffer)\n",
" episodeBuffer = zip(bufferArray)\n",
" myBuffer.add(episodeBuffer)\n",
" jList.append(j)\n",
" rList.append(rAll)\n",
Expand Down Expand Up @@ -417,12 +417,12 @@
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"cell = tf.nn.rnn_cell.LSTMCell(num_units=h_size,state_is_tuple=True)\n",
"cellT = tf.nn.rnn_cell.LSTMCell(num_units=h_size,state_is_tuple=True)\n",
"cell = tf.contrib.rnn.BasicLSTMCell(num_units=h_size,state_is_tuple=True)\n",
"cellT = tf.contrib.rnn.BasicLSTMCell(num_units=h_size,state_is_tuple=True)\n",
"mainQN = Qnetwork(h_size,cell,'main')\n",
"targetQN = Qnetwork(h_size,cellT,'target')\n",
"\n",
"init = tf.initialize_all_variables()\n",
"init = tf.global_variables_initializer()\n",
"\n",
"saver = tf.train.Saver(max_to_keep=2)\n",
"\n",
Expand Down

0 comments on commit 0b8c0c3

Please sign in to comment.