From a8ec123c00723df0d0ad897e1eea32a29201c81b Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Fri, 9 Sep 2016 16:49:31 -0700 Subject: [PATCH] batch norm: auto-upgrade old layer definitions w/ param messages automatically strip old batch norm layer definitions including `param` messages. the batch norm layer used to require manually masking its state from the solver by setting `param { lr_mult: 0 }` messages for each of its statistics. this is now handled automatically by the layer. --- include/caffe/util/upgrade_proto.hpp | 6 +++++ src/caffe/util/upgrade_proto.cpp | 34 +++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/include/caffe/util/upgrade_proto.hpp b/include/caffe/util/upgrade_proto.hpp index 14e1936a8c2..b145822af32 100644 --- a/include/caffe/util/upgrade_proto.hpp +++ b/include/caffe/util/upgrade_proto.hpp @@ -65,6 +65,12 @@ bool NetNeedsInputUpgrade(const NetParameter& net_param); // Perform all necessary transformations to upgrade input fields into layers. void UpgradeNetInput(NetParameter* net_param); +// Return true iff the Net contains batch norm layers with manual local LRs. +bool NetNeedsBatchNormUpgrade(const NetParameter& net_param); + +// Perform all necessary transformations to upgrade batch norm layers. +void UpgradeNetBatchNorm(NetParameter* net_param); + // Return true iff the solver contains any old solver_type specified as enums bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param); diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index 9e186915b43..a0aacbe92f8 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -14,7 +14,8 @@ namespace caffe { bool NetNeedsUpgrade(const NetParameter& net_param) { return NetNeedsV0ToV1Upgrade(net_param) || NetNeedsV1ToV2Upgrade(net_param) - || NetNeedsDataUpgrade(net_param) || NetNeedsInputUpgrade(net_param); + || NetNeedsDataUpgrade(net_param) || NetNeedsInputUpgrade(net_param) + || NetNeedsBatchNormUpgrade(net_param); } bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) { @@ -71,6 +72,14 @@ bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) { LOG(WARNING) << "Note that future Caffe releases will only support " << "input layers and not input fields."; } + // NetParameter uses old style batch norm layers; try to upgrade it. + if (NetNeedsBatchNormUpgrade(*param)) { + LOG(INFO) << "Attempting to upgrade batch norm layers using deprecated " + << "params: " << param_file; + UpgradeNetBatchNorm(param); + LOG(INFO) << "Successfully upgraded batch norm layers using deprecated " + << "params."; + } return success; } @@ -991,6 +1000,29 @@ void UpgradeNetInput(NetParameter* net_param) { net_param->clear_input_dim(); } +bool NetNeedsBatchNormUpgrade(const NetParameter& net_param) { + for (int i = 0; i < net_param.layer_size(); ++i) { + // Check if BatchNorm layers declare three parameters, as required by + // the previous BatchNorm layer definition. + if (net_param.layer(i).type() == "BatchNorm" + && net_param.layer(i).param_size() == 3) { + return true; + } + } + return false; +} + +void UpgradeNetBatchNorm(NetParameter* net_param) { + for (int i = 0; i < net_param->layer_size(); ++i) { + // Check if BatchNorm layers declare three parameters, as required by + // the previous BatchNorm layer definition. + if (net_param->layer(i).type() == "BatchNorm" + && net_param->layer(i).param_size() == 3) { + net_param->mutable_layer(i)->clear_param(); + } + } +} + // Return true iff the solver contains any old solver_type specified as enums bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param) { if (solver_param.has_solver_type()) {