Skip to content

Commit

Permalink
Merge pull request jolibrain#406 from jolibrain/ocr_train
Browse files Browse the repository at this point in the history
OCR training with CTC-CNN+LSTM
  • Loading branch information
beniz authored May 18, 2018
2 parents 92cc5ef + ad2d626 commit bfb8740
Show file tree
Hide file tree
Showing 13 changed files with 1,379 additions and 98 deletions.
8 changes: 5 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ include(ExternalProject)

set (deepdetect_VERSION_MAJOR 0)
set (deepdetect_VERSION_MINOR 1)
#set (CAFFE_INC_DIR /home/infantes/caffe/include /home/infantes/caffe/build/src)
set (HDF5_LIB /usr/lib/x86_64-linux-gnu/hdf5/serial)
#set (CAFFE_LIB_DIR /home/infantes/caffe/build/lib /home/infantes/caffe/build/src ${HDF5_LIB})

# options
OPTION(BUILD_TESTS "Should the tests be built")
Expand Down Expand Up @@ -62,6 +59,11 @@ set(eigen_archive_hash "50812b426b7c")

include_directories("${EIGEN3_INCLUDE_DIR}")

# hdf5
set (HDF5_LIB /usr/lib/x86_64-linux-gnu/hdf5/serial)
set (HDF5_INCLUDE /usr/include/hdf5/serial)
include_directories(${HDF5_INCLUDE})

# dependency on Boost
find_package(Boost 1.54 REQUIRED COMPONENTS filesystem thread system iostreams)

Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ DeepDetect relies on external machine learning libraries through a very generic

#### Machine Learning functionalities per library (current):

| | Training | Prediction | Classification | Object Detection | Segmentation | Regression | Autoencoder |
|------------|----------|------------|----------------|-----------|-----------|------------|-------------|
| Caffe | Y | Y | Y | Y | Y | Y | Y |
| XGBoost | Y | Y | Y | N | N | Y | N/A |
| Tensorflow | N | Y | Y | N | N | N | N |
| T-SNE | Y | N/A | N/A | N/A | N/A | N/A | N/A |
| | Training | Prediction | Classification | Object Detection | Segmentation | Regression | Autoencoder | OCR / Seq2Seq |
|------------|----------|------------|----------------|-----------|-----------|------------|-------------|-------------|
| Caffe | Y | Y | Y | Y | Y | Y | Y | Y |
| XGBoost | Y | Y | Y | N | N | Y | N/A | N |
| Tensorflow | N | Y | Y | N | N | N | N | N |
| T-SNE | Y | N/A | N/A | N/A | N/A | N/A | N/A | N |


#### GPU support per library
Expand Down
2 changes: 1 addition & 1 deletion main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ if (USE_TSNE)
endif()

add_executable (dede dede.cc)
target_link_libraries (dede ddetect ${CUDA_LIB_DEPS} glog gflags ${OpenCV_LIBS} cppnetlib-uri curlpp curl crypto ssl ${Boost_LIBRARIES} ${CAFFE_LIB_DEPS} ${XGBOOST_LIB_DEPS} ${TF_LIB_DEPS} ${TSNE_LIB_DEPS})
target_link_libraries (dede ddetect ${CUDA_LIB_DEPS} glog gflags ${OpenCV_LIBS} cppnetlib-uri curlpp curl crypto ssl hdf5_cpp ${Boost_LIBRARIES} ${CAFFE_LIB_DEPS} ${XGBOOST_LIB_DEPS} ${TF_LIB_DEPS} ${TSNE_LIB_DEPS})
257 changes: 257 additions & 0 deletions src/caffeinputconns.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@

#include "caffeinputconns.h"
#include "utils/utils.hpp"
#include <boost/multi_array.hpp>
#include <H5Cpp.h>
#include <memory>
#include "utf8.h"

using namespace caffe;

Expand Down Expand Up @@ -376,6 +379,260 @@ namespace dd
}
}

// - fixed size in-memory arrays put down to disk at once
void ImgCaffeInputFileConn::images_to_hdf5(const std::vector<std::string> &img_lists,
const std::string &dbfullname,
const std::string &test_dbfullname)
{

// test whether dbs already exist
if (fileops::file_exists(dbfullname + "_0.h5"))
{
if (!fileops::file_exists(_model_repo + "/" + _correspname))
throw InputConnectorBadParamException("found h5 db but no corresp.txt file, erase h5 to rebuild them instead ?");
std::ifstream in(_model_repo + "/" + _correspname);
if (!in.is_open())
throw InputConnectorBadParamException("failed opening corresp.txt file");
int nlines = 0;
std::string line;
while(getline(in,line))
++nlines;
_alphabet_size = nlines;

if (!fileops::file_exists(_model_repo + "/testing.txt"))
_logger->info("no hdf5 test db list found, no test set");
else
{
std::string tfilename;
int tsize = 0;
in = std::ifstream(_model_repo + "/testing.txt");
while(getline(in,tfilename))
{
H5::H5File tfile(tfilename, H5F_ACC_RDONLY);
H5::DataSet dataset = tfile.openDataSet("label");
//H5::FloatType datatype = dataset.getFloatType();
//tsize += datatype.getSize();

H5::DataSpace dataspace = dataset.getSpace();
hsize_t dims[2];
dataspace.getSimpleExtentDims(dims,NULL);
tsize += dims[0];
}
_db_testbatchsize = tsize;
_logger->info("hdf5 test set size={}",tsize);
}
return;
}

//TODO: read / shuffle / split list of images

std::unordered_map<uint32_t,int> alphabet;
alphabet[0] = 0; // space character
int max_ocr_length = -1;

std::string train_list = _model_repo + "/training.txt";
write_images_to_hdf5(img_lists.at(0), dbfullname, train_list, alphabet, max_ocr_length, true);
_logger->info("ctc alphabet training size={}",alphabet.size());

if (img_lists.size() > 1)
{
std::string test_list = _model_repo + "/testing.txt";
write_images_to_hdf5(img_lists.at(1), test_dbfullname, test_list, alphabet, max_ocr_length, false);
}

// save the alphabet as corresp file
std::ofstream correspf(_model_repo + "/" + _correspname,std::ios::binary);
auto hit = alphabet.begin();
while(hit!=alphabet.end())
{
correspf << (*hit).second << " " << std::to_string((*hit).first) << std::endl;
++hit;
}
correspf.close();
_alphabet_size = alphabet.size();
}

void ImgCaffeInputFileConn::write_images_to_hdf5(const std::string &inputfilename,
const std::string &dbfullname,
const std::string &dblistfilename,
std::unordered_map<uint32_t,int> &alphabet,
int &max_ocr_length,
const bool &train_db)
{
std::ifstream train_file(inputfilename);
std::string line;
std::unordered_map<uint32_t,int>::iterator ait;

// count file lines, we're using fixed-size in-memory array due to
// complexity of hdf5 handling of incremental datasets
int clines = 0;
while(std::getline(train_file, line))
{
std::vector<std::string> elts = dd_utils::split(line,' ');
if (train_db)
{
int ocr_size = 0;
for (size_t k=1;k<elts.size();k++)
{
ocr_size += elts.at(k).size();
if (k != elts.size()-1)
++ocr_size; // space between words
}
max_ocr_length = std::max(max_ocr_length,ocr_size);
}
++clines;
}
if (train_db)
{
_logger->info("ctc/ocr dataset training size={}",clines);
_db_batchsize = clines;
}
else
{
_logger->info("ctc/ocr dataset testing size={}",clines);
_db_testbatchsize = clines;
}
_logger->info("ctc output string max size={}",max_ocr_length);
train_file.clear();
train_file.seekg(0, std::ios::beg);

int cn = (_bw ? 1 : 3);
int max_lines = std::pow(10,9) / (_height*_width*3*4);
_logger->info("hdf5 using max number of lines={}",max_lines);

cv::Size size(_width,_height);
int chunks = std::ceil(clines / static_cast<double>(max_lines));
_logger->info("proceeding with {} hdf5 chunks",chunks);
std::vector<std::string> dbchunks;
for (int ch=0;ch<chunks;ch++)
{
int tlines = (ch == chunks-1) ? clines % max_lines: max_lines;
if (tlines == 0)
break;
boost::multi_array<float,4> img_data(boost::extents[tlines][cn][_height][_width]);
boost::multi_array<float,2> ocr_data(boost::extents[tlines][max_ocr_length]);
int nline = 0;
while (std::getline(train_file, line))
{
std::vector<std::string> elts = dd_utils::split(line,' ');

// first elt is the image path
std::string img_path = elts.at(0);
cv::Mat img = cv::imread(img_path, _bw ? CV_LOAD_IMAGE_GRAYSCALE : CV_LOAD_IMAGE_COLOR);
if (_align && img.rows > img.cols) // rotate so that width is longest axis
{
cv::Mat timg;
cv::transpose(img,timg);
cv::flip(timg,img,1);
}
cv::Mat rimg;
try
{
cv::resize(img,rimg,size,0,0,CV_INTER_CUBIC);
}
catch(std::exception &e)
{
_logger->error("failed resizing image {}: {}",img_path,e.what());
continue;
}
img = rimg;
for(int r=0;r<img.rows;r++)
{
for(int c=0;c<img.cols;c++)
{
cv::Point3_<uint8_t> pixel = img.at<cv::Point3_<uint8_t>>(r,c);
img_data[nline][0][r][c] = pixel.x; // B
img_data[nline][1][r][c] = pixel.y; // G
img_data[nline][2][r][c] = pixel.z; // R
}
}

// then come the CTC/OCR string
int cpos = 0;
for (size_t i=1;i<elts.size();i++)
{
std::string ostr = elts.at(i);
char *ostr_c = (char*)ostr.c_str();
char *ostr_ci = ostr_c;
char *end = ostr_c + strlen(ostr_c);
while(ostr_ci<end && cpos < max_ocr_length)
{
// check / add / get id from alphabet
uint32_t c = utf8::next(ostr_ci,end);
int nc = -1;
if ((ait=alphabet.find(c))==alphabet.end())
{
if (train_db)
{
nc = alphabet.size();
alphabet.insert(std::pair<uint32_t,int>(c,nc));
}
else
{
_logger->warn("character {} in test set not found in training set",c);
nc = 0; // space, blank
}
}
else nc = (*ait).second;
ocr_data[nline][cpos] = static_cast<float>(nc);
++cpos;
}
// add space, only if more forthcoming words
if (i != elts.size()-1)
{
ocr_data[nline][cpos] = 0.0;
++cpos;
}
}
// complete string with blank label
while (cpos < max_ocr_length)
{
ocr_data[nline][cpos] = 0.0;
++cpos;
}
if (nline == tlines-1)
break;
++nline;
}

std::string dbchunkname = dbfullname + "_" + std::to_string(ch) + ".h5";
dbchunks.push_back(dbchunkname);
H5::H5File hdffile(dbchunkname, H5F_ACC_TRUNC);
_logger->info("created hdf5 train dataset for chunk {}",ch);

// create datasets
// image data
hsize_t img_dims[4];
img_dims[0] = tlines;
img_dims[1] = cn;
img_dims[2] = _height;
img_dims[3] = _width;
H5::DataSpace dataspace(4, img_dims);
H5::FloatType datatype(H5::PredType::NATIVE_FLOAT);
//datatype.setOrder( H5T_ORDER_LE );
H5::DataSet dataset = hdffile.createDataSet("data",datatype,dataspace);
dataset.write(img_data.data(), H5::PredType::NATIVE_FLOAT);

// ocr data
hsize_t ocr_dims[2];
ocr_dims[0] = tlines;
ocr_dims[1] = max_ocr_length;
H5::DataSpace dataspace2(2, ocr_dims);
H5::FloatType datatype2(H5::PredType::NATIVE_FLOAT);
//datatype.setOrder( H5T_ORDER_LE );
H5::DataSet dataset2 = hdffile.createDataSet("label",datatype2,dataspace2);
dataset2.write(ocr_data.data(), H5::PredType::NATIVE_FLOAT);
}

//TODO: save the alphabet (vocab.dat)

//TODO: generate list of hdf5 db files
std::ofstream tlist(dblistfilename.c_str());
for (auto s: dbchunks)
tlist << s << std::endl;
tlist.close();
}

int ImgCaffeInputFileConn::compute_images_mean(const std::string &dbname,
const std::string &meanfile,
const std::string &backend)
Expand Down
Loading

0 comments on commit bfb8740

Please sign in to comment.