Skip to content

Commit

Permalink
pycaffe: expose SGDSolver.net
Browse files Browse the repository at this point in the history
  • Loading branch information
longjon committed Apr 5, 2014
1 parent a8a0191 commit 2915f4b
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ struct CaffeNet {
net_->CopyTrainedLayersFrom(pretrained_param_file);
}

CaffeNet(shared_ptr<Net<float> > net)
: net_(net) {}

virtual ~CaffeNet() {}

inline void check_array_against_blob(
Expand Down Expand Up @@ -297,6 +300,8 @@ class CaffeSGDSolver {
solver_.reset(new SGDSolver<float>(param_file));
}

CaffeNet net() { return CaffeNet(solver_->net()); }

protected:
shared_ptr<SGDSolver<float> > solver_;
};
Expand Down Expand Up @@ -335,7 +340,8 @@ BOOST_PYTHON_MODULE(_caffe) {
.add_property("blobs", &CaffeLayer::blobs);

boost::python::class_<CaffeSGDSolver, boost::noncopyable>(
"SGDSolver", boost::python::init<string>());
"SGDSolver", boost::python::init<string>())
.add_property("net", &CaffeSGDSolver::net);

boost::python::class_<vector<CaffeBlob> >("BlobVec")
.def(vector_indexing_suite<vector<CaffeBlob>, true>());
Expand Down

0 comments on commit 2915f4b

Please sign in to comment.