Skip to content

Commit

Permalink
Add int64 out_idx support for listdiff/list_diff/setdiff1d` (t…
Browse files Browse the repository at this point in the history
…ensorflow#13839)

* Add `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d`

This fix tries to add `int64` `out_idx` support for `listdiff`/`list_diff`/`setdiff1d`.
As was specified in docs (`tf.setdiff1d.__doc__`), it is possible to specify
`tf.int32` or `tf.int64` for the type of the output idx. However,
the `tf.int64` kernel has not been registered. As a consequence,
an error will be thrown out if `tf.int64` is used.

This fix adds `int64` out_idx` support for `listdiff`/`list_diff`/`setdiff1d`

Signed-off-by: Yong Tang <[email protected]>

* Add template for signature matching of ListDiff kernel.

Signed-off-by: Yong Tang <[email protected]>

* Add test cases for `int64` out_idx support for `tf.listdiff`/`setdiff1d`

Signed-off-by: Yong Tang <[email protected]>

* Add test case for int32 (missed in the last commit)

Signed-off-by: Yong Tang <[email protected]>
  • Loading branch information
yongtang authored and Vijay Vasudevan committed Oct 20, 2017
1 parent 7a1ddf2 commit 513f7df
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
16 changes: 11 additions & 5 deletions tensorflow/core/kernels/listdiff_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"

namespace tensorflow {
template <typename T>
template <typename T, typename Tidx>
class ListDiffOp : public OpKernel {
public:
explicit ListDiffOp(OpKernelConstruction* context) : OpKernel(context) {
const DataType dt = DataTypeToEnum<T>::v();
OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, DT_INT32}));
const DataType dtidx = DataTypeToEnum<Tidx>::v();
OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, dtidx}));
}

void Compute(OpKernelContext* context) override {
Expand Down Expand Up @@ -72,9 +73,9 @@ class ListDiffOp : public OpKernel {

Tensor* indices = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices));
auto Tindices = indices->vec<int32>();
auto Tindices = indices->vec<Tidx>();

for (int i = 0, p = 0; i < static_cast<int32>(x_size); ++i) {
for (Tidx i = 0, p = 0; i < static_cast<Tidx>(x_size); ++i) {
if (y_set.count(Tx(i)) == 0) {
OP_REQUIRES(context, p < out_size,
errors::InvalidArgument(
Expand All @@ -95,7 +96,12 @@ class ListDiffOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("out_idx"), \
ListDiffOp<type>)
ListDiffOp<type, int32>) \
REGISTER_KERNEL_BUILDER(Name("ListDiff") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("out_idx"), \
ListDiffOp<type, int64>)

TF_CALL_REAL_NUMBER_TYPES(REGISTER_LISTDIFF);
REGISTER_LISTDIFF(string);
Expand Down
20 changes: 11 additions & 9 deletions tensorflow/python/kernel_tests/listdiff_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,17 @@ def _testListDiff(self, x, y, out, idx):
y = [compat.as_bytes(str(a)) for a in y]
out = [compat.as_bytes(str(a)) for a in out]
for diff_func in [array_ops.setdiff1d]:
with self.test_session() as sess:
x_tensor = ops.convert_to_tensor(x, dtype=dtype)
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
out_tensor, idx_tensor = diff_func(x_tensor, y_tensor)
tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
self.assertAllEqual(tf_out, out)
self.assertAllEqual(tf_idx, idx)
self.assertEqual(1, out_tensor.get_shape().ndims)
self.assertEqual(1, idx_tensor.get_shape().ndims)
for index_dtype in [dtypes.int32, dtypes.int64]:
with self.test_session() as sess:
x_tensor = ops.convert_to_tensor(x, dtype=dtype)
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
out_tensor, idx_tensor = diff_func(x_tensor, y_tensor,
index_dtype=index_dtype)
tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
self.assertAllEqual(tf_out, out)
self.assertAllEqual(tf_idx, idx)
self.assertEqual(1, out_tensor.get_shape().ndims)
self.assertEqual(1, idx_tensor.get_shape().ndims)

def testBasic1(self):
x = [1, 2, 3, 4]
Expand Down

0 comments on commit 513f7df

Please sign in to comment.