Skip to content

Commit

Permalink
Merge pull request jolibrain#561 from fantes/model_stats_status
Browse files Browse the repository at this point in the history
output some stats in status json
  • Loading branch information
beniz authored Apr 5, 2019
2 parents 0ecc5a8 + 30b4678 commit ec8788e
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 10 deletions.
18 changes: 12 additions & 6 deletions src/backends/caffe/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,6 @@ namespace dd
if (this->_loss == "dice")
// dice loss!!
{
if (_nclasses > 2)
update_protofiles_one_hot(net_param);

if (net_param.name().compare("deeplab_vgg16")==0
|| net_param.name().compare("pspnet_vgg16")==0
|| net_param.name().compare("pspnet_50")==0
Expand Down Expand Up @@ -877,7 +874,7 @@ namespace dd
}
try
{
model_complexity(_flops,_params);
model_complexity(this->_flops,this->_params);
}
catch(std::exception &e)
{
Expand Down Expand Up @@ -1056,7 +1053,7 @@ namespace dd
int user_batch_size, batch_size, test_batch_size, test_iter;
update_in_memory_net_and_solver(solver_param,cad,inputc,has_mean_file,user_batch_size,batch_size,test_batch_size,test_iter);
//caffe::ReadProtoFromTextFile(this->_mlmodel._solver,&solver_param);

// parameters
#if !defined(CPU_ONLY) && !defined(USE_CAFFE_CPU_ONLY)
bool gpu = _gpu;
Expand Down Expand Up @@ -1292,6 +1289,10 @@ namespace dd
_syncs[i]->StartInternalThread();
}

this->_mem_used_train = solver->net_->memory_used();
for (caffe::shared_ptr<caffe::Net<float > > n: solver->test_nets())
this->_mem_used_test+= n->memory_used();

const int start_iter = solver->iter_;
int average_loss = solver->param_.average_loss();
std::vector<float> losses;
Expand Down Expand Up @@ -2198,7 +2199,9 @@ namespace dd
_net = nullptr;
throw;
}


this->_mem_used_test = _net->memory_used();

float loss = 0.0;
if (extract_layer.empty() || inputc._segmentation) // supervised or segmentation
{
Expand Down Expand Up @@ -3442,6 +3445,9 @@ namespace dd
shrink_param->add_include();
caffe::NetStateRule *nsr = shrink_param->mutable_include(0);
nsr->set_phase(caffe::TRAIN);
caffe::InterpParameter *ip = shrink_param->mutable_interp_param();
ip->set_mode(caffe::InterpParameter::NEAREST);


int softml_pos = find_index_layer_by_type(net_param,"SoftmaxWithLoss");
std::string logits = net_param.layer(softml_pos).bottom(0);
Expand Down
4 changes: 1 addition & 3 deletions src/backends/caffe/caffelib.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ namespace dd
void set_gpuid(const APIData &ad);

void model_complexity(long int &flops,
long int &params);
long int &params);

void model_type(caffe::Net<float> *net,
std::string &mltype);
Expand Down Expand Up @@ -292,8 +292,6 @@ namespace dd
// std::vector<int> _targets; /**< id number of classification or regression targets. */
bool _autoencoder = false; /**< whether an autoencoder. */
std::mutex _net_mutex; /**< mutex around net, e.g. no concurrent predict calls as net is not re-instantiated. Use batches instead. */
long int _flops = 0; /**< model flops. */
long int _params = 0; /**< number of parameters in the model. */
int _crop_size = -1; /**< cropping is part of Caffe transforms in input layers, storing here. */
float _scale = 1.0; /**< scale is part of Caffe transforms in input layers, storing here. */

Expand Down
7 changes: 6 additions & 1 deletion src/mllibstrategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,12 @@ namespace dd
When not, prediction calls are rejected while training is running. */

std::shared_ptr<spdlog::logger> _logger; /**< mllib logger. */


long int _flops = 0; /**< model flops. */
long int _params = 0; /**< number of parameters in the model. */
long int _mem_used_train = 0; /**< amount of memory used. */
long int _mem_used_test = 0; /**< amount of memory used. */

protected:
mutable std::mutex _meas_per_iter_mutex; /**< mutex over measures history. */
mutable std::mutex _meas_mutex; /** mutex around current measures. */
Expand Down
6 changes: 6 additions & 0 deletions src/mlservice.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ namespace dd
vad.push_back(jad);
++hit;
}
APIData stats;
stats.add("flops",this->_flops);
stats.add("params",this->_params);
stats.add("data_mem_train",this->_mem_used_train * sizeof(float));
stats.add("data_mem_test",this->_mem_used_test * sizeof(float));
ad.add("stats", stats);
ad.add("jobs",vad);
ad.add("parameters",_init_parameters);
ad.add("repository",this->_inputc._model_repo);
Expand Down

0 comments on commit ec8788e

Please sign in to comment.