Skip to content

Commit

Permalink
Merge pull request BVLC#3995 from ZhouYzzz/python-phase
Browse files Browse the repository at this point in the history
Allow the python layer have attribute "phase"
  • Loading branch information
longjon committed May 4, 2016
2 parents f467ead + c2dba92 commit de8ac32
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/caffe/layers/python_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class PythonLayer : public Layer<Dtype> {
}
self_.attr("param_str") = bp::str(
this->layer_param_.python_param().param_str());
self_.attr("phase") = static_cast<int>(this->phase_);
self_.attr("setup")(bottom, top);
}
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
Expand Down
26 changes: 26 additions & 0 deletions python/caffe/test/test_python_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def forward(self, bottom, top):
def backward(self, top, propagate_down, bottom):
self.blobs[0].diff[0] = 1

class PhaseLayer(caffe.Layer):
"""A layer for checking attribute `phase`"""

def setup(self, bottom, top):
pass

def reshape(self, bootom, top):
top[0].reshape()

def forward(self, bottom, top):
top[0].data[()] = self.phase

def python_net_file():
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
f.write("""name: 'pythonnet' force_backward: true
Expand Down Expand Up @@ -76,6 +88,14 @@ def parameter_net_file():
""")
return f.name

def phase_net_file():
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
f.write("""name: 'pythonnet' force_backward: true
layer { type: 'Python' name: 'layer' top: 'phase'
python_param { module: 'test_python_layer' layer: 'PhaseLayer' } }
""")
return f.name


@unittest.skipIf('Python' not in caffe.layer_type_list(),
'Caffe built without Python layer support')
Expand Down Expand Up @@ -140,3 +160,9 @@ def test_parameter(self):
self.assertEqual(layer.blobs[0].data[0], 1)

os.remove(net_file)

def test_phase(self):
net_file = phase_net_file()
for phase in caffe.TRAIN, caffe.TEST:
net = caffe.Net(net_file, phase)
self.assertEqual(net.forward()['phase'], phase)

0 comments on commit de8ac32

Please sign in to comment.