Skip to content

Commit

Permalink
batch norm: auto-upgrade old layer definitions w/ param messages
Browse files Browse the repository at this point in the history
automatically strip old batch norm layer definitions including `param`
messages. the batch norm layer used to require manually masking its
state from the solver by setting `param { lr_mult: 0 }` messages for
each of its statistics. this is now handled automatically by the layer.
  • Loading branch information
shelhamer committed Sep 13, 2016
1 parent c8f446f commit a8ec123
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
6 changes: 6 additions & 0 deletions include/caffe/util/upgrade_proto.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ bool NetNeedsInputUpgrade(const NetParameter& net_param);
// Perform all necessary transformations to upgrade input fields into layers.
void UpgradeNetInput(NetParameter* net_param);

// Return true iff the Net contains batch norm layers with manual local LRs.
bool NetNeedsBatchNormUpgrade(const NetParameter& net_param);

// Perform all necessary transformations to upgrade batch norm layers.
void UpgradeNetBatchNorm(NetParameter* net_param);

// Return true iff the solver contains any old solver_type specified as enums
bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param);

Expand Down
34 changes: 33 additions & 1 deletion src/caffe/util/upgrade_proto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ namespace caffe {

bool NetNeedsUpgrade(const NetParameter& net_param) {
return NetNeedsV0ToV1Upgrade(net_param) || NetNeedsV1ToV2Upgrade(net_param)
|| NetNeedsDataUpgrade(net_param) || NetNeedsInputUpgrade(net_param);
|| NetNeedsDataUpgrade(net_param) || NetNeedsInputUpgrade(net_param)
|| NetNeedsBatchNormUpgrade(net_param);
}

bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) {
Expand Down Expand Up @@ -71,6 +72,14 @@ bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) {
LOG(WARNING) << "Note that future Caffe releases will only support "
<< "input layers and not input fields.";
}
// NetParameter uses old style batch norm layers; try to upgrade it.
if (NetNeedsBatchNormUpgrade(*param)) {
LOG(INFO) << "Attempting to upgrade batch norm layers using deprecated "
<< "params: " << param_file;
UpgradeNetBatchNorm(param);
LOG(INFO) << "Successfully upgraded batch norm layers using deprecated "
<< "params.";
}
return success;
}

Expand Down Expand Up @@ -991,6 +1000,29 @@ void UpgradeNetInput(NetParameter* net_param) {
net_param->clear_input_dim();
}

bool NetNeedsBatchNormUpgrade(const NetParameter& net_param) {
for (int i = 0; i < net_param.layer_size(); ++i) {
// Check if BatchNorm layers declare three parameters, as required by
// the previous BatchNorm layer definition.
if (net_param.layer(i).type() == "BatchNorm"
&& net_param.layer(i).param_size() == 3) {
return true;
}
}
return false;
}

void UpgradeNetBatchNorm(NetParameter* net_param) {
for (int i = 0; i < net_param->layer_size(); ++i) {
// Check if BatchNorm layers declare three parameters, as required by
// the previous BatchNorm layer definition.
if (net_param->layer(i).type() == "BatchNorm"
&& net_param->layer(i).param_size() == 3) {
net_param->mutable_layer(i)->clear_param();
}
}
}

// Return true iff the solver contains any old solver_type specified as enums
bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param) {
if (solver_param.has_solver_type()) {
Expand Down

0 comments on commit a8ec123

Please sign in to comment.