Skip to content

Commit

Permalink
Add tests for dataset iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
batzner committed Mar 18, 2018
1 parent 263b51c commit bf7addd
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 7 deletions.
18 changes: 11 additions & 7 deletions examples/sequential_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def main():
np.mean(accuracies)))


def add_noise(inputs, labels):
def add_gaussian_noise(inputs, labels):
# Values taken https://github.com/cooijmanstim/recurrent-batch-normalization
inputs = inputs + tf.random_normal(inputs.shape, mean=0.0, stddev=0.1)
return inputs, labels
Expand All @@ -167,18 +167,22 @@ def preprocess_data(inputs, labels):
return inputs, labels


def get_iterators(handle, inputs_ph, labels_ph):
def get_iterators(handle, inputs_ph, labels_ph, add_noise=True,
batch_size=BATCH_SIZE, shuffle=True):
training_dataset = tf.data.Dataset.from_tensor_slices((inputs_ph, labels_ph))
training_dataset = training_dataset.shuffle(buffer_size=1000)
# Apply random perturbations to the training data
training_dataset = training_dataset.map(add_noise).map(preprocess_data)
training_dataset = training_dataset.repeat().batch(BATCH_SIZE)
if shuffle:
training_dataset = training_dataset.shuffle(buffer_size=1000)
if add_noise:
# Apply random perturbations to the training data
training_dataset = training_dataset.map(add_gaussian_noise)
training_dataset = training_dataset.map(preprocess_data)
training_dataset = training_dataset.repeat().batch(batch_size)

# Create the validation dataset
validation_dataset = tf.data.Dataset.from_tensor_slices(
(inputs_ph, labels_ph))
validation_dataset = validation_dataset.map(preprocess_data)
validation_dataset = validation_dataset.batch(BATCH_SIZE)
validation_dataset = validation_dataset.batch(batch_size)

# Create an iterator for switching between datasets
iterator = tf.data.Iterator.from_string_handle(
Expand Down
146 changes: 146 additions & 0 deletions examples/sequential_mnist_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import tensorflow as tf
from tensorflow.python.platform import test
import numpy as np

from examples.sequential_mnist import get_iterators


class TestSequentialMnist(test.TestCase):
def testTrainingOutputs(self):
batch_size = 2
train_inputs = np.array([[1, 2], [3, 4], [5, 6]])
train_labels = np.array([11, 12, 13])

expected_input_batches = [[[1, 2], [3, 4]], [[5, 6], [1, 2]]]
expected_input_batches = np.array(expected_input_batches).reshape(
(2, 2, 2, 1))
expected_input_labels = [[11, 12], [13, 11]]

data_handle = tf.placeholder(tf.string, shape=[])
all_inputs_ph = tf.placeholder(tf.float32, [None, 2])
all_labels_ph = tf.placeholder(tf.int32, [None])

main_iter, train_iter, _ = get_iterators(data_handle,
all_inputs_ph,
all_labels_ph,
add_noise=False,
batch_size=batch_size,
shuffle=False)
sess = tf.Session()
sess.run(train_iter.initializer, feed_dict={
all_inputs_ph: train_inputs,
all_labels_ph: train_labels})

train_handle = sess.run(train_iter.string_handle())
inputs_op, labels_op = main_iter.get_next()
# Generate the first batch
inputs, labels = sess.run([inputs_op, labels_op],
feed_dict={data_handle: train_handle})
self.assertAllEqual(inputs, expected_input_batches[0])
self.assertAllEqual(labels, expected_input_labels[0])

# Generate the second batch
inputs, labels = sess.run([inputs_op, labels_op],
feed_dict={data_handle: train_handle})
self.assertAllEqual(inputs, expected_input_batches[1])
self.assertAllEqual(labels, expected_input_labels[1])

def testValidationOutputs(self):
batch_size = 2
train_inputs = np.random.rand(10, 2)
train_labels = np.random.rand(10)
valid_inputs = np.array([[1, 2], [3, 4], [5, 6]])
valid_labels = np.array([11, 12, 13])

expected_input_batches = [[[[1], [2]], [[3], [4]]], [[[5], [6]]]]
expected_input_labels = [[11, 12], [13]]

data_handle = tf.placeholder(tf.string, shape=[])
all_inputs_ph = tf.placeholder(tf.float32, [None, 2])
all_labels_ph = tf.placeholder(tf.int32, [None])

main_iter, train_iter, valid_iter = get_iterators(data_handle,
all_inputs_ph,
all_labels_ph,
batch_size=batch_size)
sess = tf.Session()
sess.run(train_iter.initializer, feed_dict={
all_inputs_ph: train_inputs,
all_labels_ph: train_labels})
sess.run(valid_iter.initializer, feed_dict={
all_inputs_ph: valid_inputs,
all_labels_ph: valid_labels})

# Generate handles for each iterator
train_handle = sess.run(train_iter.string_handle())
valid_handle = sess.run(valid_iter.string_handle())
inputs_op, labels_op = main_iter.get_next()

# Generate some train labels first
sess.run([inputs_op, labels_op], feed_dict={data_handle: train_handle})

# Generate the first batch
inputs, labels = sess.run([inputs_op, labels_op],
feed_dict={data_handle: valid_handle})
self.assertAllEqual(inputs, expected_input_batches[0])
self.assertAllEqual(labels, expected_input_labels[0])

# Generate the second batch
inputs, labels = sess.run([inputs_op, labels_op],
feed_dict={data_handle: valid_handle})
self.assertAllEqual(inputs, expected_input_batches[1])
self.assertAllEqual(labels, expected_input_labels[1])

def testTrainingValidationMix(self):
batch_size = 2
train_inputs = np.array([[1, 2], [3, 4], [5, 6]])
train_labels = np.array([11, 12, 13])
valid_inputs = np.random.rand(10, 2)
valid_labels = np.random.rand(10)

expected_input_batches = [[[1, 2], [3, 4]], [[5, 6], [1, 2]]]
expected_input_batches = np.array(expected_input_batches).reshape(
(2, 2, 2, 1))
expected_input_labels = [[11, 12], [13, 11]]

data_handle = tf.placeholder(tf.string, shape=[])
all_inputs_ph = tf.placeholder(tf.float32, [None, 2])
all_labels_ph = tf.placeholder(tf.int32, [None])

main_iter, train_iter, valid_iter = get_iterators(data_handle,
all_inputs_ph,
all_labels_ph,
batch_size=batch_size,
add_noise=False,
shuffle=False)
sess = tf.Session()
sess.run(train_iter.initializer, feed_dict={
all_inputs_ph: train_inputs,
all_labels_ph: train_labels})

# Generate handles for each iterator
train_handle = sess.run(train_iter.string_handle())
valid_handle = sess.run(valid_iter.string_handle())
inputs_op, labels_op = main_iter.get_next()

# Generate the first batch
inputs, labels = sess.run([inputs_op, labels_op],
feed_dict={data_handle: train_handle})
self.assertAllEqual(inputs, expected_input_batches[0])
self.assertAllEqual(labels, expected_input_labels[0])

# Iterate through the validation set
sess.run(valid_iter.initializer, feed_dict={
all_inputs_ph: valid_inputs,
all_labels_ph: valid_labels})
while True:
try:
sess.run([inputs_op, labels_op], feed_dict={data_handle: valid_handle})
except tf.errors.OutOfRangeError:
break

# Generate the second batch
inputs, labels = sess.run([inputs_op, labels_op],
feed_dict={data_handle: train_handle})
self.assertAllEqual(inputs, expected_input_batches[1])
self.assertAllEqual(labels, expected_input_labels[1])

0 comments on commit bf7addd

Please sign in to comment.