forked from aymericdamien/TensorFlow-Examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added basic models examples (kmeans, random forest, ...) * Added API examples (layers, estimator, ...) * Added other examples (Multi-GPU, build a dataset, ...) * Notebook refactoring with new header and more details
- Loading branch information
1 parent
4e829a6
commit 90bb4de
Showing
56 changed files
with
6,451 additions
and
1,524 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
""" K-Means. | ||
Implement K-Means algorithm with TensorFlow, and apply it to classify | ||
handwritten digit images. This example is using the MNIST database of | ||
handwritten digits as training samples (http://yann.lecun.com/exdb/mnist/). | ||
Note: This example requires TensorFlow v1.1.0 or over. | ||
Author: Aymeric Damien | ||
Project: https://github.com/aymericdamien/TensorFlow-Examples/ | ||
""" | ||
|
||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow.contrib.factorization import KMeans | ||
|
||
# Ignore all GPUs, tf random forest does not benefit from it. | ||
import os | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "" | ||
|
||
# Import MNIST data | ||
from tensorflow.examples.tutorials.mnist import input_data | ||
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) | ||
full_data_x = mnist.train.images | ||
|
||
# Parameters | ||
num_steps = 50 # Total steps to train | ||
batch_size = 1024 # The number of samples per batch | ||
k = 25 # The number of clusters | ||
num_classes = 10 # The 10 digits | ||
num_features = 784 # Each image is 28x28 pixels | ||
|
||
# Input images | ||
X = tf.placeholder(tf.float32, shape=[None, num_features]) | ||
# Labels (for assigning a label to a centroid and testing) | ||
Y = tf.placeholder(tf.float32, shape=[None, num_classes]) | ||
|
||
# K-Means Parameters | ||
kmeans = KMeans(inputs=X, num_clusters=k, distance_metric='cosine', | ||
use_mini_batch=True) | ||
|
||
# Build KMeans graph | ||
(all_scores, cluster_idx, scores, cluster_centers_initialized, init_op, | ||
train_op) = kmeans.training_graph() | ||
cluster_idx = cluster_idx[0] # fix for cluster_idx being a tuple | ||
avg_distance = tf.reduce_mean(scores) | ||
|
||
# Initialize the variables (i.e. assign their default value) | ||
init_vars = tf.global_variables_initializer() | ||
|
||
# Start TensorFlow session | ||
sess = tf.Session() | ||
|
||
# Run the initializer | ||
sess.run(init_vars, feed_dict={X: full_data_x}) | ||
sess.run(init_op, feed_dict={X: full_data_x}) | ||
|
||
# Training | ||
for i in range(1, num_steps + 1): | ||
_, d, idx = sess.run([train_op, avg_distance, cluster_idx], | ||
feed_dict={X: full_data_x}) | ||
if i % 10 == 0 or i == 1: | ||
print("Step %i, Avg Distance: %f" % (i, d)) | ||
|
||
# Assign a label to each centroid | ||
# Count total number of labels per centroid, using the label of each training | ||
# sample to their closest centroid (given by 'idx') | ||
counts = np.zeros(shape=(k, num_classes)) | ||
for i in range(len(idx)): | ||
counts[idx[i]] += mnist.train.labels[i] | ||
# Assign the most frequent label to the centroid | ||
labels_map = [np.argmax(c) for c in counts] | ||
labels_map = tf.convert_to_tensor(labels_map) | ||
|
||
# Evaluation ops | ||
# Lookup: centroid_id -> label | ||
cluster_label = tf.nn.embedding_lookup(labels_map, cluster_idx) | ||
# Compute accuracy | ||
correct_prediction = tf.equal(cluster_label, tf.cast(tf.argmax(Y, 1), tf.int32)) | ||
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | ||
|
||
# Test Model | ||
test_x, test_y = mnist.test.images, mnist.test.labels | ||
print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y})) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
""" Random Forest. | ||
Implement Random Forest algorithm with TensorFlow, and apply it to classify | ||
handwritten digit images. This example is using the MNIST database of | ||
handwritten digits as training samples (http://yann.lecun.com/exdb/mnist/). | ||
Author: Aymeric Damien | ||
Project: https://github.com/aymericdamien/TensorFlow-Examples/ | ||
""" | ||
|
||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
from tensorflow.contrib.tensor_forest.python import tensor_forest | ||
|
||
# Ignore all GPUs, tf random forest does not benefit from it. | ||
import os | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "" | ||
|
||
# Import MNIST data | ||
from tensorflow.examples.tutorials.mnist import input_data | ||
mnist = input_data.read_data_sets("/tmp/data/", one_hot=False) | ||
|
||
# Parameters | ||
num_steps = 500 # Total steps to train | ||
batch_size = 1024 # The number of samples per batch | ||
num_classes = 10 # The 10 digits | ||
num_features = 784 # Each image is 28x28 pixels | ||
num_trees = 10 | ||
max_nodes = 1000 | ||
|
||
# Input and Target data | ||
X = tf.placeholder(tf.float32, shape=[None, num_features]) | ||
# For random forest, labels must be integers (the class id) | ||
Y = tf.placeholder(tf.int32, shape=[None]) | ||
|
||
# Random Forest Parameters | ||
hparams = tensor_forest.ForestHParams(num_classes=num_classes, | ||
num_features=num_features, | ||
num_trees=num_trees, | ||
max_nodes=max_nodes).fill() | ||
|
||
# Build the Random Forest | ||
forest_graph = tensor_forest.RandomForestGraphs(hparams) | ||
# Get training graph and loss | ||
train_op = forest_graph.training_graph(X, Y) | ||
loss_op = forest_graph.training_loss(X, Y) | ||
|
||
# Measure the accuracy | ||
infer_op = forest_graph.inference_graph(X) | ||
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64)) | ||
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | ||
|
||
# Initialize the variables (i.e. assign their default value) | ||
init_vars = tf.global_variables_initializer() | ||
|
||
# Start TensorFlow session | ||
sess = tf.Session() | ||
|
||
# Run the initializer | ||
sess.run(init_vars) | ||
|
||
# Training | ||
for i in range(1, num_steps + 1): | ||
# Prepare Data | ||
# Get the next batch of MNIST data (only images are needed, not labels) | ||
batch_x, batch_y = mnist.train.next_batch(batch_size) | ||
_, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y}) | ||
if i % 50 == 0 or i == 1: | ||
acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y}) | ||
print('Step %i, Loss: %f, Acc: %f' % (i, l, acc)) | ||
|
||
# Test Model | ||
test_x, test_y = mnist.test.images, mnist.test.labels | ||
print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y})) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.