Skip to content

Commit

Permalink
Add ConcatV2 operation which is the same as Concat but with argument …
Browse files Browse the repository at this point in the history
…order

swapped.
Change: 136763054
  • Loading branch information
tensorflower-gardener committed Oct 20, 2016
1 parent 45e11c3 commit 5cfb0fd
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 51 deletions.
36 changes: 23 additions & 13 deletions tensorflow/core/framework/common_shape_fns.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,18 +718,22 @@ Status ReductionShapeForReduceJoin(InferenceContext* c) {
return Status::OK();
}

Status ConcatShape(InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
Status ConcatShapeHelper(InferenceContext* c, bool dim_is_last_argument) {
const int dim_index = dim_is_last_argument ? c->num_inputs() - 1 : 0;
const int start_value_index = dim_is_last_argument ? 0 : 1;
const int end_value_index =
dim_is_last_argument ? c->num_inputs() - 1 : c->num_inputs();

const Tensor* concat_dim_t = c->input_tensor(0);
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
const Tensor* concat_dim_t = c->input_tensor(dim_index);
if (concat_dim_t == nullptr) {
// Return an unknown shape with same rank as inputs, or an unknown rank
// if no input's rank is known.

// Find rank.
int32 rank = InferenceContext::kUnknownRank;
for (int i = 1; i < c->num_inputs(); ++i) {
for (int i = start_value_index; i < end_value_index; ++i) {
if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i));
if (rank != InferenceContext::kUnknownRank) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused));
Expand All @@ -753,25 +757,23 @@ Status ConcatShape(InferenceContext* c) {
// Merge all the non-concat dims, and sum the concat dim to make an output
// shape.
const int32 concat_dim = concat_dim_t->scalar<int32>()();
if (concat_dim < 0) {
return errors::InvalidArgument("Expected concat_dim >= 0, but got ",
concat_dim);
}
// Minimum required number of dimensions.
const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;

ShapeHandle output_before;
ShapeHandle output_after;

ShapeHandle input = c->input(c->num_inputs() - 1);
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, concat_dim + 1, &input));
ShapeHandle input = c->input(end_value_index - 1);
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
DimensionHandle output_middle = c->Dim(input, concat_dim);
TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));

for (int i = c->num_inputs() - 2; i > 0; --i) {
for (int i = end_value_index - 2; i >= start_value_index; --i) {
ShapeHandle before;
ShapeHandle after;
input = c->input(i);
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, concat_dim + 1, &input));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
DimensionHandle middle = c->Dim(input, concat_dim);
TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
Expand All @@ -789,6 +791,14 @@ Status ConcatShape(InferenceContext* c) {
return Status::OK();
}

Status ConcatShape(InferenceContext* c) {
return ConcatShapeHelper(c, /* dim_is_last_argument */ false);
}

Status ConcatV2Shape(InferenceContext* c) {
return ConcatShapeHelper(c, /* dim_is_last_argument */ true);
}

Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
ShapeHandle shape_x = c->input(0);
ShapeHandle shape_y = c->input(1);
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/core/framework/common_shape_fns.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ Status ReductionShapeForReduceJoin(shape_inference::InferenceContext* c);
// Shape function for concat operations.
Status ConcatShape(shape_inference::InferenceContext* c);

// Shape function for concat operations.
Status ConcatV2Shape(shape_inference::InferenceContext* c);

// Shape function for binary operators that broadcast their inputs.
// Tested by ops/math_ops_test.cc.
Status BroadcastBinaryOpShapeFn(InferenceContext* c);
Expand Down
94 changes: 63 additions & 31 deletions tensorflow/core/kernels/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,37 +36,44 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#endif // GOOGLE_CUDA

enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };

// --------------------------------------------------------------------------
template <typename Device, typename T>
class ConcatOp : public OpKernel {
template <typename Device, typename T, AxisArgumentName AxisArgName>
class ConcatBaseOp : public OpKernel {
public:
typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
ConstMatrixVector;

explicit ConcatOp(OpKernelConstruction* c) : OpKernel(c) {}
explicit ConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}

void Compute(OpKernelContext* c) override {
const Tensor* concat_dim_tensor;
OP_REQUIRES_OK(c, c->input("concat_dim", &concat_dim_tensor));
OP_REQUIRES(
c, IsLegacyScalar(concat_dim_tensor->shape()),
errors::InvalidArgument(
"Concat dim tensor should be a scalar integer, but got shape ",
concat_dim_tensor->shape().DebugString()));
const char* axis_attribute_name =
AxisArgName == NAME_IS_AXIS
? "axis"
: AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()),
errors::InvalidArgument(
axis_attribute_name,
" tensor should be a scalar integer, but got shape ",
concat_dim_tensor->shape().DebugString()));
const int32 concat_dim =
internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
OpInputList values;
OP_REQUIRES_OK(c, c->input_list("values", &values));
const int N = values.size();
const int input_dims = values[0].dims();
const TensorShape& input_shape = values[0].shape();
OP_REQUIRES(
c, FastBoundsCheck(concat_dim, input_dims) ||
(allow_legacy_scalars() && concat_dim == 0),
errors::InvalidArgument(
"ConcatOp : Expected concatenating dimensions in the range [", 0,
", ", input_dims, "), but got ", concat_dim));

int axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
OP_REQUIRES(c, (0 <= axis && axis < input_dims) ||
(allow_legacy_scalars() && concat_dim == 0),
errors::InvalidArgument(
"ConcatOp : Expected concatenating dimensions in the range "
"[",
-input_dims, ", ", input_dims, "), but got ", concat_dim));
// Note that we reduce the concat of n-dimensional tensors into a two
// dimensional concat. Assuming the dimensions of any input/output
// tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along
Expand All @@ -75,7 +82,7 @@ class ConcatOp : public OpKernel {
ConstMatrixVector inputs_flat;
inputs_flat.reserve(N);
int64 inputs_flat_dim0 = 1;
for (int d = 0; d < concat_dim; ++d) {
for (int d = 0; d < axis; ++d) {
inputs_flat_dim0 *= input_shape.dim_size(d);
}
int64 output_concat_dim = 0;
Expand All @@ -90,7 +97,7 @@ class ConcatOp : public OpKernel {
input_shape.DebugString(), " vs. shape[", i, "] = ",
in.shape().DebugString()));
for (int j = 0; j < input_dims; ++j) {
if (j == concat_dim) {
if (j == axis) {
continue;
}
OP_REQUIRES(
Expand All @@ -106,15 +113,15 @@ class ConcatOp : public OpKernel {
in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
}
// TODO(irving): Remove check once !allow_legacy_scalars().
output_concat_dim += in.dims() > 0 ? in.dim_size(concat_dim) : 1;
output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
}

TensorShape output_shape(input_shape);
// TODO(irving): Remove rank 0 case once !allow_legacy_scalars().
if (output_shape.dims() == 0) {
output_shape.AddDim(output_concat_dim);
} else {
output_shape.set_dim(concat_dim, output_concat_dim);
output_shape.set_dim(axis, output_concat_dim);
}
Tensor* output = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
Expand All @@ -132,12 +139,23 @@ class ConcatOp : public OpKernel {
}
};

#define REGISTER_CONCAT(type) \
REGISTER_KERNEL_BUILDER(Name("Concat") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.HostMemory("concat_dim"), \
ConcatOp<CPUDevice, type>)
template <typename Device, typename T>
using ConcatOp = ConcatBaseOp<Device, T, NAME_IS_CONCAT_DIM>;
template <typename Device, typename T>
using ConcatV2Op = ConcatBaseOp<Device, T, NAME_IS_AXIS>;

#define REGISTER_CONCAT(type) \
REGISTER_KERNEL_BUILDER(Name("Concat") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.HostMemory("concat_dim"), \
ConcatOp<CPUDevice, type>) \
REGISTER_KERNEL_BUILDER(Name("ConcatV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
ConcatV2Op<CPUDevice, type>)

TF_CALL_ALL_TYPES(REGISTER_CONCAT);
REGISTER_CONCAT(quint8);
Expand All @@ -151,12 +169,18 @@ REGISTER_CONCAT(bfloat16);

#if GOOGLE_CUDA

#define REGISTER_GPU(type) \
REGISTER_KERNEL_BUILDER(Name("Concat") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.HostMemory("concat_dim"), \
ConcatOp<GPUDevice, type>)
#define REGISTER_GPU(type) \
REGISTER_KERNEL_BUILDER(Name("Concat") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.HostMemory("concat_dim"), \
ConcatOp<GPUDevice, type>) \
REGISTER_KERNEL_BUILDER(Name("ConcatV2") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("Tidx") \
.HostMemory("axis"), \
ConcatV2Op<GPUDevice, type>)

TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
REGISTER_GPU(bfloat16);
Expand All @@ -172,6 +196,14 @@ REGISTER_KERNEL_BUILDER(Name("Concat")
.HostMemory("values")
.HostMemory("output"),
ConcatOp<CPUDevice, int32>);
REGISTER_KERNEL_BUILDER(Name("ConcatV2")
.Device(DEVICE_GPU)
.TypeConstraint<int32>("T")
.TypeConstraint<int32>("Tidx")
.HostMemory("values")
.HostMemory("axis")
.HostMemory("output"),
ConcatV2Op<CPUDevice, int32>);

#endif // GOOGLE_CUDA

Expand Down
20 changes: 20 additions & 0 deletions tensorflow/core/ops/array_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,26 @@ output: A `Tensor` with the concatenation of values stacked along the
in `concat_dim` where it has the sum of the sizes.
)doc");

REGISTER_OP("ConcatV2")
.Input("values: N * T")
.Input("axis: Tidx")
.Output("output: T")
.Attr("N: int >= 2")
.Attr("T: type")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ConcatV2Shape)
.Doc(R"doc(
Concatenates tensors along one dimension.
values: List of `N` Tensors to concatenate. Their ranks and types must match,
and their sizes must match in all dimensions except `concat_dim`.
axis: 0-D. The dimension along which to concatenate. Must be in the
range [0, rank(values)).
output: A `Tensor` with the concatenation of values stacked along the
`concat_dim` dimension. This tensor's shape matches that of `values` except
in `concat_dim` where it has the sum of the sizes.
)doc");

REGISTER_OP("ConcatOffset")
.Input("concat_dim: int32")
.Input("shape: N * int32")
Expand Down
86 changes: 82 additions & 4 deletions tensorflow/core/ops/array_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,6 @@ TEST(ArrayOpsTest, Concat_ShapeFn) {
op.input_tensors.push_back(&concat_dim_t);
set_n(2);

// Invalid concat dim value.
concat_dim_t = test::AsScalar(-1);
INFER_ERROR("Expected concat_dim >= 0, but got -1", op, "?;?;?");

// Sum dim 0, merge the other two dims.
concat_dim_t = test::AsScalar(0);
INFER_OK(op, "[];[100,2,?];[10,?,3]", "[110,d1_1,d2_2]");
Expand All @@ -557,16 +553,98 @@ TEST(ArrayOpsTest, Concat_ShapeFn) {
INFER_OK(op, "[];[1,100];[?,10]", "[d1_0,110]");
INFER_OK(op, "[];[?,100];[1,10]", "[d2_0,110]");
// concat_dim is too high.
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
"[];[100];[10,?]");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
"[];[100,5];[10]");
// concat_dim is too low.
concat_dim_t = test::AsScalar(-2);
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
"[];[100];[10,?]");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
"[];[100,5];[10]");

// Repeat successful case with several unknown inputs.
set_n(5);
concat_dim_t = test::AsScalar(1);
INFER_OK(op, "[];?;[1,100,?];[?,?,?];[?,10,3];?", "[d2_0,?,d4_2]");
}

TEST(ArrayOpsTest, ConcatV2_ShapeFn) {
ShapeInferenceTestOp op("ConcatV2");
auto set_n = [&op](int n) {
std::vector<NodeDefBuilder::NodeOut> src_list;
for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
TF_ASSERT_OK(NodeDefBuilder("test", "ConcatV2")
.Input(src_list)
.Input({"axis", 0, DT_INT32})
.Attr("n", n)
.Finalize(&op.node_def));
};

// Confirm dimension[0] of the input (the concat_dim) is a scalar.
set_n(2);
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;?;[1]");

// Test with the input concat_dim tensor not known. This takes the known rank
// of the inputs and makes a tensor of that many unknown dims.
set_n(7);
INFER_OK(op, "?;?;?;?;[1,2,3];?;[3,2,1];?", "[?,?,?]");
set_n(4);
INFER_OK(op, "?;?;[1,2,3,4];[4,3,2,1];?", "[?,?,?,?]");
INFER_OK(op, "?;?;?;?;?", "?"); // output rank unknown
INFER_ERROR("Can't concatenate scalars (use tf.pack instead)", op,
"?;?;[];[];?");
INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;?;[1,2];[1,2,3];?");

// Test when the concat_dim tensor is known. The concatenated dimension is
// summed across all input tensors, and other dimensions are merged.
Tensor concat_dim_t;
op.input_tensors.resize(3);
op.input_tensors[2] = &concat_dim_t;

set_n(2);

// Invalid concat dim value.
// concat_dim_t = test::AsScalar(-1);
// INFER_ERROR("Expected concat_dim >= 0, but got -1", op, "?;?;?");

// Sum dim 0, merge the other two dims.
concat_dim_t = test::AsScalar(0);
INFER_OK(op, "[100,2,?];[10,?,3];[]", "[110,d0_1,d1_2]");
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
"[100,2,5];[10,?,3];[]");
// concat_dim can't be summed, as one value is unknown.
INFER_OK(op, "[100,2,?];[?,?,3];[]", "[?,d0_1,d1_2]");
INFER_OK(op, "[?,2,?];[10,?,3];[]", "[?,d0_1,d1_2]");

// Test with a higher concat_dim.
concat_dim_t = test::AsScalar(1);
INFER_OK(op, "[1,100,?];[?,10,3];[]", "[d0_0,110,d1_2]");
INFER_OK(op, "[1,100];[?,10];[]", "[d0_0,110]");
INFER_OK(op, "[?,100];[1,10];[]", "[d1_0,110]");
// concat_dim is too high.
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
"[100];[10,?];[]");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
"[100,5];[10];[]");
// concat_dim is too low.
concat_dim_t = test::AsScalar(-2);
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
"[100];[10,?];[]");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
"[100,5];[10];[]");

// Repeat successful case with several unknown inputs.
op.input_tensors.resize(6);
op.input_tensors[3] = nullptr;
op.input_tensors[5] = &concat_dim_t;
concat_dim_t = test::AsScalar(1);

set_n(5);
INFER_OK(op, "?;[1,100,?];[?,?,?];[?,10,3];?;[]", "[d1_0,?,d3_2]");
}

TEST(ArrayOpsTest, ConcatOffset_ShapeFn) {
ShapeInferenceTestOp op("ConcatOffset");

Expand Down
Loading

0 comments on commit 5cfb0fd

Please sign in to comment.