Skip to content

Commit

Permalink
Add clip layer
Browse files Browse the repository at this point in the history
  • Loading branch information
harm-nedap authored and Noiredd committed Aug 17, 2018
1 parent 24b0905 commit 7f4f5d2
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 1 deletion.
75 changes: 75 additions & 0 deletions include/caffe/layers/clip_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#ifndef CAFFE_CLIP_LAYER_HPP_
#define CAFFE_CLIP_LAYER_HPP_

#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

#include "caffe/layers/neuron_layer.hpp"

namespace caffe {

/**
* @brief Clip: @f$ y = \max(min, \min(max, x)) @f$.
*/
template <typename Dtype>
class ClipLayer : public NeuronLayer<Dtype> {
public:
/**
* @param param provides ClipParameter clip_param,
* with ClipLayer options:
* - min
* - max
*/
explicit ClipLayer(const LayerParameter& param)
: NeuronLayer<Dtype>(param) {}

virtual inline const char* type() const { return "Clip"; }

protected:
/**
* @param bottom input Blob vector (length 1)
* -# @f$ (N \times C \times H \times W) @f$
* the inputs @f$ x @f$
* @param top output Blob vector (length 1)
* -# @f$ (N \times C \times H \times W) @f$
* the computed outputs @f$
* y = \max(min, \min(max, x))
* @f$
*/
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

/**
* @brief Computes the error gradient w.r.t. the clipped inputs.
*
* @param top output Blob vector (length 1), providing the error gradient with
* respect to the outputs
* -# @f$ (N \times C \times H \times W) @f$
* containing error gradients @f$ \frac{\partial E}{\partial y} @f$
* with respect to computed outputs @f$ y @f$
* @param propagate_down see Layer::Backward.
* @param bottom input Blob vector (length 1)
* -# @f$ (N \times C \times H \times W) @f$
* the inputs @f$ x @f$; Backward fills their diff with
* gradients @f$
* \frac{\partial E}{\partial x} = \left\{
* \begin{array}{lr}
* 0 & \mathrm{if} \; x < min \vee x > max \\
* \frac{\partial E}{\partial y} & \mathrm{if} \; x \ge min \wedge x \le max
* \end{array} \right.
* @f$
*/
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
};

} // namespace caffe

#endif // CAFFE_CLIP_LAYER_HPP_
1 change: 1 addition & 0 deletions src/caffe/layer_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "caffe/layer.hpp"
#include "caffe/layer_factory.hpp"
#include "caffe/layers/clip_layer.hpp"
#include "caffe/layers/conv_layer.hpp"
#include "caffe/layers/deconv_layer.hpp"
#include "caffe/layers/lrn_layer.hpp"
Expand Down
50 changes: 50 additions & 0 deletions src/caffe/layers/clip_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <algorithm>
#include <vector>
#include "caffe/layers/clip_layer.hpp"

namespace caffe {

template <typename Dtype>
void ClipLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
const int count = bottom[0]->count();

Dtype min = this->layer_param_.clip_param().min();
Dtype max = this->layer_param_.clip_param().max();

for (int i = 0; i < count; ++i) {
top_data[i] = std::max(min, std::min(bottom_data[i], max));
}
}

template <typename Dtype>
void ClipLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[0]) {
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
const int count = bottom[0]->count();

Dtype min = this->layer_param_.clip_param().min();
Dtype max = this->layer_param_.clip_param().max();

for (int i = 0; i < count; ++i) {
bottom_diff[i] = top_diff[i] * (
bottom_data[i] >= min && bottom_data[i] <= max);
}
}
}


#ifdef CPU_ONLY
STUB_GPU(ClipLayer);
#endif

INSTANTIATE_CLASS(ClipLayer);
REGISTER_LAYER_CLASS(Clip);

} // namespace caffe
66 changes: 66 additions & 0 deletions src/caffe/layers/clip_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include <vector>
#include "caffe/layers/clip_layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

__global__ void ClipForward(const int n, const float* in, float* out,
float p_min, float p_max) {
CUDA_KERNEL_LOOP(index, n) {
out[index] = fmaxf(p_min, fminf(in[index], p_max));
}
}

__global__ void ClipForward(const int n, const double* in, double* out,
double p_min, double p_max) {
CUDA_KERNEL_LOOP(index, n) {
out[index] = fmax(p_min, fmin(in[index], p_max));
}
}

template <typename Dtype>
void ClipLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
const int count = bottom[0]->count();
Dtype p_min = this->layer_param_.clip_param().min();
Dtype p_max = this->layer_param_.clip_param().max();
// NOLINT_NEXT_LINE(whitespace/operators)
ClipForward<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, top_data, p_min, p_max);
CUDA_POST_KERNEL_CHECK;
}

template <typename Dtype>
__global__ void ClipBackward(const int n, const Dtype* in_diff,
const Dtype* in_data, Dtype* out_diff, Dtype p_min, Dtype p_max) {
CUDA_KERNEL_LOOP(index, n) {
out_diff[index] = in_diff[index] * (
in_data[index] >= p_min && in_data[index] <= p_max);
}
}

template <typename Dtype>
void ClipLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[0]) {
const Dtype* bottom_data = bottom[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
const int count = bottom[0]->count();
Dtype p_min = this->layer_param_.clip_param().min();
Dtype p_max = this->layer_param_.clip_param().max();
// NOLINT_NEXT_LINE(whitespace/operators)
ClipBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, bottom_data, bottom_diff, p_min, p_max);
CUDA_POST_KERNEL_CHECK;
}
}


INSTANTIATE_LAYER_GPU_FUNCS(ClipLayer);


} // namespace caffe
9 changes: 8 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available layer-specific ID: 148 (last added: swish_param)
// LayerParameter next available layer-specific ID: 149 (last added: clip_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
Expand Down Expand Up @@ -378,6 +378,7 @@ message LayerParameter {
optional ArgMaxParameter argmax_param = 103;
optional BatchNormParameter batch_norm_param = 139;
optional BiasParameter bias_param = 141;
optional ClipParameter clip_param = 148;
optional ConcatParameter concat_param = 104;
optional ContrastiveLossParameter contrastive_loss_param = 105;
optional ConvolutionParameter convolution_param = 106;
Expand Down Expand Up @@ -505,6 +506,12 @@ message ArgMaxParameter {
optional int32 axis = 3;
}

// Message that stores parameters used by ClipLayer
message ClipParameter {
required float min = 1;
required float max = 2;
}

message ConcatParameter {
// The axis along which to concatenate -- may be negative to index from the
// end (e.g., -1 for the last axis). Other axes must have the
Expand Down
33 changes: 33 additions & 0 deletions src/caffe/test/test_neuron_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "caffe/layers/absval_layer.hpp"
#include "caffe/layers/bnll_layer.hpp"
#include "caffe/layers/clip_layer.hpp"
#include "caffe/layers/dropout_layer.hpp"
#include "caffe/layers/elu_layer.hpp"
#include "caffe/layers/exp_layer.hpp"
Expand Down Expand Up @@ -206,6 +207,38 @@ TYPED_TEST(NeuronLayerTest, TestAbsGradient) {
this->blob_top_vec_);
}

TYPED_TEST(NeuronLayerTest, TestClip) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
CHECK(google::protobuf::TextFormat::ParseFromString(
"clip_param { min: -1, max: 2 }", &layer_param));
ClipLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
// Now, check values
const Dtype* bottom_data = this->blob_bottom_->cpu_data();
const Dtype* top_data = this->blob_top_->cpu_data();
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
EXPECT_GE(top_data[i], -1);
EXPECT_LE(top_data[i], 2);
EXPECT_TRUE(bottom_data[i] > -1 || top_data[i] == -1);
EXPECT_TRUE(bottom_data[i] < 2 || top_data[i] == 2);
EXPECT_TRUE(!(bottom_data[i] >= -1 && bottom_data[i] <= 2)
|| top_data[i] == bottom_data[i]);
}
}

TYPED_TEST(NeuronLayerTest, TestClipGradient) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
CHECK(google::protobuf::TextFormat::ParseFromString(
"clip_param { min: -1, max: 2 }", &layer_param));
ClipLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 1e-3);
checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
this->blob_top_vec_);
}

TYPED_TEST(NeuronLayerTest, TestReLU) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
Expand Down

0 comments on commit 7f4f5d2

Please sign in to comment.