Skip to content

Commit

Permalink
allowing people to manually define how sharp a cascade classifier mod…
Browse files Browse the repository at this point in the history
…el should be trained
  • Loading branch information
StevenPuttemans committed Apr 3, 2015
1 parent 4e87dea commit 7e35f76
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
13 changes: 10 additions & 3 deletions apps/traincascade/cascadeclassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ bool CvCascadeClassifier::train( const string _cascadeDirName,
const CvCascadeParams& _cascadeParams,
const CvFeatureParams& _featureParams,
const CvCascadeBoostParams& _stageParams,
bool baseFormatSave )
bool baseFormatSave,
double acceptanceRatioBreakValue )
{
// Start recording clock ticks for training time output
const clock_t begin_time = clock();
Expand Down Expand Up @@ -185,6 +186,7 @@ bool CvCascadeClassifier::train( const string _cascadeDirName,
cout << "numStages: " << numStages << endl;
cout << "precalcValBufSize[Mb] : " << _precalcValBufSize << endl;
cout << "precalcIdxBufSize[Mb] : " << _precalcIdxBufSize << endl;
cout << "acceptanceRatioBreakValue : " << acceptanceRatioBreakValue << endl;
cascadeParams.printAttrs();
stageParams->printAttrs();
featureParams->printAttrs();
Expand All @@ -207,13 +209,18 @@ bool CvCascadeClassifier::train( const string _cascadeDirName,
if ( !updateTrainingSet( tempLeafFARate ) )
{
cout << "Train dataset for temp stage can not be filled. "
"Branch training terminated." << endl;
"Branch training terminated." << endl;
break;
}
if( tempLeafFARate <= requiredLeafFARate )
{
cout << "Required leaf false alarm rate achieved. "
"Branch training terminated." << endl;
"Branch training terminated." << endl;
break;
}
if( (tempLeafFARate <= acceptanceRatioBreakValue) && (acceptanceRatioBreakValue < 0) ){
cout << "The required acceptanceRatio for the model has been reached to avoid overfitting of trainingdata. "
"Branch training terminated." << endl;
break;
}

Expand Down
3 changes: 2 additions & 1 deletion apps/traincascade/cascadeclassifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ class CvCascadeClassifier
const CvCascadeParams& _cascadeParams,
const CvFeatureParams& _featureParams,
const CvCascadeBoostParams& _stageParams,
bool baseFormatSave = false );
bool baseFormatSave = false,
double acceptanceRatioBreakValue = -1.0 );
private:
int predict( int sampleIdx );
void save( const std::string cascadeDirName, bool baseFormat = false );
Expand Down
9 changes: 8 additions & 1 deletion apps/traincascade/traincascade.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ int main( int argc, char* argv[] )
int precalcValBufSize = 256,
precalcIdxBufSize = 256;
bool baseFormatSave = false;
double acceptanceRatioBreakValue = -1.0;

CvCascadeParams cascadeParams;
CvCascadeBoostParams stageParams;
Expand All @@ -36,6 +37,7 @@ int main( int argc, char* argv[] )
cout << " [-precalcIdxBufSize <precalculated_idxs_buffer_size_in_Mb = " << precalcIdxBufSize << ">]" << endl;
cout << " [-baseFormatSave]" << endl;
cout << " [-numThreads <max_number_of_threads = " << numThreads << ">]" << endl;
cout << " [-acceptanceRatioBreakValue <value> = " << acceptanceRatioBreakValue << ">]" << endl;
cascadeParams.printDefaults();
stageParams.printDefaults();
for( int fi = 0; fi < fc; fi++ )
Expand Down Expand Up @@ -86,6 +88,10 @@ int main( int argc, char* argv[] )
{
numThreads = atoi(argv[++i]);
}
else if( !strcmp( argv[i], "-acceptanceRatioBreakValue" ) )
{
acceptanceRatioBreakValue = atof(argv[++i]);
}
else if ( cascadeParams.scanAttr( argv[i], argv[i+1] ) ) { i++; }
else if ( stageParams.scanAttr( argv[i], argv[i+1] ) ) { i++; }
else if ( !set )
Expand All @@ -112,6 +118,7 @@ int main( int argc, char* argv[] )
cascadeParams,
*featureParams[cascadeParams.featureType],
stageParams,
baseFormatSave );
baseFormatSave,
acceptanceRatioBreakValue );
return 0;
}
6 changes: 6 additions & 0 deletions doc/user_guide/ug_traincascade.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ Command line arguments of opencv_traincascade application grouped by purposes:
Maximum number of threads to use during training. Notice that the actual number of used
threads may be lower, depending on your machine and compilation options.

- -acceptanceRatioBreakValue \<break_value\>

This argument is used to determine how precise your model should keep learning and when to stop.
A good guideline is to train not further than 10e-5, to ensure the model does not overtrain on your training data.
By default this value is set to -1 to disable this feature.

-# Cascade parameters:

- -stageType \<BOOST(default)\>
Expand Down

0 comments on commit 7e35f76

Please sign in to comment.