Skip to content

Commit

Permalink
Add test for attribute "phase" in python layer
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhouYzzz committed May 4, 2016
1 parent 1c49130 commit c2dba92
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions python/caffe/test/test_python_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def forward(self, bottom, top):
def backward(self, top, propagate_down, bottom):
self.blobs[0].diff[0] = 1

class PhaseLayer(caffe.Layer):
"""A layer for checking attribute `phase`"""

def setup(self, bottom, top):
pass

def reshape(self, bootom, top):
top[0].reshape()

def forward(self, bottom, top):
top[0].data[()] = self.phase

def python_net_file():
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
f.write("""name: 'pythonnet' force_backward: true
Expand Down Expand Up @@ -76,6 +88,14 @@ def parameter_net_file():
""")
return f.name

def phase_net_file():
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
f.write("""name: 'pythonnet' force_backward: true
layer { type: 'Python' name: 'layer' top: 'phase'
python_param { module: 'test_python_layer' layer: 'PhaseLayer' } }
""")
return f.name


@unittest.skipIf('Python' not in caffe.layer_type_list(),
'Caffe built without Python layer support')
Expand Down Expand Up @@ -140,3 +160,9 @@ def test_parameter(self):
self.assertEqual(layer.blobs[0].data[0], 1)

os.remove(net_file)

def test_phase(self):
net_file = phase_net_file()
for phase in caffe.TRAIN, caffe.TEST:
net = caffe.Net(net_file, phase)
self.assertEqual(net.forward()['phase'], phase)

0 comments on commit c2dba92

Please sign in to comment.