From 44843fb5fb2ad53690763fd2c7b66c82f6268a52 Mon Sep 17 00:00:00 2001 From: Xiaoqiang Zheng Date: Tue, 18 Oct 2016 10:14:24 -0800 Subject: [PATCH] Adding a CPU kernel for adjust_contrast that is 45 times faster. Add more shapes and a numpy reference implementation for testing. Before this CL: BM_AdjustContrast_cpu_1_299_299 3772927 28001819 5000 9.134M items/s After this CL: BM_AdjustContrast_cpu_1_299_299 285677 596725 5000 428.637M items/s TESTED: - opensource_build passed https://ci.tensorflow.org/job/tensorflow-cl-presubmit-multijob/7138/ - unit tests passed Change: 136495531 --- tensorflow/core/kernels/adjust_contrast_op.cc | 238 +++++++++++++++++- tensorflow/python/ops/image_ops_test.py | 27 ++ 2 files changed, 257 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/kernels/adjust_contrast_op.cc b/tensorflow/core/kernels/adjust_contrast_op.cc index 6e810e18fe9ec1..c8f12f91a6cb74 100644 --- a/tensorflow/core/kernels/adjust_contrast_op.cc +++ b/tensorflow/core/kernels/adjust_contrast_op.cc @@ -135,12 +135,21 @@ REGISTER_GPU_KERNEL(double); #endif // GOOGLE_CUDA -template -class AdjustContrastOpv2 : public OpKernel { - public: - explicit AdjustContrastOpv2(OpKernelConstruction* context) +class AdjustContrastOpV2Base : public OpKernel { + protected: + explicit AdjustContrastOpV2Base(OpKernelConstruction* context) : OpKernel(context) {} + struct ComputeOptions { + const Tensor* input = nullptr; + const Tensor* factor = nullptr; + Tensor* output = nullptr; + int64 batch = 0; + int64 height = 0; + int64 width = 0; + int64 channels = 0; + }; + void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); const Tensor& factor = context->input(1); @@ -161,10 +170,206 @@ class AdjustContrastOpv2 : public OpKernel { if (input.NumElements() > 0) { const int64 batch = input.NumElements() / (height * width * channels); - const int64 shape[4] = {batch, height, width, channels}; - functor::AdjustContrastv2()( - context->eigen_device(), input.shaped(shape), - factor.scalar(), output->shaped(shape)); + ComputeOptions options; + options.input = &input; + options.factor = &factor; + options.output = output; + options.batch = batch; + options.height = height; + options.width = width; + options.channels = channels; + DoCompute(context, options); + } + } + + virtual void DoCompute(OpKernelContext* context, + const ComputeOptions& options) = 0; +}; + +template +class AdjustContrastOpv2; + +template <> +class AdjustContrastOpv2 : public AdjustContrastOpV2Base { + public: + explicit AdjustContrastOpv2(OpKernelConstruction* context) + : AdjustContrastOpV2Base(context) {} + + void DoCompute(OpKernelContext* context, + const ComputeOptions& options) override { + const int64 batch = options.batch; + const int64 height = options.height; + const int64 width = options.width; + const int64 channels = options.channels; + const int64 image_size = height * width; + const Tensor* input = options.input; + const Tensor* factor = options.factor; + Tensor* output = options.output; + Tensor mean_values; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({batch, channels}), &mean_values)); + // TODO(zhengxq): for multiple batches, shard them into different batches. + auto input_data = input->shaped({batch, image_size, channels}); + auto mean_data = mean_values.tensor(); + auto output_data = output->shaped({batch, image_size, channels}); + + // Calculate the mean of the inputs. + ReduceMeanAcrossImage(input_data, mean_data, output_data); + // Broadcast the mean into the outputs. + BroadcastAcrossImage(mean_data, output_data); + // Increment the outputs with the scaled difference through their flat + // structure. + IncrementWithScaling(input_data, factor->scalar(), output_data); + } + + private: + // Reduce the mean of the inputs along the image dimension, i.e. dim_1, in a + // 3D tensor. Effectively means(i, k) = inputs(i, :, k).mean(). + void ReduceMeanAcrossImage(typename TTypes::ConstTensor input, + typename TTypes::Tensor mean, + typename TTypes::Tensor scratch) { + const int64 batch = input.dimension(0); + const int64 image_size = input.dimension(1); + const int64 channels = input.dimension(2); + TTypes::ConstTensor input_flat(&input(0, 0, 0), input.size()); + TTypes::Tensor mean_flat(&mean(0, 0), mean.size()); + TTypes::Tensor summation_scratch(&scratch(0, 0, 0), + scratch.size()); + typedef Eigen::array Index; + const int64 plane_size = image_size * channels; + // Since the number of channels in the early layers is often small, a + // straightforward loop for summing cannot utilize vectorization. + // This algorithm repeatedly folds each image plane by half, until + // only one set of channels remains. + for (int64 i = 0; i < batch; i++) { + auto input_plane = + input_flat.slice(Index(i * plane_size), Index(plane_size)); + auto summation_plane = + summation_scratch.slice(Index(i * plane_size), Index(plane_size)); + int64 remaining_size = image_size; + int round = 0; + // Sum the input(i, :, k) into mean(i, k). Repeatedly splits the input + // array into half and sums the two halves, until only one set of channels + // is left, which holds the sum. Since each half is large enough, this + // leads to much better vectorizations between components. An example of + // how this works: + // + // x = float[4096, 3] + // round 0 + // y[:2048, :] = x[:2048, :] + x[2048:, :] + // round 1 + // y[:1024, :] += y[1024:2048, :] + // round 2 + // y[:512, :] += y[512:1024, :] + // ... + // round 11 + // y[:1, :] += y[1:2, :] + // At this point y[0, :] holds the sum of all x[:, :] + // + // The algorithm itself can handle size that is not power-of-two. Note + // that in each round we sum up elements that are contiguous. So we can + // use their flattened structure to gain vectorinization efficiency. + do { + int64 right_size = remaining_size / 2; + int64 left_size = remaining_size - right_size; + DCHECK(left_size == right_size || left_size == right_size + 1); + if (round == 0) { + // In the first round, sum the left side and right side of the input + // array into the summation area. + summation_plane.slice(Index(0), Index(right_size * channels)) = + input_plane.slice(Index(left_size * channels), + Index(right_size * channels)) + + input_plane.slice(Index(0), Index(right_size * channels)); + if (left_size > right_size) { + DCHECK_EQ(left_size - right_size, 1); + // Copy over the remaining column if the remaining_size is odd. + // This also handles the case where image_size == 1. + summation_plane.slice(Index(right_size * channels), + Index(channels)) = + input_plane.slice(Index(right_size * channels), + Index(channels)); + } + } else { + // For all the remaining rounds, add the second half of the inputs + // into the first half of the inputs. With the flat structure and + // large size, this utilizes vectorization between components. + summation_plane.slice(Index(0), Index(right_size * channels)) += + summation_plane.slice(Index(left_size * channels), + Index(right_size * channels)); + } + remaining_size = left_size; + round++; + } while (remaining_size > 1); + const float mean_scaling = 1.0f / image_size; + // The first channels elements in summation_plane now holds the summation. + // Scale it with image_size and copy over to the means. + auto mean_plane = mean_flat.slice(Index(i * channels), Index(channels)); + mean_plane = + summation_plane.slice(Index(0), Index(channels)) * mean_scaling; + } + } + + // Broadcast a 2D inputs into a 3D outputs across the image dimension, i.e., + // dim-1. + void BroadcastAcrossImage(typename TTypes::Tensor inputs, + typename TTypes::Tensor outputs) { + int64 batch = outputs.dimension(0); + int64 image_size = outputs.dimension(1); + int64 channels = outputs.dimension(2); + // Similar to the reduction case, a straighforward implementation of this + // does not utilize vectorization well because of the small channel size. + // This algorithm repeatedly increases the area to be copied, and leads to + // much better vectorinizations in the copy. + for (int64 i = 0; i < batch; i++) { + // Copy over the inputs into outputs in this batch. Effectively: + // outputs(i, :, k) = inputs(i, k). An example of how this algorith works: + // + // x = float[1, 3], y = float[2048, 3] + // round 0 + // y[:1, :] = x[:, :] + // round 1 + // y[1:2, :] = y[:1, :] + // round 2 + // y[2:4, :] = y[:2, :] + // round 3 + // y[4:8, :] = y[:4, :] + // ... + // round 11 + // y[1024:2048, :] = y[:1024, :] + // At this point y[:, k] == x[k] + // + // The algorithm works for size that is not power-of-two. For each round, + // the elements that are copied are continuous, so it benefits from the + // vectorized copy via memcpy. + const float* mean_p = &inputs(i, 0); + // Copy the first set of channels. + float* output_p = &outputs(i, 0, 0); + memcpy(output_p, mean_p, sizeof(float) * channels); + int64 copied = 1; + while (copied < image_size) { + // Repeatedly increases the number of elements to copy so they have + // better vectorinizations. However, the source of the copy has to be + // not too large to stay in the cache. + const int64 kMaxToCopy = 1024; + int64 to_copy = std::min({copied, image_size - copied, kMaxToCopy}); + memcpy(output_p + channels * copied, output_p, + to_copy * channels * sizeof(float)); + copied += to_copy; + } + } + } + + // Increment the outputs with the scaled difference between inputs and + // outputs. Effectively: outputs += factor * (inputs - outputs). + void IncrementWithScaling(typename TTypes::ConstTensor input, + typename TTypes::ConstScalar factor, + typename TTypes::Tensor output) { + const float factor_value = factor(); + float* p = output.data(); + const float* q = input.data(); + for (int64 n = 0; n < input.size(); ++n) { + p[n] += factor_value * (q[n] - p[n]); } } }; @@ -184,6 +389,23 @@ void AdjustContrastv2::operator()( extern template struct AdjustContrastv2; } // namespace functor +template <> +class AdjustContrastOpv2 : public AdjustContrastOpV2Base { + public: + explicit AdjustContrastOpv2(OpKernelConstruction* context) + : AdjustContrastOpV2Base(context) {} + + void DoCompute(OpKernelContext* context, + const ComputeOptions& options) override { + const int64 shape[4] = {options.batch, options.height, options.width, + options.channels}; + functor::AdjustContrastv2()( + context->eigen_device(), + options.input->shaped(shape), options.factor->scalar(), + options.output->shaped(shape)); + } +}; + REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_GPU), AdjustContrastOpv2); #endif // GOOGLE_CUDA diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 80ad28563583c5..f1738fd779a35a 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -418,6 +418,33 @@ def testBatchDoubleContrast(self): self._testContrast(x_np, y_np, contrast_factor=2.0) + def _adjustContrastNp(self, x_np, contrast_factor): + mean = np.mean(x_np, (1, 2), keepdims=True) + y_np = mean + contrast_factor * (x_np - mean) + return y_np + + def _adjustContrastTf(self, x_np, contrast_factor): + with self.test_session(use_gpu=True): + x = constant_op.constant(x_np) + y = image_ops.adjust_contrast(x, contrast_factor) + y_tf = y.eval() + return y_tf + + def testRandomContrast(self): + x_shapes = [ + [1, 2, 2, 3], + [2, 1, 2, 3], + [1, 2, 2, 3], + [2, 5, 5, 3], + [2, 1, 1, 3], + ] + for x_shape in x_shapes: + x_np = np.random.rand(*x_shape) * 255. + contrast_factor = np.random.rand() * 2.0 + 0.1 + y_np = self._adjustContrastNp(x_np, contrast_factor) + y_tf = self._adjustContrastTf(x_np, contrast_factor) + self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5) + class AdjustBrightnessTest(test_util.TensorFlowTestCase):