Skip to content

Commit

Permalink
Merge pull request BVLC#3863 from lukeyeager/bvlc/expose-all-netstate…
Browse files Browse the repository at this point in the history
…-options

Expose all netstate options (for all-in-one nets)
  • Loading branch information
shelhamer authored Jul 11, 2016
2 parents f28f5ae + 19adc7a commit 3e94c0e
Show file tree
Hide file tree
Showing 6 changed files with 439 additions and 12 deletions.
1 change: 1 addition & 0 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Net {
public:
explicit Net(const NetParameter& param, const Net* root_net = NULL);
explicit Net(const string& param_file, Phase phase,
const int level = 0, const vector<string>* stages = NULL,
const Net* root_net = NULL);
virtual ~Net() {}

Expand Down
44 changes: 36 additions & 8 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,42 @@ void CheckContiguousArray(PyArrayObject* arr, string name,
}
}

// Net constructor for passing phase as int
shared_ptr<Net<Dtype> > Net_Init(
string param_file, int phase) {
CheckFile(param_file);
// Net constructor
shared_ptr<Net<Dtype> > Net_Init(string network_file, int phase,
const int level, const bp::object& stages,
const bp::object& weights) {
CheckFile(network_file);

// Convert stages from list to vector
vector<string> stages_vector;
if (!stages.is_none()) {
for (int i = 0; i < len(stages); i++) {
stages_vector.push_back(bp::extract<string>(stages[i]));
}
}

// Initialize net
shared_ptr<Net<Dtype> > net(new Net<Dtype>(network_file,
static_cast<Phase>(phase), level, &stages_vector));

// Load weights
if (!weights.is_none()) {
std::string weights_file_str = bp::extract<std::string>(weights);
CheckFile(weights_file_str);
net->CopyTrainedLayersFrom(weights_file_str);
}

shared_ptr<Net<Dtype> > net(new Net<Dtype>(param_file,
static_cast<Phase>(phase)));
return net;
}

// Net construct-and-load convenience constructor
// Legacy Net construct-and-load convenience constructor
shared_ptr<Net<Dtype> > Net_Init_Load(
string param_file, string pretrained_param_file, int phase) {
LOG(WARNING) << "DEPRECATION WARNING - deprecated use of Python interface";
LOG(WARNING) << "Use this instead (with the named \"weights\""
<< " parameter):";
LOG(WARNING) << "Net('" << param_file << "', " << phase
<< ", weights='" << pretrained_param_file << "')";
CheckFile(param_file);
CheckFile(pretrained_param_file);

Expand Down Expand Up @@ -266,7 +289,12 @@ BOOST_PYTHON_MODULE(_caffe) {

bp::class_<Net<Dtype>, shared_ptr<Net<Dtype> >, boost::noncopyable >("Net",
bp::no_init)
.def("__init__", bp::make_constructor(&Net_Init))
// Constructor
.def("__init__", bp::make_constructor(&Net_Init,
bp::default_call_policies(), (bp::arg("network_file"), "phase",
bp::arg("level")=0, bp::arg("stages")=bp::object(),
bp::arg("weights")=bp::object())))
// Legacy constructor
.def("__init__", bp::make_constructor(&Net_Init_Load))
.def("_forward", &Net<Dtype>::ForwardFromTo)
.def("_backward", &Net<Dtype>::BackwardFromTo)
Expand Down
228 changes: 227 additions & 1 deletion python/caffe/test/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ def test_save_and_read(self):
f.close()
self.net.save(f.name)
net_file = simple_net_file(self.num_output)
net2 = caffe.Net(net_file, f.name, caffe.TRAIN)
# Test legacy constructor
# should print deprecation warning
caffe.Net(net_file, f.name, caffe.TRAIN)
# Test named constructor
net2 = caffe.Net(net_file, caffe.TRAIN, weights=f.name)
os.remove(net_file)
os.remove(f.name)
for name in self.net.params:
Expand All @@ -93,3 +97,225 @@ def test_save_hdf5(self):
for i in range(len(self.net.params[name])):
self.assertEqual(abs(self.net.params[name][i].data
- net2.params[name][i].data).sum(), 0)

class TestLevels(unittest.TestCase):

TEST_NET = """
layer {
name: "data"
type: "DummyData"
top: "data"
dummy_data_param { shape { dim: 1 dim: 1 dim: 10 dim: 10 } }
}
layer {
name: "NoLevel"
type: "InnerProduct"
bottom: "data"
top: "NoLevel"
inner_product_param { num_output: 1 }
}
layer {
name: "Level0Only"
type: "InnerProduct"
bottom: "data"
top: "Level0Only"
include { min_level: 0 max_level: 0 }
inner_product_param { num_output: 1 }
}
layer {
name: "Level1Only"
type: "InnerProduct"
bottom: "data"
top: "Level1Only"
include { min_level: 1 max_level: 1 }
inner_product_param { num_output: 1 }
}
layer {
name: "Level>=0"
type: "InnerProduct"
bottom: "data"
top: "Level>=0"
include { min_level: 0 }
inner_product_param { num_output: 1 }
}
layer {
name: "Level>=1"
type: "InnerProduct"
bottom: "data"
top: "Level>=1"
include { min_level: 1 }
inner_product_param { num_output: 1 }
}
"""

def setUp(self):
self.f = tempfile.NamedTemporaryFile(mode='w+')
self.f.write(self.TEST_NET)
self.f.flush()

def tearDown(self):
self.f.close()

def check_net(self, net, blobs):
net_blobs = [b for b in net.blobs.keys() if 'data' not in b]
self.assertEqual(net_blobs, blobs)

def test_0(self):
net = caffe.Net(self.f.name, caffe.TEST)
self.check_net(net, ['NoLevel', 'Level0Only', 'Level>=0'])

def test_1(self):
net = caffe.Net(self.f.name, caffe.TEST, level=1)
self.check_net(net, ['NoLevel', 'Level1Only', 'Level>=0', 'Level>=1'])


class TestStages(unittest.TestCase):

TEST_NET = """
layer {
name: "data"
type: "DummyData"
top: "data"
dummy_data_param { shape { dim: 1 dim: 1 dim: 10 dim: 10 } }
}
layer {
name: "A"
type: "InnerProduct"
bottom: "data"
top: "A"
include { stage: "A" }
inner_product_param { num_output: 1 }
}
layer {
name: "B"
type: "InnerProduct"
bottom: "data"
top: "B"
include { stage: "B" }
inner_product_param { num_output: 1 }
}
layer {
name: "AorB"
type: "InnerProduct"
bottom: "data"
top: "AorB"
include { stage: "A" }
include { stage: "B" }
inner_product_param { num_output: 1 }
}
layer {
name: "AandB"
type: "InnerProduct"
bottom: "data"
top: "AandB"
include { stage: "A" stage: "B" }
inner_product_param { num_output: 1 }
}
"""

def setUp(self):
self.f = tempfile.NamedTemporaryFile(mode='w+')
self.f.write(self.TEST_NET)
self.f.flush()

def tearDown(self):
self.f.close()

def check_net(self, net, blobs):
net_blobs = [b for b in net.blobs.keys() if 'data' not in b]
self.assertEqual(net_blobs, blobs)

def test_A(self):
net = caffe.Net(self.f.name, caffe.TEST, stages=['A'])
self.check_net(net, ['A', 'AorB'])

def test_B(self):
net = caffe.Net(self.f.name, caffe.TEST, stages=['B'])
self.check_net(net, ['B', 'AorB'])

def test_AandB(self):
net = caffe.Net(self.f.name, caffe.TEST, stages=['A', 'B'])
self.check_net(net, ['A', 'B', 'AorB', 'AandB'])


class TestAllInOne(unittest.TestCase):

TEST_NET = """
layer {
name: "train_data"
type: "DummyData"
top: "data"
top: "label"
dummy_data_param {
shape { dim: 1 dim: 1 dim: 10 dim: 10 }
shape { dim: 1 dim: 1 dim: 1 dim: 1 }
}
include { phase: TRAIN stage: "train" }
}
layer {
name: "val_data"
type: "DummyData"
top: "data"
top: "label"
dummy_data_param {
shape { dim: 1 dim: 1 dim: 10 dim: 10 }
shape { dim: 1 dim: 1 dim: 1 dim: 1 }
}
include { phase: TEST stage: "val" }
}
layer {
name: "deploy_data"
type: "Input"
top: "data"
input_param { shape { dim: 1 dim: 1 dim: 10 dim: 10 } }
include { phase: TEST stage: "deploy" }
}
layer {
name: "ip"
type: "InnerProduct"
bottom: "data"
top: "ip"
inner_product_param { num_output: 2 }
}
layer {
name: "loss"
type: "SoftmaxWithLoss"
bottom: "ip"
bottom: "label"
top: "loss"
include: { phase: TRAIN stage: "train" }
include: { phase: TEST stage: "val" }
}
layer {
name: "pred"
type: "Softmax"
bottom: "ip"
top: "pred"
include: { phase: TEST stage: "deploy" }
}
"""

def setUp(self):
self.f = tempfile.NamedTemporaryFile(mode='w+')
self.f.write(self.TEST_NET)
self.f.flush()

def tearDown(self):
self.f.close()

def check_net(self, net, outputs):
self.assertEqual(list(net.blobs['data'].shape), [1,1,10,10])
self.assertEqual(net.outputs, outputs)

def test_train(self):
net = caffe.Net(self.f.name, caffe.TRAIN, stages=['train'])
self.check_net(net, ['loss'])

def test_val(self):
net = caffe.Net(self.f.name, caffe.TEST, stages=['val'])
self.check_net(net, ['loss'])

def test_deploy(self):
net = caffe.Net(self.f.name, caffe.TEST, stages=['deploy'])
self.check_net(net, ['pred'])

11 changes: 10 additions & 1 deletion src/caffe/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,20 @@ Net<Dtype>::Net(const NetParameter& param, const Net* root_net)
}

template <typename Dtype>
Net<Dtype>::Net(const string& param_file, Phase phase, const Net* root_net)
Net<Dtype>::Net(const string& param_file, Phase phase,
const int level, const vector<string>* stages,
const Net* root_net)
: root_net_(root_net) {
NetParameter param;
ReadNetParamsFromTextFileOrDie(param_file, &param);
// Set phase, stages and level
param.mutable_state()->set_phase(phase);
if (stages != NULL) {
for (int i = 0; i < stages->size(); i++) {
param.mutable_state()->add_stage((*stages)[i]);
}
}
param.mutable_state()->set_level(level);
Init(param);
}

Expand Down
Loading

0 comments on commit 3e94c0e

Please sign in to comment.