Skip to content

Commit

Permalink
Merge pull request tensorflow#5127 from Mistobaan/feature/add-tf-sele…
Browse files Browse the repository at this point in the history
…ct-scalar-3945

extend tf.select to broadcast a scalar condition tensorflow#3945
  • Loading branch information
benoitsteiner authored Oct 22, 2016
2 parents 41ba1e0 + ec6e818 commit c9c7e23
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 17 deletions.
10 changes: 10 additions & 0 deletions tensorflow/core/framework/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,16 @@ TEST(Tensor_Float, Reshape) {
}

TEST(Tensor_Scalar, Basics) {
{
Tensor t(DT_BOOL, TensorShape({}));
EXPECT_EQ(1, t.NumElements());
auto Tt = t.scalar<bool>();
EXPECT_EQ(1, Tt.size());
EXPECT_EQ(0, Tt.rank());
EXPECT_FALSE(Tt());
t.scalar<bool>()() = true;
EXPECT_TRUE(Tt());
}
{
Tensor t(DT_FLOAT, TensorShape({}));
EXPECT_EQ(1, t.NumElements());
Expand Down
20 changes: 20 additions & 0 deletions tensorflow/core/kernels/cwise_op_gpu_select.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#if GOOGLE_CUDA

#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {
namespace functor {
Expand All @@ -31,6 +32,24 @@ struct SelectFunctor<GPUDevice, T> {
}
};

template <typename T>
struct SelectScalarFunctor<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
typename TTypes<bool>::ConstScalar cond,
typename TTypes<T>::ConstFlat then_flat,
typename TTypes<T>::ConstFlat else_flat) {

Eigen::IndexList<Eigen::type2index<1>> rank1;
const int size = then_flat.dimension(0);
Eigen::array<int, 1> broadcast_dims{size};

To32Bit(out).device(d) = cond.reshape(rank1)
.broadcast(broadcast_dims)
.select(then_flat, else_flat);

}
};

template <typename T>
struct BatchSelectFunctor<GPUDevice, T> {
void operator()(const GPUDevice& d,
Expand Down Expand Up @@ -68,6 +87,7 @@ struct BatchSelectFunctor<GPUDevice, T> {

#define SELECT_FUNCTOR(T) \
template struct SelectFunctor<GPUDevice, T>; \
template struct SelectScalarFunctor<GPUDevice, T>; \
template struct BatchSelectFunctor<GPUDevice, T>;

SELECT_FUNCTOR(Eigen::half);
Expand Down
35 changes: 35 additions & 0 deletions tensorflow/core/kernels/cwise_op_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class SelectOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));

if (TensorShapeUtils::IsScalar(cond->shape())){
ComputeScalar(ctx, cond, then, else_);
return;
}

bool broadcasting = (TensorShapeUtils::IsVector(cond->shape()) &&
!TensorShapeUtils::IsVector(then->shape()));

Expand Down Expand Up @@ -108,6 +113,25 @@ class SelectOp : public OpKernel {
}
}

void ComputeScalar(OpKernelContext* ctx, const Tensor* cond,
const Tensor* then, const Tensor* else_) {
OP_REQUIRES(
ctx, then->shape().IsSameSize(else_->shape()),
errors::InvalidArgument(
"'then' and 'else' must have the same size. but received: ",
then->shape().DebugString(), " vs. ",
else_->shape().DebugString()));

Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));

if (output->NumElements() > 0) {
functor::SelectScalarFunctor<Device, T> func;
TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
then->flat<T>(), else_->flat<T>());
}
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
};
Expand Down Expand Up @@ -152,6 +176,17 @@ struct SelectFunctor<CPUDevice, T> {
}
};

// CPU Specializations of Select functors with scalar
template <typename T>
struct SelectScalarFunctor<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
TTypes<bool>::ConstScalar cond,
typename TTypes<T>::ConstFlat then_flat,
typename TTypes<T>::ConstFlat else_flat) {
out.device(d) = cond() ? then_flat : else_flat;
}
};

template <typename T>
struct BatchSelectFunctor<CPUDevice, T> {
void operator()(const CPUDevice& d,
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/core/kernels/cwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,14 @@ struct SelectFunctor {
typename TTypes<T>::ConstFlat else_flat);
};

template <typename Device, typename T>
struct SelectScalarFunctor {
void operator()(const Device& d, typename TTypes<T>::Flat out,
typename TTypes<bool>::ConstScalar cond,
typename TTypes<T>::ConstFlat then_flat,
typename TTypes<T>::ConstFlat else_flat);
};

template <typename Device, typename T>
struct BatchSelectFunctor {
void operator()(const Device& d,
Expand Down
50 changes: 35 additions & 15 deletions tensorflow/core/ops/math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,8 @@ REGISTER_OP("Select")
.SetShapeFn([](InferenceContext* c) {
// The inputs 'then' and 'else' must have the same shape.
ShapeHandle data = c->input(1);
TF_RETURN_IF_ERROR(c->Merge(data, c->input(2), &data));
ShapeHandle other = c->input(2);
TF_RETURN_IF_ERROR(c->Merge(data, other, &data));

// The input 'cond' must either have the same shape as 'then' and
// 'else', or be a vector if 'then' and 'else' are at least vectors.
Expand All @@ -929,30 +930,49 @@ REGISTER_OP("Select")
const int32 cond_rank = c->Rank(cond);
const int32 data_rank = c->Rank(data);

if (cond_rank == 0){
// The rank of 'cond' is a scalar.
// t and e can have any shape.
c->set_output(0, data);
return Status::OK();
}

if (cond_rank != 1) {
// If the rank of 'cond' is != 1, the shape must match 'then' and 'else'
// If 'cond' is not a vector, and not a scalar,
// then shape must match 'then' and 'else'
TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
c->set_output(0, data);
return Status::OK();
}
if (data_rank != 0) {
// If then and else are not scalars, then cond must be at least
// a vector, and its first value must match that of 'else'
TF_RETURN_IF_ERROR(c->WithRankAtLeast(cond, 1, &cond));
if (cond_rank == 1) {
TF_RETURN_IF_ERROR(c->Merge(cond, c->Vector(c->Dim(data, 0)), &cond));
}

if (data_rank == 0) {
// if 'then' and 'else' are scalar also the cond must be
TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
c->set_output(0, data);
return Status::OK();
}

if (cond_rank == 1) {
// if the cond is a vector and the 'then' is not a scalar,
// the first dimension of 'then' and 'else'
TF_RETURN_IF_ERROR(c->Merge(cond, c->Vector(c->Dim(data, 0)), &cond));
c->set_output(0, data);
return Status::OK();
}

c->set_output(0, data);
return Status::OK();
})
})
.Doc(R"doc(
Selects elements from `t` or `e`, depending on `condition`.
The `t`, and `e` tensors must all have the same shape,
and the output will also have that shape. The `condition` tensor
must be a scalar if `t` and `e` are scalars. If `t` and `e` are vectors
or higher rank, then `condition` must be either a vector with size
matching the first dimension of `t`, or must have the same shape as `t`.
The `t`, and `e` tensors must all have the same shape, and the
output will also have that shape.
The `condition` tensor must be a scalar if `t` and `e` are scalars.
If `t` and `e` are vectors or higher rank, then `condition` must be either a
scalar, a vector with size matching the first dimension of `t`, or must have
the same shape as `t`.
The `condition` tensor acts as a mask that chooses, based on the value at each
element, whether the corresponding element / row in the output should be
Expand Down
7 changes: 5 additions & 2 deletions tensorflow/core/ops/math_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ TEST(MathOpsTest, Select_ShapeFn) {
ShapeInferenceTestOp op("Select");
INFER_OK(op, "?;?;?", "in1|in2");

// scalar case
INFER_OK(op, "[];[1];?", "in1");
INFER_OK(op, "[];?;?", "in1|in2");

INFER_OK(op, "[1];?;?",
"in1|in2"); // When cond is vector, t/e may not match it.
INFER_OK(op, "[1,2];?;?", "in1|in2?");
Expand All @@ -200,8 +203,8 @@ TEST(MathOpsTest, Select_ShapeFn) {
INFER_OK(op, "?;[1,2];?", "in1");
INFER_OK(op, "?;?;[1,2]", "in2");

INFER_OK(op, "[1];[];?", "in1");
INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, "[];[1];?");
INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, "[1];[];?");
INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[];[1];[1,2]");
INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[1,2];[1];?");
INFER_OK(op, "[2];[?];[?]", "in1|in2");

Expand Down
12 changes: 12 additions & 0 deletions tensorflow/python/kernel_tests/cwise_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,18 @@ def _compareGradientY(self, c, x, y, numeric_gradient_type=None):
elif x.dtype == np.float64:
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)

def testScalar(self):
c = True
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 3, 2) * 100
for t in [np.float16, np.float32, np.float64, np.int32, np.int64,
np.complex64, np.complex128]:
xt = x.astype(t)
yt = y.astype(t)
self._compare(c, xt, yt, use_gpu=False)
if t in [np.float16, np.float32, np.float64]:
self._compare(c, xt, yt, use_gpu=True)

def testBasic(self):
c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
x = np.random.rand(1, 3, 2) * 100
Expand Down

0 comments on commit c9c7e23

Please sign in to comment.