Skip to content

Commit

Permalink
Merge pull request tensorflow#541 from kentonl/master
Browse files Browse the repository at this point in the history
Allows IndexedSlices to be fed and fetched.
  • Loading branch information
mrry committed Jan 11, 2016
2 parents cfc73af + 8ddbadb commit 22b6b23
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tensorflow/python/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def run(self, fetches, feed_dict=None):
"""Runs operations in the session. See `Session.run()` for details."""
raise NotImplementedError('Run')

def _get_indexed_slices_value_from_fetches(fetched_vals):
return ops.IndexedSlicesValue(fetched_vals[0], fetched_vals[1], fetched_vals[2] if len(fetched_vals) == 3 else None)

def _get_feeds_for_indexed_slices(feed, feed_val):
return list(zip(
[feed.values, feed.indices] if feed.dense_shape is None
else [feed.values, feed.indices, feed.dense_shape], feed_val))

class BaseSession(SessionInterface):
"""A class for interacting with a TensorFlow computation.
Expand Down Expand Up @@ -221,6 +228,14 @@ def as_default(self):
lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)),
lambda feed, feed_val: list(zip(
[feed.indices, feed.values, feed.shape], feed_val))),
# IndexedSlices are fetched as IndexedSlicesValues. They can be fed
# IndexedSlicesValues or normal tuples.
(ops.IndexedSlices,
lambda fetch: (
[fetch.values, fetch.indices] if fetch.dense_shape is None
else [fetch.values, fetch.indices, fetch.dense_shape],
_get_indexed_slices_value_from_fetches),
_get_feeds_for_indexed_slices),
# The default catches all types and performs no expansions.
(object,
lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
Expand Down
133 changes: 133 additions & 0 deletions tensorflow/python/client/session_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,139 @@ def testFeedSparseTensor(self):
self.assertAllEqual(sp2_out.values, values)
self.assertAllEqual(sp2_out.shape, shape)

def testFetchIndexedSlices(self):
with session.Session() as s:
indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
values = np.array([1.0, 2.0]).astype(np.float32)
dense_shape = np.array([7, 9, 2]).astype(np.int64)
ind = ops.IndexedSlices(
constant_op.constant(values),
constant_op.constant(indices),
constant_op.constant(dense_shape))
# Single fetch, use as tuple
ind_out = s.run(ind)
values_out, indices_out, dense_shape_out = ind_out
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(dense_shape_out, dense_shape)
# Single fetch, use as IndexedSlicesValue
ind_out = s.run(ind)
self.assertAllEqual(ind_out.values, values)
self.assertAllEqual(ind_out.indices, indices)
self.assertAllEqual(ind_out.dense_shape, dense_shape)
# Tuple fetch, use as tuple
values_out, indices_out, dense_shape_out = s.run(ind)
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(dense_shape_out, dense_shape)
# List fetch, use as tuple
(values_out, indices_out, dense_shape_out), = s.run([ind])
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(dense_shape_out, dense_shape)
# List fetch, use as IndexedSlicesValue
ind_out, = s.run([ind])
self.assertAllEqual(ind_out.values, values)
self.assertAllEqual(ind_out.indices, indices)
self.assertAllEqual(ind_out.dense_shape, dense_shape)

def testFeedIndexedSlices(self):
with session.Session() as s:
values = np.array([1.0, 2.0]).astype(np.float32)
indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
dense_shape = np.array([7, 9, 2]).astype(np.int64)
ind = ops.IndexedSlices(
array_ops.placeholder(dtype=np.float32, shape=(2,)),
array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
array_ops.placeholder(dtype=np.int64, shape=(3,)),)
ind_values = array_ops.identity(ind.values)
ind_indices = array_ops.identity(ind.indices)
ind_dense_shape = array_ops.identity(ind.dense_shape)
ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape)
# Feed with tuple
values_out, indices_out, dense_shape_out = s.run(
[ind_values, ind_indices, ind_dense_shape], {ind: (values, indices, dense_shape)})
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(dense_shape_out, dense_shape)
# Feed with IndexedSlicesValue
values_out, indices_out, dense_shape_out = s.run(
[ind_values, ind_indices, ind_dense_shape],
{ind: ops.IndexedSlicesValue(values, indices, dense_shape)})
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(dense_shape_out, dense_shape)
# Feed with IndexedSlicesValue, fetch IndexedSlicesValue
ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices, dense_shape)})
self.assertAllEqual(ind2_out.values, values)
self.assertAllEqual(ind2_out.indices, indices)
self.assertAllEqual(ind2_out.dense_shape, dense_shape)

def testFetchIndexedSlicesWithoutDenseShape(self):
with session.Session() as s:
indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
values = np.array([1.0, 2.0]).astype(np.float32)
dense_shape = None
ind = ops.IndexedSlices(
constant_op.constant(values),
constant_op.constant(indices),
None)
# Single fetch, use as tuple
ind_out = s.run(ind)
values_out, indices_out, dense_shape_out = ind_out
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(dense_shape_out, dense_shape)
# Single fetch, use as IndexedSlicesValue
ind_out = s.run(ind)
self.assertAllEqual(ind_out.values, values)
self.assertAllEqual(ind_out.indices, indices)
self.assertAllEqual(ind_out.dense_shape, dense_shape)
# Tuple fetch, use as tuple
values_out, indices_out, dense_shape_out = s.run(ind)
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(dense_shape_out, dense_shape)
# List fetch, use as tuple
(values_out, indices_out, dense_shape_out), = s.run([ind])
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(dense_shape_out, dense_shape)
# List fetch, use as IndexedSlicesValue
ind_out, = s.run([ind])
self.assertAllEqual(ind_out.values, values)
self.assertAllEqual(ind_out.indices, indices)
self.assertAllEqual(ind_out.dense_shape, dense_shape)

def testFeedIndexedSlicesWithoutDenseShape(self):
with session.Session() as s:
values = np.array([1.0, 2.0]).astype(np.float32)
indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
dense_shape = None
ind = ops.IndexedSlices(
array_ops.placeholder(dtype=np.float32, shape=(2,)),
array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
None)
ind_values = array_ops.identity(ind.values)
ind_indices = array_ops.identity(ind.indices)
ind2 = ops.IndexedSlices(ind_values, ind_indices)
# Feed with tuple
values_out, indices_out = s.run(
[ind_values, ind_indices], {ind: (values, indices)})
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
# Feed with IndexedSlicesValue
values_out, indices_out = s.run(
[ind_values, ind_indices],
{ind: ops.IndexedSlicesValue(values, indices, dense_shape)})
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
# Feed with IndexedSlicesValue, fetch IndexedSlicesValue
ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices, dense_shape)})
self.assertAllEqual(ind2_out.values, values)
self.assertAllEqual(ind2_out.indices, indices)
self.assertAllEqual(ind2_out.dense_shape, dense_shape)

def testExtendWithStatelessOperations(self):
with session.Session() as s:
a = constant_op.constant(1.0, shape=[1, 2])
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/python/framework/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,8 @@ def __str__(self):
self._indices, self._values,
(", dense_shape=%s" % self._dense_shape) if self._dense_shape else "")

IndexedSlicesValue = collections.namedtuple("IndexedSlicesValue",
["values", "indices", "dense_shape"])

class SparseTensor(object):
"""Represents a sparse tensor.
Expand Down

0 comments on commit 22b6b23

Please sign in to comment.