Skip to content

Commit

Permalink
Added % training
Browse files Browse the repository at this point in the history
  • Loading branch information
gineshidalgo99 committed Mar 15, 2018
1 parent 7dade36 commit 825730b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 49 deletions.
9 changes: 6 additions & 3 deletions include/caffe/openpose/layers/oPDataLayer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ class OPDataLayer : public BasePrefetchingDataLayer<Dtype> {
uint64_t offset_;

// OpenPose: added
void NextSecond();
bool SkipSecond();
void NextBackground();
void NextSecond();
// Secondary lmdb
uint64_t offsetSecond;
bool secondDb;
Expand All @@ -60,8 +61,10 @@ class OPDataLayer : public BasePrefetchingDataLayer<Dtype> {
// Data augmentation class
shared_ptr<OPDataTransformer<Dtype> > mOPDataTransformer;
// Timer
int sCounter;
double sDuration;
unsigned long long mOnes;
unsigned long long mTwos;
int mCounter;
double mDuration;
// OpenPose: added end
};

Expand Down
88 changes: 47 additions & 41 deletions src/caffe/openpose/layers/oPDataLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ OPDataLayer<Dtype>::OPDataLayer(const LayerParameter& param) :
db_->Open(param.data_param().source(), db::READ);
cursor_.reset(db_->NewCursor());
// OpenPose: added
mOnes = 0;
mTwos = 0;
// Set up secondary DB
if (!param.op_transform_param().source_secondary().empty())
{
Expand Down Expand Up @@ -56,8 +58,8 @@ OPDataLayer<Dtype>::OPDataLayer(const LayerParameter& param) :
else
backgroundDb = false;
// Timer
sDuration = 0;
sCounter = 0;
mDuration = 0;
mCounter = 0;
// OpenPose: added end
}

Expand Down Expand Up @@ -145,6 +147,19 @@ bool OPDataLayer<Dtype>::Skip()
return !keep;
}

template<typename Dtype>
void OPDataLayer<Dtype>::Next()
{
cursor_->Next();
if (!cursor_->valid())
{
LOG_IF(INFO, Caffe::root_solver())
<< "Restarting data prefetching from start.";
cursor_->SeekToFirst();
}
offset_++;
}

// OpenPose: added
template <typename Dtype>
bool OPDataLayer<Dtype>::SkipSecond()
Expand All @@ -156,55 +171,33 @@ bool OPDataLayer<Dtype>::SkipSecond()
this->layer_param_.phase() == TEST;
return !keep;
}
// OpenPose: end

template<typename Dtype>
void OPDataLayer<Dtype>::Next()
void OPDataLayer<Dtype>::NextBackground()
{
cursor_->Next();
if (!cursor_->valid())
{
LOG_IF(INFO, Caffe::root_solver())
<< "Restarting data prefetching from start.";
cursor_->SeekToFirst();
}
offset_++;
// OpenPose: added
if (backgroundDb)
{
cursorBackground->Next();
if (!cursor_->valid())
if (!cursorBackground->valid())
{
LOG_IF(INFO, Caffe::root_solver())
<< "Restarting negatives data prefetching from start.";
cursorBackground->SeekToFirst();
}
}
// OpenPose: added ended
}

// OpenPose: added
template<typename Dtype>
void OPDataLayer<Dtype>::NextSecond()
{
cursorSecond->Next();
if (!cursorSecond->valid())
{
LOG_IF(INFO, Caffe::root_solver())
<< "Restarting data prefetching from start.";
<< "Restarting second data prefetching from start.";
cursorSecond->SeekToFirst();
}
offsetSecond++;
if (backgroundDb)
{
cursorBackground->Next();
if (!cursorSecond->valid())
{
LOG_IF(INFO, Caffe::root_solver())
<< "Restarting negatives data prefetching from start.";
cursorBackground->SeekToFirst();
}
}
}
// OpenPose: added ended

Expand All @@ -229,7 +222,7 @@ void OPDataLayer<Dtype>::load_batch(Batch<Dtype>* batch)
Datum datumBackground;
// OpenPose: added
const float dice = static_cast <float> (rand()) / static_cast <float> (RAND_MAX); //[0,1]
const auto desiredDbIs1 = (dice <= (1-secondProbability));
const auto desiredDbIs1 = !secondDb || (dice <= (1-secondProbability));
// OpenPose: added ended
for (int item_id = 0; item_id < batch_size; ++item_id) {
timer.Start();
Expand All @@ -241,23 +234,26 @@ void OPDataLayer<Dtype>::load_batch(Batch<Dtype>* batch)
// OpenPose: commended ended
// OpenPose: added
// If only main DB or if 2 DBs but 1st must go
if (!secondDb || desiredDbIs1)
if (desiredDbIs1)
{
while (Skip()) {
mOnes++;
while (Skip())
Next();
}
datum.ParseFromString(cursor_->value());
}
// If 2 DBs & 2nd one must go
else
{
while (SkipSecond()) {
mTwos++;
while (SkipSecond())
NextSecond();
}
datum.ParseFromString(cursorSecond->value());
}
if (backgroundDb)
{
NextBackground();
datumBackground.ParseFromString(cursorBackground->value());
}
// OpenPose: added ended
read_time += timer.MicroSeconds();

Expand Down Expand Up @@ -303,7 +299,15 @@ void OPDataLayer<Dtype>::load_batch(Batch<Dtype>* batch)
&(this->transformed_label_),
datum);
const auto end = std::chrono::high_resolution_clock::now();
sDuration += std::chrono::duration_cast<std::chrono::nanoseconds>(end-begin).count();
mDuration += std::chrono::duration_cast<std::chrono::nanoseconds>(end-begin).count();

// DB 1
if (desiredDbIs1)
Next();
// DB 2
else
NextSecond();
trans_time += timer.MicroSeconds();
// OpenPose: added ended
// OpenPose: commented
// this->data_transformer_->Transform(datum, &(this->transformed_data_));
Expand All @@ -312,17 +316,19 @@ void OPDataLayer<Dtype>::load_batch(Batch<Dtype>* batch)
// Dtype* topLabel = batch->label_.mutable_cpu_data();
// topLabel[item_id] = datum.label();
// }
// trans_time += timer.MicroSeconds();
// Next();
// OpenPose: commented ended
trans_time += timer.MicroSeconds();
Next();
}
// Timer (every 20 iterations x batch size)
sCounter++;
if (sCounter == 20)
mCounter++;
const auto repeatEveryXVisualizations = 2;
if (mCounter == 20*repeatEveryXVisualizations)
{
std::cout << "Time: " << sDuration * 1e-9 << "s" << std::endl;
sDuration = 0;
sCounter = 0;
std::cout << "Time: " << mDuration/repeatEveryXVisualizations * 1e-9 << "s\t"
<< "Ratio: " << mOnes/float(mOnes+mTwos) << std::endl;
mDuration = 0;
mCounter = 0;
}
timer.Stop();
batch_timer.Stop();
Expand Down
10 changes: 5 additions & 5 deletions src/caffe/openpose/oPDataTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,20 +918,20 @@ void OPDataTransformer<Dtype>::generateLabelMap(Dtype* transformedLabel, const c
// Background
const auto type = getType(Dtype(0));
const auto backgroundIndex = numberPafChannels + numberBodyParts;
cv::Mat maskMiss(gridY, gridX, type, &transformedLabel[backgroundIndex*channelOffset]);
cv::Mat maskMissTemp(gridY, gridX, type, &transformedLabel[backgroundIndex*channelOffset]);
// If hands
if (numberBodyParts == 59 && mPoseModel != PoseModel::MPII_59)
{
maskHands(maskMiss, metaData.jointsSelf.isVisible, metaData.jointsSelf.points, stride, 0.6f);
maskHands(maskMissTemp, metaData.jointsSelf.isVisible, metaData.jointsSelf.points, stride, 0.6f);
for (const auto& jointsOther : metaData.jointsOthers)
maskHands(maskMiss, jointsOther.isVisible, jointsOther.points, stride, 0.6f);
maskHands(maskMissTemp, jointsOther.isVisible, jointsOther.points, stride, 0.6f);
}
// If foot
if (numberBodyParts == 23)
{
maskFeet(maskMiss, metaData.jointsSelf.isVisible, metaData.jointsSelf.points, stride, 0.6f);
maskFeet(maskMissTemp, metaData.jointsSelf.isVisible, metaData.jointsSelf.points, stride, 0.6f);
for (const auto& jointsOther : metaData.jointsOthers)
maskFeet(maskMiss, jointsOther.isVisible, jointsOther.points, stride, 0.6f);
maskFeet(maskMissTemp, jointsOther.isVisible, jointsOther.points, stride, 0.6f);
}
}

Expand Down

0 comments on commit 825730b

Please sign in to comment.