Skip to content

Commit

Permalink
Implement ctc prefix beam search decode for TextRecognitionModel.
Browse files Browse the repository at this point in the history
The algorithm is based on Hannun's paper: First-Pass Large Vocabulary
Continuous Speech Recognition using Bi-Directional Recurrent DNNs
  • Loading branch information
yichenj committed Aug 12, 2021
1 parent ea068dc commit 955cf35
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ Before recognition, you should `setVocabulary` and `setDecodeType`.
- `T` is the sequence length
- `B` is the batch size (only support `B=1` in inference)
- and `Dim` is the length of vocabulary +1('Blank' of CTC is at the index=0 of Dim).
- "CTC-prefix-beam-search", the output of the text recognition model should be a probability matrix same with "CTC-greedy".
- The algorithm is proposed at Hannun's [paper](https://arxiv.org/abs/1408.2873).
- `setDecodeOptsCTCPrefixBeamSearch` could be used to control the beam size in search step.
- To futher optimize for big vocabulary, a new option `vocPruneSize` is introduced to avoid iterate the whole vocbulary
but only the number of `vocPruneSize` tokens with top probabilty.

@ref cv::dnn::TextRecognitionModel::recognize() is the main function for text recognition.
- The input image should be a cropped text image or an image with `roiRects`
Expand Down
13 changes: 12 additions & 1 deletion modules/dnn/include/opencv2/dnn/dnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1373,7 +1373,9 @@ class CV_EXPORTS_W_SIMPLE TextRecognitionModel : public Model

/**
* @brief Set the decoding method of translating the network output into string
* @param[in] decodeType The decoding method of translating the network output into string: {'CTC-greedy': greedy decoding for the output of CTC-based methods}
* @param[in] decodeType The decoding method of translating the network output into string, currently supported type:
* - `"CTC-greedy"` greedy decoding for the output of CTC-based methods
* - `"CTC-prefix-beam-search"` Prefix beam search decoding for the output of CTC-based methods
*/
CV_WRAP
TextRecognitionModel& setDecodeType(const std::string& decodeType);
Expand All @@ -1385,6 +1387,15 @@ class CV_EXPORTS_W_SIMPLE TextRecognitionModel : public Model
CV_WRAP
const std::string& getDecodeType() const;

/**
* @brief Set the decoding method options for `"CTC-prefix-beam-search"` decode usage
* @param[in] beamSize Beam size for search
* @param[in] vocPruneSize Parameter to optimize big vocabulary search,
* only take top @p vocPruneSize tokens in each search step, @p vocPruneSize <= 0 stands for disable this prune.
*/
CV_WRAP
TextRecognitionModel& setDecodeOptsCTCPrefixBeamSearch(int beamSize, int vocPruneSize = 0);

/**
* @brief Set the vocabulary for recognition.
* @param[in] vocabulary the associated vocabulary of the network.
Expand Down
83 changes: 83 additions & 0 deletions modules/dnn/src/math_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.

// Code is borrowed from https://github.com/kaldi-asr/kaldi/blob/master/src/base/kaldi-math.h

// base/kaldi-math.h

// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Yanmin Qian;
// Jan Silovsky; Saarland University
//
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef __OPENCV_DNN_MATH_UTILS_HPP__
#define __OPENCV_DNN_MATH_UTILS_HPP__

#ifdef OS_QNX
#include <math.h>
#else
#include <cmath>
#endif

#include <limits>

#ifndef FLT_EPSILON
#define FLT_EPSILON 1.19209290e-7f
#endif

namespace cv { namespace dnn {

const float kNegativeInfinity = -std::numeric_limits<float>::infinity();

const float kMinLogDiffFloat = std::log(FLT_EPSILON);

#if !defined(_MSC_VER) || (_MSC_VER >= 1700)
inline float Log1p(float x) { return log1pf(x); }
#else
inline float Log1p(float x) {
const float cutoff = 1.0e-07;
if (x < cutoff)
return x - 2 * x * x;
else
return Log(1.0 + x);
}
#endif

inline float Exp(float x) { return expf(x); }

inline float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.

if (diff >= kMinLogDiffFloat) {
float res;
res = x + Log1p(Exp(diff));
return res;
} else {
return x; // return the larger one.
}
}

}} // namespace

#endif // __OPENCV_DNN_MATH_UTILS_HPP__
Loading

0 comments on commit 955cf35

Please sign in to comment.