Skip to content

Commit

Permalink
Fix bug in training_step: target_Q_values must be a column vector
Browse files Browse the repository at this point in the history
  • Loading branch information
ageron committed Mar 12, 2020
1 parent 0c2c80d commit 49715d4
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions 18_reinforcement_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"from tensorflow import keras\n",
"assert tf.__version__ >= \"2.0\"\n",
"\n",
"if not tf.test.is_gpu_available():\n",
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. CNNs can be very slow without a GPU.\")\n",
" if IS_COLAB:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
Expand Down Expand Up @@ -574,6 +574,7 @@
"metadata": {},
"outputs": [],
"source": [
"keras.backend.clear_session()\n",
"tf.random.set_seed(42)\n",
"np.random.seed(42)\n",
"\n",
Expand Down Expand Up @@ -638,7 +639,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -882,6 +883,7 @@
"metadata": {},
"outputs": [],
"source": [
"keras.backend.clear_session()\n",
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n",
Expand Down Expand Up @@ -1274,6 +1276,7 @@
"metadata": {},
"outputs": [],
"source": [
"keras.backend.clear_session()\n",
"tf.random.set_seed(42)\n",
"np.random.seed(42)\n",
"\n",
Expand Down Expand Up @@ -1392,7 +1395,9 @@
" states, actions, rewards, next_states, dones = experiences\n",
" next_Q_values = model.predict(next_states)\n",
" max_next_Q_values = np.max(next_Q_values, axis=1)\n",
" target_Q_values = rewards + (1 - dones) * discount_rate * max_next_Q_values\n",
" target_Q_values = (rewards +\n",
" (1 - dones) * discount_rate * max_next_Q_values)\n",
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
" mask = tf.one_hot(actions, n_outputs)\n",
" with tf.GradientTape() as tape:\n",
" all_Q_values = model(states)\n",
Expand Down Expand Up @@ -1505,6 +1510,7 @@
"metadata": {},
"outputs": [],
"source": [
"keras.backend.clear_session()\n",
"tf.random.set_seed(42)\n",
"np.random.seed(42)\n",
"\n",
Expand Down Expand Up @@ -1536,7 +1542,9 @@
" best_next_actions = np.argmax(next_Q_values, axis=1)\n",
" next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n",
" next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n",
" target_Q_values = rewards + (1 - dones) * discount_rate * next_best_Q_values\n",
" target_Q_values = (rewards + \n",
" (1 - dones) * discount_rate * next_best_Q_values)\n",
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
" mask = tf.one_hot(actions, n_outputs)\n",
" with tf.GradientTape() as tape:\n",
" all_Q_values = model(states)\n",
Expand Down Expand Up @@ -1646,6 +1654,7 @@
"metadata": {},
"outputs": [],
"source": [
"keras.backend.clear_session()\n",
"tf.random.set_seed(42)\n",
"np.random.seed(42)\n",
"\n",
Expand Down Expand Up @@ -1681,7 +1690,9 @@
" best_next_actions = np.argmax(next_Q_values, axis=1)\n",
" next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n",
" next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n",
" target_Q_values = rewards + (1 - dones) * discount_rate * next_best_Q_values\n",
" target_Q_values = (rewards + \n",
" (1 - dones) * discount_rate * next_best_Q_values)\n",
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
" mask = tf.one_hot(actions, n_outputs)\n",
" with tf.GradientTape() as tape:\n",
" all_Q_values = model(states)\n",
Expand Down Expand Up @@ -2777,7 +2788,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.7.6"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 49715d4

Please sign in to comment.