forked from initialneil/caffe-vs2013
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsolver.hpp
147 lines (125 loc) · 4.72 KB
/
solver.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
#define CAFFE_OPTIMIZATION_SOLVER_HPP_
#include <string>
#include <vector>
#include "caffe/net.hpp"
namespace caffe {
/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ComputeUpdateValue to compute a parameter update
* given the current state of the Net parameters.
*/
template <typename Dtype>
class Solver {
public:
explicit Solver(const SolverParameter& param);
explicit Solver(const string& param_file);
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
virtual ~Solver() {}
inline shared_ptr<Net<Dtype> > net() { return net_; }
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
return test_nets_;
}
protected:
// PreSolve is run before any solving iteration starts, allowing one to
// put up some scaffold.
virtual void PreSolve() {}
// Get the update value for the current iteration.
virtual void ComputeUpdateValue() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
void Snapshot();
// The test routine
void TestAll();
void Test(const int test_net_id = 0);
virtual void SnapshotSolverState(SolverState* state) = 0;
// The Restore function implements how one should restore the solver to a
// previously snapshotted state. You should implement the RestoreSolverState()
// function that restores the state from a SolverState protocol buffer.
void Restore(const char* resume_file);
virtual void RestoreSolverState(const SolverState& state) = 0;
void DisplayOutputBlobs(const int net_id);
SolverParameter param_;
int iter_;
int current_step_;
shared_ptr<Net<Dtype> > net_;
vector<shared_ptr<Net<Dtype> > > test_nets_;
DISABLE_COPY_AND_ASSIGN(Solver);
};
/**
* @brief Optimizes the parameters of a Net using
* stochastic gradient descent (SGD) with momentum.
*/
template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
public:
explicit SGDSolver(const SolverParameter& param)
: Solver<Dtype>(param) {}
explicit SGDSolver(const string& param_file)
: Solver<Dtype>(param_file) {}
const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
protected:
virtual void PreSolve();
Dtype GetLearningRate();
virtual void ComputeUpdateValue();
virtual void SnapshotSolverState(SolverState * state);
virtual void RestoreSolverState(const SolverState& state);
// history maintains the historical momentum data.
// update maintains update related data and is not needed in snapshots.
// temp maintains other information that might be needed in computation
// of gradients/updates and is not needed in snapshots
vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;
DISABLE_COPY_AND_ASSIGN(SGDSolver);
};
template <typename Dtype>
class NesterovSolver : public SGDSolver<Dtype> {
public:
explicit NesterovSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) {}
explicit NesterovSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) {}
protected:
virtual void ComputeUpdateValue();
DISABLE_COPY_AND_ASSIGN(NesterovSolver);
};
template <typename Dtype>
class AdaGradSolver : public SGDSolver<Dtype> {
public:
explicit AdaGradSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
explicit AdaGradSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
protected:
virtual void ComputeUpdateValue();
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.momentum())
<< "Momentum cannot be used with AdaGrad.";
}
DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};
template <typename Dtype>
Solver<Dtype>* GetSolver(const SolverParameter& param) {
SolverParameter_SolverType type = param.solver_type();
switch (type) {
case SolverParameter_SolverType_SGD:
return new SGDSolver<Dtype>(param);
case SolverParameter_SolverType_NESTEROV:
return new NesterovSolver<Dtype>(param);
case SolverParameter_SolverType_ADAGRAD:
return new AdaGradSolver<Dtype>(param);
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}
return (Solver<Dtype>*) NULL;
}
} // namespace caffe
#endif // CAFFE_OPTIMIZATION_SOLVER_HPP_