From d991c8221627002efb64839b6c62a4fbca3e7d52 Mon Sep 17 00:00:00 2001 From: Kenton Lee Date: Tue, 15 Dec 2015 19:56:57 -0800 Subject: [PATCH 1/4] Allows IndexedSlices to be fed and fetched. Change-Id: If19f09cd53db402a16a51b9f5f9f7b7c616e0a4f --- tensorflow/python/client/session.py | 14 ++++++++++++++ tensorflow/python/framework/ops.py | 4 ++++ 2 files changed, 18 insertions(+) diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 74358349b77fd4..5c23baf64c568e 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -221,6 +221,20 @@ def as_default(self): lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)), lambda feed, feed_val: list(zip( [feed.indices, feed.values, feed.shape], feed_val))), + # IndexedSlices are fetched as IndexedSlicesValues or + # IndexedSlicesWithoutDenseShapeValues. They can be fed + # IndexedSlicesValues or IndexedSlicesWithoutDenseShapeValues or normal + # tuples. + (ops.IndexedSlices, + lambda fetch: ( + [fetch.values, fetch.indices] if fetch.dense_shape is None + else [fetch.values, fetch.indices, fetch.dense_shape], + lambda fetched_vals: + ops.IndexedSlicesValue(*fetched_vals) if len(fetched_vals) == 3 + else ops.IndexedSlicesWithoutDenseShapeValue(*fetched_vals)), + lambda feed, feed_val: list(zip( + [feed.values, feed.indices] if feed.dense_shape is None + else [feed.values, feed.indices, feed.dense_shape], feed_val))), # The default catches all types and performs no expansions. (object, lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]), diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index d73cf47b713382..18ecf992e9f299 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -751,6 +751,10 @@ def __str__(self): self._indices, self._values, (", dense_shape=%s" % self._dense_shape) if self._dense_shape else "") +IndexedSlicesValue = collections.namedtuple("IndexedSlicesValue", + ["values", "indices", "dense_shape"]) +IndexedSlicesWithoutDenseShapeValue = collections.namedtuple("IndexedSlicesWithoutDenseShapeValue", + ["values", "indices"]) class SparseTensor(object): """Represents a sparse tensor. From c52794939b3a5e1f0568c3cbf38025e4e3964e72 Mon Sep 17 00:00:00 2001 From: Kenton Lee Date: Thu, 17 Dec 2015 00:57:24 -0800 Subject: [PATCH 2/4] Adds test for fetching and feeding IndexedSlices. Change-Id: I524594cb630893f6b23fa8f0d99c87a46b9b1a15 --- tensorflow/python/client/session_test.py | 125 +++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 6b1a6bb09f0780..c34e5a250a2269 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -248,6 +248,131 @@ def testFeedSparseTensor(self): self.assertAllEqual(sp2_out.values, values) self.assertAllEqual(sp2_out.shape, shape) + def testFetchIndexedSlices(self): + with session.Session() as s: + indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) + values = np.array([1.0, 2.0]).astype(np.float32) + dense_shape = np.array([7, 9, 2]).astype(np.int64) + ind = ops.IndexedSlices( + constant_op.constant(values), + constant_op.constant(indices), + constant_op.constant(dense_shape)) + # Single fetch, use as tuple + ind_out = s.run(ind) + values_out, indices_out, dense_shape_out = ind_out + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) + # Single fetch, use as IndexedSlicesValue + ind_out = s.run(ind) + self.assertAllEqual(ind_out.values, values) + self.assertAllEqual(ind_out.indices, indices) + self.assertAllEqual(ind_out.dense_shape, dense_shape) + # Tuple fetch, use as tuple + values_out, indices_out, dense_shape_out = s.run(ind) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) + # List fetch, use as tuple + (values_out, indices_out, dense_shape_out), = s.run([ind]) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) + # List fetch, use as IndexedSlicesValue + ind_out, = s.run([ind]) + self.assertAllEqual(ind_out.values, values) + self.assertAllEqual(ind_out.indices, indices) + self.assertAllEqual(ind_out.dense_shape, dense_shape) + + def testFeedIndexedSlices(self): + with session.Session() as s: + values = np.array([1.0, 2.0]).astype(np.float32) + indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) + dense_shape = np.array([7, 9, 2]).astype(np.int64) + ind = ops.IndexedSlices( + array_ops.placeholder(dtype=np.float32, shape=(2,)), + array_ops.placeholder(dtype=np.int64, shape=(2, 3)), + array_ops.placeholder(dtype=np.int64, shape=(3,)),) + ind_values = array_ops.identity(ind.values) + ind_indices = array_ops.identity(ind.indices) + ind_dense_shape = array_ops.identity(ind.dense_shape) + ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape) + # Feed with tuple + values_out, indices_out, dense_shape_out = s.run( + [ind_values, ind_indices, ind_dense_shape], {ind: (values, indices, dense_shape)}) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) + # Feed with IndexedSlicesValue + values_out, indices_out, dense_shape_out = s.run( + [ind_values, ind_indices, ind_dense_shape], + {ind: ops.IndexedSlicesValue(values, indices, dense_shape)}) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) + # Feed with IndexedSlicesValue, fetch IndexedSlicesValue + ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices, dense_shape)}) + self.assertAllEqual(ind2_out.values, values) + self.assertAllEqual(ind2_out.indices, indices) + self.assertAllEqual(ind2_out.dense_shape, dense_shape) + + def testFetchIndexedSlicesWithoutDenseShape(self): + with session.Session() as s: + indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) + values = np.array([1.0, 2.0]).astype(np.float32) + ind = ops.IndexedSlices( + constant_op.constant(values), + constant_op.constant(indices), + None) + # Single fetch, use as tuple + ind_out = s.run(ind) + values_out, indices_out = ind_out + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + # Single fetch, use as IndexedSlicesWithoutDenseShapeValue + ind_out = s.run(ind) + self.assertAllEqual(ind_out.values, values) + self.assertAllEqual(ind_out.indices, indices) + # Tuple fetch, use as tuple + values_out, indices_out = s.run(ind) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + # List fetch, use as tuple + (values_out, indices_out), = s.run([ind]) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + # List fetch, use as IndexedSlicesWithoutDenseShapeValue + ind_out, = s.run([ind]) + self.assertAllEqual(ind_out.values, values) + self.assertAllEqual(ind_out.indices, indices) + + def testFeedIndexedSlicesWithDenseShape(self): + with session.Session() as s: + values = np.array([1.0, 2.0]).astype(np.float32) + indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) + ind = ops.IndexedSlices( + array_ops.placeholder(dtype=np.float32, shape=(2,)), + array_ops.placeholder(dtype=np.int64, shape=(2, 3)), + None) + ind_values = array_ops.identity(ind.values) + ind_indices = array_ops.identity(ind.indices) + ind2 = ops.IndexedSlices(ind_values, ind_indices) + # Feed with tuple + values_out, indices_out = s.run( + [ind_values, ind_indices], {ind: (values, indices)}) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + # Feed with IndexedSlicesWithoutDenseShapeValue + values_out, indices_out = s.run( + [ind_values, ind_indices], + {ind: ops.IndexedSlicesWithoutDenseShapeValue(values, indices)}) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + # Feed with IndexedSlicesWithoutDenseShapeValue, fetch IndexedSlicesWithoutDenseShapeValue + ind2_out = s.run(ind2, {ind: ops.IndexedSlicesWithoutDenseShapeValue(values, indices)}) + self.assertAllEqual(ind2_out.values, values) + self.assertAllEqual(ind2_out.indices, indices) + def testExtendWithStatelessOperations(self): with session.Session() as s: a = constant_op.constant(1.0, shape=[1, 2]) From c42490367c403b1836338523ce81eb2f8ac05647 Mon Sep 17 00:00:00 2001 From: Kenton Lee Date: Wed, 6 Jan 2016 05:09:20 -0800 Subject: [PATCH 3/4] Removes the need for IndexedSlicesWithoutDenseShapeValue. Tests for feeding and fetching IndexedSlices without dense shapes now reuse IndexedSlicesValue. Change-Id: I84eb7289d262797275e6eaa08b5e5daf4d99fe3d --- tensorflow/python/client/session.py | 21 +++++++------- tensorflow/python/client/session_test.py | 36 +++++++++++------------- tensorflow/python/framework/ops.py | 2 -- 3 files changed, 27 insertions(+), 32 deletions(-) diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index f3c1136315806c..44d0372a56b7b1 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -52,6 +52,13 @@ def run(self, fetches, feed_dict=None): """Runs operations in the session. See `Session.run()` for details.""" raise NotImplementedError('Run') +def _get_indexed_slices_value_from_fetches(fetched_vals): + return ops.IndexedSlicesValue(fetched_vals[0], fetched_vals[1], fetched_vals[2] if len(fetched_vals) == 3 else None) + +def _get_feeds_for_indexed_slices(feed, feed_val): + return list(zip( + [feed.values, feed.indices] if feed.dense_shape is None + else [feed.values, feed.indices, feed.dense_shape], feed_val)) class BaseSession(SessionInterface): """A class for interacting with a TensorFlow computation. @@ -221,20 +228,14 @@ def as_default(self): lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)), lambda feed, feed_val: list(zip( [feed.indices, feed.values, feed.shape], feed_val))), - # IndexedSlices are fetched as IndexedSlicesValues or - # IndexedSlicesWithoutDenseShapeValues. They can be fed - # IndexedSlicesValues or IndexedSlicesWithoutDenseShapeValues or normal - # tuples. + # IndexedSlices are fetched as IndexedSlicesValues. They can be fed + # IndexedSlicesValues or normal tuples. (ops.IndexedSlices, lambda fetch: ( [fetch.values, fetch.indices] if fetch.dense_shape is None else [fetch.values, fetch.indices, fetch.dense_shape], - lambda fetched_vals: - ops.IndexedSlicesValue(*fetched_vals) if len(fetched_vals) == 3 - else ops.IndexedSlicesWithoutDenseShapeValue(*fetched_vals)), - lambda feed, feed_val: list(zip( - [feed.values, feed.indices] if feed.dense_shape is None - else [feed.values, feed.indices, feed.dense_shape], feed_val))), + _get_indexed_slices_value_from_fetches), + _get_feeds_for_indexed_slices), # The default catches all types and performs no expansions. (object, lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]), diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index e821a7d986e23a..1193b2080a2b0b 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -44,7 +44,6 @@ from tensorflow.python.platform import googletest from tensorflow.python.util import compat - # NOTE(mrry): Dummy shape registration for op used in the tests. ops.RegisterShape('ConstructionFails')(None) @@ -320,36 +319,43 @@ def testFetchIndexedSlicesWithoutDenseShape(self): with session.Session() as s: indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) values = np.array([1.0, 2.0]).astype(np.float32) + dense_shape = None ind = ops.IndexedSlices( constant_op.constant(values), constant_op.constant(indices), None) # Single fetch, use as tuple ind_out = s.run(ind) - values_out, indices_out = ind_out + values_out, indices_out, dense_shape_out = ind_out self.assertAllEqual(values_out, values) self.assertAllEqual(indices_out, indices) - # Single fetch, use as IndexedSlicesWithoutDenseShapeValue + self.assertAllEqual(dense_shape_out, dense_shape) + # Single fetch, use as IndexedSlicesValue ind_out = s.run(ind) self.assertAllEqual(ind_out.values, values) self.assertAllEqual(ind_out.indices, indices) + self.assertAllEqual(ind_out.dense_shape, dense_shape) # Tuple fetch, use as tuple - values_out, indices_out = s.run(ind) + values_out, indices_out, dense_shape_out = s.run(ind) self.assertAllEqual(values_out, values) self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) # List fetch, use as tuple - (values_out, indices_out), = s.run([ind]) + (values_out, indices_out, dense_shape_out), = s.run([ind]) self.assertAllEqual(values_out, values) self.assertAllEqual(indices_out, indices) - # List fetch, use as IndexedSlicesWithoutDenseShapeValue + self.assertAllEqual(dense_shape_out, dense_shape) + # List fetch, use as IndexedSlicesValue ind_out, = s.run([ind]) self.assertAllEqual(ind_out.values, values) self.assertAllEqual(ind_out.indices, indices) + self.assertAllEqual(ind_out.dense_shape, dense_shape) - def testFeedIndexedSlicesWithDenseShape(self): + def testFeedIndexedSlicesWithoutDenseShape(self): with session.Session() as s: values = np.array([1.0, 2.0]).astype(np.float32) indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) + dense_shape = None ind = ops.IndexedSlices( array_ops.placeholder(dtype=np.float32, shape=(2,)), array_ops.placeholder(dtype=np.int64, shape=(2, 3)), @@ -357,21 +363,11 @@ def testFeedIndexedSlicesWithDenseShape(self): ind_values = array_ops.identity(ind.values) ind_indices = array_ops.identity(ind.indices) ind2 = ops.IndexedSlices(ind_values, ind_indices) - # Feed with tuple - values_out, indices_out = s.run( - [ind_values, ind_indices], {ind: (values, indices)}) - self.assertAllEqual(values_out, values) - self.assertAllEqual(indices_out, indices) - # Feed with IndexedSlicesWithoutDenseShapeValue - values_out, indices_out = s.run( - [ind_values, ind_indices], - {ind: ops.IndexedSlicesWithoutDenseShapeValue(values, indices)}) - self.assertAllEqual(values_out, values) - self.assertAllEqual(indices_out, indices) - # Feed with IndexedSlicesWithoutDenseShapeValue, fetch IndexedSlicesWithoutDenseShapeValue - ind2_out = s.run(ind2, {ind: ops.IndexedSlicesWithoutDenseShapeValue(values, indices)}) + # Feed with IndexedSlicesValue, fetch IndexedSlicesValue + ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices, dense_shape)}) self.assertAllEqual(ind2_out.values, values) self.assertAllEqual(ind2_out.indices, indices) + self.assertAllEqual(ind2_out.dense_shape, dense_shape) def testExtendWithStatelessOperations(self): with session.Session() as s: diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index e9303a2afaf466..91c3827219626d 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -782,8 +782,6 @@ def __str__(self): IndexedSlicesValue = collections.namedtuple("IndexedSlicesValue", ["values", "indices", "dense_shape"]) -IndexedSlicesWithoutDenseShapeValue = collections.namedtuple("IndexedSlicesWithoutDenseShapeValue", - ["values", "indices"]) class SparseTensor(object): """Represents a sparse tensor. From 709c9229151cb9a06296e55e9ab71adb9827371a Mon Sep 17 00:00:00 2001 From: Kenton Lee Date: Wed, 6 Jan 2016 05:20:53 -0800 Subject: [PATCH 4/4] Adds back the tests for fetching tuples from IndexedSlices without dense shapes. Change-Id: I95c9eb219b5c6395892f132ae5034a45460e9855 --- tensorflow/python/client/session_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 1193b2080a2b0b..eb146b05e9dcbb 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -44,6 +44,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.util import compat + # NOTE(mrry): Dummy shape registration for op used in the tests. ops.RegisterShape('ConstructionFails')(None) @@ -363,6 +364,17 @@ def testFeedIndexedSlicesWithoutDenseShape(self): ind_values = array_ops.identity(ind.values) ind_indices = array_ops.identity(ind.indices) ind2 = ops.IndexedSlices(ind_values, ind_indices) + # Feed with tuple + values_out, indices_out = s.run( + [ind_values, ind_indices], {ind: (values, indices)}) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + # Feed with IndexedSlicesValue + values_out, indices_out = s.run( + [ind_values, ind_indices], + {ind: ops.IndexedSlicesValue(values, indices, dense_shape)}) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) # Feed with IndexedSlicesValue, fetch IndexedSlicesValue ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices, dense_shape)}) self.assertAllEqual(ind2_out.values, values)