Skip to content

Commit

Permalink
[pycaffe] expose solver update to do manual solving
Browse files Browse the repository at this point in the history
a sketch of `solver.step()` done out manually:

1. `solver.net.forward()`
2. `solver.net.backward()`
3. `solver.net.apply_update()`
4. `solver.net.clear_param_diffs()`
  • Loading branch information
mitar authored and shelhamer committed Jun 7, 2018
1 parent a357693 commit cc1c8fb
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
3 changes: 2 additions & 1 deletion include/caffe/sgd_solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ class SGDSolver : public Solver<Dtype> {

const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }

virtual void ApplyUpdate();

protected:
void PreSolve();
Dtype GetLearningRate();
virtual void ApplyUpdate();
virtual void Normalize(int param_id);
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
Expand Down
3 changes: 2 additions & 1 deletion include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ class Solver {
*/
virtual inline const char* type() const { return ""; }

protected:
// Make and apply the update value for the current iteration.
virtual void ApplyUpdate() = 0;

protected:
string SnapshotFilename(const string extension);
string SnapshotToBinaryProto();
string SnapshotToHDF5();
Expand Down
1 change: 1 addition & 0 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("restore", &Solver<Dtype>::Restore)
.def("snapshot", &Solver<Dtype>::Snapshot)
.def("share_weights", &share_weights)
.def("apply_update", &Solver<Dtype>::ApplyUpdate)
.add_property("param", bp::make_function(&Solver<Dtype>::param,
bp::return_value_policy<bp::copy_const_reference>()));
BP_REGISTER_SHARED_PTR_TO_PYTHON(Solver<Dtype>);
Expand Down

0 comments on commit cc1c8fb

Please sign in to comment.