Skip to content

Commit

Permalink
fix random forest TF 1.4 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
aymericdamien committed Dec 12, 2017
1 parent d3f3c83 commit 0c4e666
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions examples/2_BasicModels/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@
loss_op = forest_graph.training_loss(X, Y)

# Measure the accuracy
infer_op = forest_graph.inference_graph(X)
infer_op, _, _ = forest_graph.inference_graph(X)
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Initialize the variables (i.e. assign their default value)
init_vars = tf.global_variables_initializer()

# Start TensorFlow session
sess = tf.Session()
sess = tf.train.MonitoredSession()

# Run the initializer
sess.run(init_vars)
Expand Down
4 changes: 2 additions & 2 deletions notebooks/2_BasicModels/random_forest.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
"loss_op = forest_graph.training_loss(X, Y)\n",
"\n",
"# Measure the accuracy\n",
"infer_op = forest_graph.inference_graph(X)\n",
"infer_op, _, _ = forest_graph.inference_graph(X)\n",
"correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))\n",
"accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
"\n",
Expand Down Expand Up @@ -158,7 +158,7 @@
],
"source": [
"# Start TensorFlow session\n",
"sess = tf.Session()\n",
"sess = tf.train.MonitoredSession()\n",
"\n",
"# Run the initializer\n",
"sess.run(init_vars)\n",
Expand Down

0 comments on commit 0c4e666

Please sign in to comment.