Skip to content

Commit

Permalink
Merge pull request BVLC#5009 from shelhamer/solver-type-check
Browse files Browse the repository at this point in the history
Solver: check and set type to reconcile class and proto type
  • Loading branch information
shelhamer authored Nov 22, 2016
2 parents a60c4a4 + e52451d commit a6c6533
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class Solver {
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
void DisplayOutputBlobs(const int net_id);
void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
/// Harmonize solver class type with configured proto type.
void CheckType(SolverParameter* param);

SolverParameter param_;
int iter_;
Expand Down
12 changes: 12 additions & 0 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,21 @@ Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
requested_early_exit_(false) {
SolverParameter param;
ReadSolverParamsFromTextFileOrDie(param_file, &param);
CheckType(&param);
Init(param);
}

template <typename Dtype>
void Solver<Dtype>::CheckType(SolverParameter* param) {
// Harmonize solver class type with configured type to avoid confusion.
if (param->has_type()) {
CHECK_EQ(param->type(), this->type())
<< "Solver type must agree with instantiated solver class.";
} else {
param->set_type(this->type());
}
}

template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
CHECK(Caffe::root_solver() || root_solver_)
Expand Down
5 changes: 5 additions & 0 deletions src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,11 @@ TYPED_TEST(SGDSolverTest, TestSnapshotShare) {
}
}

TYPED_TEST(SGDSolverTest, TestSolverType) {
this->TestLeastSquaresUpdate();
EXPECT_NE(this->solver_->type(), string(""));
EXPECT_EQ(this->solver_->type(), this->solver_->param().type());
}

template <typename TypeParam>
class AdaGradSolverTest : public GradientBasedSolverTest<TypeParam> {
Expand Down

0 comments on commit a6c6533

Please sign in to comment.