forked from zhaoweicai/cascade-rcnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
solver_factory.hpp
137 lines (118 loc) · 4.32 KB
/
solver_factory.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
/**
* @brief A solver factory that allows one to register solvers, similar to
* layer factory. During runtime, registered solvers could be called by passing
* a SolverParameter protobuffer to the CreateSolver function:
*
* SolverRegistry<Dtype>::CreateSolver(param);
*
* There are two ways to register a solver. Assuming that we have a solver like:
*
* template <typename Dtype>
* class MyAwesomeSolver : public Solver<Dtype> {
* // your implementations
* };
*
* and its type is its C++ class name, but without the "Solver" at the end
* ("MyAwesomeSolver" -> "MyAwesome").
*
* If the solver is going to be created simply by its constructor, in your C++
* file, add the following line:
*
* REGISTER_SOLVER_CLASS(MyAwesome);
*
* Or, if the solver is going to be created by another creator function, in the
* format of:
*
* template <typename Dtype>
* Solver<Dtype*> GetMyAwesomeSolver(const SolverParameter& param) {
* // your implementation
* }
*
* then you can register the creator function instead, like
*
* REGISTER_SOLVER_CREATOR(MyAwesome, GetMyAwesomeSolver)
*
* Note that each solver type should only be registered once.
*/
#ifndef CAFFE_SOLVER_FACTORY_H_
#define CAFFE_SOLVER_FACTORY_H_
#include <map>
#include <string>
#include <vector>
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
namespace caffe {
template <typename Dtype>
class Solver;
template <typename Dtype>
class SolverRegistry {
public:
typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
typedef std::map<string, Creator> CreatorRegistry;
static CreatorRegistry& Registry() {
static CreatorRegistry* g_registry_ = new CreatorRegistry();
return *g_registry_;
}
// Adds a creator.
static void AddCreator(const string& type, Creator creator) {
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 0)
<< "Solver type " << type << " already registered.";
registry[type] = creator;
}
// Get a solver using a SolverParameter.
static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
const string& type = param.type();
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
<< " (known types: " << SolverTypeListString() << ")";
return registry[type](param);
}
static vector<string> SolverTypeList() {
CreatorRegistry& registry = Registry();
vector<string> solver_types;
for (typename CreatorRegistry::iterator iter = registry.begin();
iter != registry.end(); ++iter) {
solver_types.push_back(iter->first);
}
return solver_types;
}
private:
// Solver registry should never be instantiated - everything is done with its
// static variables.
SolverRegistry() {}
static string SolverTypeListString() {
vector<string> solver_types = SolverTypeList();
string solver_types_str;
for (vector<string>::iterator iter = solver_types.begin();
iter != solver_types.end(); ++iter) {
if (iter != solver_types.begin()) {
solver_types_str += ", ";
}
solver_types_str += *iter;
}
return solver_types_str;
}
};
template <typename Dtype>
class SolverRegisterer {
public:
SolverRegisterer(const string& type,
Solver<Dtype>* (*creator)(const SolverParameter&)) {
// LOG(INFO) << "Registering solver type: " << type;
SolverRegistry<Dtype>::AddCreator(type, creator);
}
};
#define REGISTER_SOLVER_CREATOR(type, creator) \
static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \
static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \
#define REGISTER_SOLVER_CLASS(type) \
template <typename Dtype> \
Solver<Dtype>* Creator_##type##Solver( \
const SolverParameter& param) \
{ \
return new type##Solver<Dtype>(param); \
} \
REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
} // namespace caffe
#endif // CAFFE_SOLVER_FACTORY_H_