diff --git a/.gitignore b/.gitignore index 6262738..a166590 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,9 @@ __pycache__/ inference/ inference_results/ output/ +e2e_results/ +rec_results/ +det_results/ train_data/ log/ *.DS_Store diff --git a/README.md b/README.md index d0f19c8..fdfc64d 100644 --- a/README.md +++ b/README.md @@ -1,55 +1,110 @@ -# OpenOCR +
-We aim to establishing a unified benchmark for training and evaluating models for scene text detection and recognition. Based on this benchmark, we introduce an accurate and efficient general OCR system, OpenOCR. Additionally, this repository will serve as the official codebase for the OCR team from the [FVL](https://fvl.fudan.edu.cn) Laboratory, Fudan University. +

OpenOCR: A general OCR system with accuracy and efficiency

+ +
If you find this project useful, please give us a star🌟.
+ +license + + + + + + + +PyPI + + 🚀 Quick Start | English | [简体中文](./README_ch.md) + +
+ +______________________________________________________________________ + +We aim to establish a unified benchmark for training and evaluating models in scene text detection and recognition. Building on this benchmark, we introduce a general OCR system with accuracy and efficiency, **OpenOCR**. This repository also serves as the official codebase of the OCR team from the [FVL Laboratory](https://fvl.fudan.edu.cn), Fudan University. We sincerely welcome the researcher to recommend OCR or relevant algorithms and point out any potential factual errors or bugs. Upon receiving the suggestions, we will promptly evaluate and critically reproduce them. We look forward to collaborating with you to advance the development of OpenOCR and continuously contribute to the OCR community! ## Features -- 🔥**OpenOCR: A general OCR system for accuracy and efficiency** - - ⚡\[[Quick Start](#quick-start)\] \[[Demo](<>)(TODO)\] +- 🔥**OpenOCR: A general OCR system with accuracy and efficiency** + - ⚡\[[Quick Start](#quick-start)\] \[[Model](https://github.com/Topdu/OpenOCR/releases/tag/develop0.0.1)\] \[[ModelScope Demo](https://modelscope.cn/studios/topdktu/OpenOCR-Demo)\] \[[Hugging Face Demo](https://huggingface.co/spaces/topdu/OpenOCR-Demo)\] \[[Local Demo](#local-demo)\] \[[PaddleOCR Implementation](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/text_recognition/algorithm_rec_svtrv2.html)\] - [Introduction](./docs/openocr.md) - - A practical version of the model builds on SVTRv2. - - Outperforming [PP-OCRv4](<>) released by [PaddleOCR](<>) by 4.5% on the [OCR competition leaderboard](<>). - - [x] Supporting Chinese and English text detection and recognition. - - [x] Providing server model and mobile model. - - [ ] Fine-tuning OpenOCR on a custom dataset - - [ ] Export to ONNX engine + - A practical OCR system building on SVTRv2. + - Outperforms [PP-OCRv4](https://paddlepaddle.github.io/PaddleOCR/latest/ppocr/model_list.html) baseline by 4.5% on the [OCR competition leaderboard](https://aistudio.baidu.com/competition/detail/1131/0/leaderboard) in terms of accuracy, while preserving quite similar inference speed. + - [x] Supports Chinese and English text detection and recognition. + - [x] Provides server model and mobile model. + - [x] Fine-tunes OpenOCR on a custom dataset: [Fine-tuning Det](./docs/finetune_det.md), [Fine-tuning Rec](./docs/finetune_rec.md). + - [x] [ONNX model export for wider compatibility](#export-onnx-model). - 🔥**SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition** - - \[[Paper](../configs/rec/svtrv2/SVTRv2.pdf)\] \[[Model](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Config, Training and Inference](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] + - \[[Paper](https://arxiv.org/abs/2411.15858)\] \[[Doc](./configs/rec/svtrv2/)\] \[[Model](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Datasets](./docs/svtrv2.md#downloading-datasets)\] \[[Config, Training and Inference](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] \[[Benchmark](./docs/svtrv2.md#results-benchmark--configs--checkpoints)\] - [Introduction](./docs/svtrv2.md) - - Developing a unified training and evaluation benchmark for Scene Text Recognition - - Supporting for 24 Scene Text Recognition methods trained from scratch on large-scale real datasets, and will continue to add the latest methods. - - Improving results by 20-30% compared to training on synthetic datasets. + - A unified training and evaluation benchmark (on top of [Union14M](https://github.com/Mountchicken/Union14M?tab=readme-ov-file#3-union14m-dataset)) for Scene Text Recognition + - Supports 24 Scene Text Recognition methods trained from scratch on the large-scale real dataset [Union14M-L-Filter](./docs/svtrv2.md#dataset-details), and will continue to add the latest methods. + - Improves accuracy by 20-30% compared to models trained based on synthetic datasets. - Towards Arbitrary-Shaped Text Recognition and Language modeling with a Single Visual Model. - - Surpasses Attention-based Decoder Methods across challenging scenarios in terms of accuracy and speed - - [Get Started](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) with training a SoTA Scene Text Recognition model from scratch. + - Surpasses Attention-based Encoder-Decoder Methods across challenging scenarios in terms of accuracy and speed + - [Get Started](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) with training a SOTA Scene Text Recognition model from scratch. ## Ours STR algorithms -- [**DPTR**](<>) (*Shuai Zhao, Yongkun Du, Zhineng Chen\*, Yu-Gang Jiang. Decoder Pre-Training with only Text for Scene Text Recognition,* ACM MM 2024. [paper](https://arxiv.org/abs/2408.05706)) -- [**IGTR**](./configs/rec/igtr/) (*Yongkun Du, Zhineng Chen\*, Yuchen Su, Caiyan Jia, Yu-Gang Jiang. Instruction-Guided Scene Text Recognition,* Under TPAMI minor revison 2024. [Doc](./configs/rec/igtr/readme.md), [paper](https://arxiv.org/abs/2401.17851)) -- [**SVTRv2**](./configs/rec/svtrv2) (*Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang. SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition,* 2024. [paper](./configs/rec/svtrv2/SVTRv2.pdf)) -- [**SMTR&FocalSVTR**](./configs/rec/smtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang. Out of Length Text Recognition with Sub-String Matching,* 2024. [paper](https://arxiv.org/abs/2407.12317)) -- [**CDistNet**](./configs/rec/cdistnet/) (*Tianlun Zheng, Zhineng Chen\*, Shancheng Fang, Hongtao Xie, Yu-Gang Jiang. CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition,* IJCV 2024. [paper](https://link.springer.com/article/10.1007/s11263-023-01880-0)) -- **MRN** (*Tianlun Zheng, Zhineng Chen\*, Bingchen Huang, Wei Zhang, Yu-Gang Jiang. MRN: Multiplexed routing network for incremental multilingual text recognition,* ICCV 2023. [paper](https://openaccess.thecvf.com/content/ICCV2023/html/Zheng_MRN_Multiplexed_Routing_Network_for_Incremental_Multilingual_Text_Recognition_ICCV_2023_paper.html)) -- **TPS++** (*Tianlun Zheng, Zhineng Chen\*, Jinfeng Bai, Hongtao Xie, Yu-Gang Jiang. TPS++: Attention-Enhanced Thin-Plate Spline for Scene Text Recognition,* IJCAI 2023. [paper](https://arxiv.org/abs/2305.05322)) -- [**CPPD**](./configs/rec/cppd/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Chenxia Li, Yuning Du, Yu-Gang Jiang. Context Perception Parallel Decoder for Scene Text Recognition,* Under TPAMI minor revision 2023. [PaddleOCR Doc](https://github.com/Topdu/PaddleOCR/blob/main/doc/doc_ch/algorithm_rec_cppd.md), [paper](https://arxiv.org/abs/2307.12270)) -- [**SVTR**](./configs/rec/svtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Tianlun Zheng, Chenxia Li, Yuning Du, Yu-Gang Jiang. SVTR: Scene Text Recognition with a Single Visual Model,* IJCAI 2022 (Long). [PaddleOCR Doc](https://github.com/Topdu/PaddleOCR/blob/main/doc/doc_ch/algorithm_rec_svtr.md), [paper](https://www.ijcai.org/proceedings/2022/124)) -- [**NRTR**](./configs/rec/nrtr/) (*Fenfen Sheng, Zhineng Chen\*, Bo Xu. NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition,* ICDAR 2019. [paper](https://arxiv.org/abs/1806.00926)) +- [**SMTR&FocalSVTR**](./configs/rec/smtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang. Out of Length Text Recognition with Sub-String Matching,* AAAI 2025. [Doc](./configs/rec/smtr/), [Paper](https://arxiv.org/abs/2407.12317)) +- [**DPTR**](./configs/rec/dptr/) (*Shuai Zhao, Yongkun Du, Zhineng Chen\*, Yu-Gang Jiang. Decoder Pre-Training with only Text for Scene Text Recognition,* ACM MM 2024. [Paper](https://arxiv.org/abs/2408.05706)) +- [**IGTR**](./configs/rec/igtr/) (*Yongkun Du, Zhineng Chen\*, Yuchen Su, Caiyan Jia, Yu-Gang Jiang. Instruction-Guided Scene Text Recognition,* TPAMI 2025. [Doc](./configs/rec/igtr), [Paper](https://doi.ieeecomputersociety.org/10.1109/TPAMI.2025.3525526)) +- [**SVTRv2**](./configs/rec/svtrv2) (*Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang. SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition,* 2024. [Doc](./configs/rec/svtrv2/), [Paper](https://arxiv.org/abs/2411.15858)) +- [**CDistNet**](./configs/rec/cdistnet/) (*Tianlun Zheng, Zhineng Chen\*, Shancheng Fang, Hongtao Xie, Yu-Gang Jiang. CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition,* IJCV 2024. [Paper](https://link.springer.com/article/10.1007/s11263-023-01880-0)) +- **MRN** (*Tianlun Zheng, Zhineng Chen\*, Bingchen Huang, Wei Zhang, Yu-Gang Jiang. MRN: Multiplexed Routing Network for Incremental Multilingual Text Recognition,* ICCV 2023. [Paper](https://openaccess.thecvf.com/content/ICCV2023/html/Zheng_MRN_Multiplexed_Routing_Network_for_Incremental_Multilingual_Text_Recognition_ICCV_2023_paper.html), [Code](https://github.com/simplify23/MRN)) +- **TPS++** (*Tianlun Zheng, Zhineng Chen\*, Jinfeng Bai, Hongtao Xie, Yu-Gang Jiang. TPS++: Attention-Enhanced Thin-Plate Spline for Scene Text Recognition,* IJCAI 2023. [Paper](https://arxiv.org/abs/2305.05322), [Code](https://github.com/simplify23/TPS_PP)) +- [**CPPD**](./configs/rec/cppd/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Chenxia Li, Yuning Du, Yu-Gang Jiang. Context Perception Parallel Decoder for Scene Text Recognition,* TPAMI (accepted). [PaddleOCR Doc](https://github.com/PaddlePaddle/PaddleOCR/blob/main/docs/algorithm/text_recognition/algorithm_rec_cppd.en.md), [Paper](https://doi.ieeecomputersociety.org/10.1109/TPAMI.2025.3545453)) +- [**SVTR**](./configs/rec/svtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Tianlun Zheng, Chenxia Li, Yuning Du, Yu-Gang Jiang. SVTR: Scene Text Recognition with a Single Visual Model,* IJCAI 2022 (Long). [PaddleOCR Doc](https://github.com/Topdu/PaddleOCR/blob/main/doc/doc_ch/algorithm_rec_svtr.md), [Paper](https://www.ijcai.org/proceedings/2022/124)) +- [**NRTR**](./configs/rec/nrtr/) (*Fenfen Sheng, Zhineng Chen, Bo Xu. NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition,* ICDAR 2019. [Paper](https://arxiv.org/abs/1806.00926)) ## Recent Updates +- **2025.03.24**: 🔥 Releasing the feature of fine-tuning OpenOCR on a custom dataset: [Fine-tuning Det](./docs/finetune_det.md), [Fine-tuning Rec](./docs/finetune_rec.md) + +- **2025.03.23**: 🔥 Releasing the feature of [ONNX model export for wider compatibility](#export-onnx-model). + +- **2025.02.22**: Our paper [CPPD](https://doi.ieeecomputersociety.org/10.1109/TPAMI.2025.3545453) is accepted by TPAMI. Accessible in [Doc](./configs/rec/cppd/) and [PaddleOCR Doc](https://github.com/PaddlePaddle/PaddleOCR/blob/main/docs/algorithm/text_recognition/algorithm_rec_cppd.en.md). + +- **2024.12.31**: Our paper [IGTR](https://doi.ieeecomputersociety.org/10.1109/TPAMI.2025.3525526) is accepted by TPAMI. Accessible in [Doc](./configs/rec/igtr/). + +- **2024.12.16**: Our paper [SMTR](https://arxiv.org/abs/2407.12317) is accepted by AAAI 2025. Accessible in [Doc](./configs/rec/smtr/). + +- **2024.12.03**: The pre-training code for [DPTR](https://arxiv.org/abs/2408.05706) is merged. + - **🔥 2024.11.23 release notes**: - - **OpenOCR: A general OCR system for accuracy and efficiency** - - ⚡\[[Quick Start](#quick-start)\] \[[Demo](<>)(TODO)\] + + - **OpenOCR: A general OCR system with accuracy and efficiency** + - ⚡\[[Quick Start](#quick-start)\] \[[Model](https://github.com/Topdu/OpenOCR/releases/tag/develop0.0.1)\] \[[ModelScope Demo](https://modelscope.cn/studios/topdktu/OpenOCR-Demo)\] \[[Hugging Face Demo](https://huggingface.co/spaces/topdu/OpenOCR-Demo)\] \[[Local Demo](#local-demo)\] \[[PaddleOCR Implementation](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/text_recognition/algorithm_rec_svtrv2.html)\] - [Introduction](./docs/openocr.md) - **SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition** - - \[[Paper](../configs/rec/svtrv2/SVTRv2.pdf)\] \[[Model](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Config, Training and Inference](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] + - \[[Paper](https://arxiv.org/abs/2411.15858)\] \[[Doc](./configs/rec/svtrv2/)\] \[[Model](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Datasets](./docs/svtrv2.md#downloading-datasets)\] \[[Config, Training and Inference](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] \[[Benchmark](./docs/svtrv2.md#results--configs--checkpoints)\] - [Introduction](./docs/svtrv2.md) - - [Get Started](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) with training a SoTA Scene Text Recognition model from scratch. + - [Get Started](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) with training a SOTA Scene Text Recognition model from scratch. + +## Quick Start + +**Note**: OpenOCR supports inference using both the ONNX and Torch frameworks, with the dependency environments for the two frameworks being isolated. When using ONNX for inference, there is no need to install Torch, and vice versa. + +### 1. ONNX Inference + +#### Install OpenOCR and Dependencies: + +```shell +pip install openocr-python +pip install onnxruntime +``` + +#### Usage: + +```python +from openocr import OpenOCR +onnx_engine = OpenOCR(backend='onnx', device='cpu') +img_path = '/path/img_path or /path/img_file' +result, elapse = onnx_engine(img_path) +``` -## ⚡[Quick Start](./docs/openocr.md#quick-start) +### 2. Pytorch inference #### Dependencies: @@ -59,12 +114,17 @@ We sincerely welcome the researcher to recommend OCR or relevant algorithms and ```shell conda create -n openocr python==3.8 conda activate openocr +# install gpu version torch conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia +# or cpu version +conda install pytorch torchvision torchaudio cpuonly -c pytorch ``` After installing dependencies, the following two installation methods are available. Either one can be chosen. -#### 1. Python Modules +#### 2.1. Python Modules + +**Install OpenOCR**: ```shell pip install openocr-python @@ -74,24 +134,24 @@ pip install openocr-python ```python from openocr import OpenOCR - engine = OpenOCR() - img_path = '/path/img_path or /path/img_file' result, elapse = engine(img_path) -print(result) -print(elapse) # Server mode -engine = OpenOCR(mode='server') +# engine = OpenOCR(mode='server') ``` -#### 2. Clone this repository: +#### 2.2. Clone this repository: ```shell git clone https://github.com/Topdu/OpenOCR.git cd OpenOCR pip install -r requirements.txt +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth +# Rec Server model +# wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth ``` **Usage**: @@ -99,14 +159,42 @@ pip install -r requirements.txt ```shell # OpenOCR system: Det + Rec model python tools/infer_e2e.py --img_path=/path/img_fold or /path/img_file - # Det model python tools/infer_det.py --c ./configs/det/dbnet/repvit_db.yml --o Global.infer_img=/path/img_fold or /path/img_file - # Rec model python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.infer_img=/path/img_fold or /path/img_file ``` +##### Export ONNX model + +```shell +pip install onnx +python tools/toonnx.py --c configs/rec/svtrv2/repsvtr_ch.yml --o Global.device=cpu +python tools/toonnx.py --c configs/det/dbnet/repvit_db.yml --o Global.device=cpu +``` + +##### Inference with ONNXRuntime + +```shell +pip install onnxruntime +# OpenOCR system: Det + Rec model +python tools/infer_e2e.py --img_path=/path/img_fold or /path/img_file --backend=onnx --device=cpu +# Det model +python tools/infer_det.py --c ./configs/det/dbnet/repvit_db.yml --o Global.backend=onnx Global.device=cpu Global.infer_img=/path/img_fold or /path/img_file +# Rec model +python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.backend=onnx Global.device=cpu Global.infer_img=/path/img_fold or /path/img_file +``` + +#### Local Demo + +```shell +pip install gradio==4.20.0 +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/OCR_e2e_img.tar +tar xf OCR_e2e_img.tar +# start demo +python demo_gradio.py +``` + ## Reproduction schedule: ### Scene Text Recognition @@ -117,7 +205,7 @@ python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.inf | [ASTER](./configs/rec/aster/) | [TPAMI 2019](https://ieeexplore.ieee.org/document/8395027) | ✅ | ✅ | [pretto0](https://github.com/pretto0) | | [NRTR](./configs/rec/nrtr/) | [ICDAR 2019](https://arxiv.org/abs/1806.00926) | ✅ | ✅ | | | [SAR](./configs/rec/sar/) | [AAAI 2019](https://aaai.org/papers/08610-show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition/) | ✅ | ✅ | [pretto0](https://github.com/pretto0) | -| [MORAN](./configs/rec/moran/) | [PR 2019](https://www.sciencedirect.com/science/article/abs/pii/S0031320319300263) | ✅ | ✅ | Debug | +| [MORAN](./configs/rec/moran/) | [PR 2019](https://www.sciencedirect.com/science/article/abs/pii/S0031320319300263) | ✅ | ✅ | | | [DAN](./configs/rec/dan/) | [AAAI 2020](https://arxiv.org/pdf/1912.10205) | ✅ | ✅ | | | [RobustScanner](./configs/rec/robustscanner/) | [ECCV 2020](https://www.ecva.net/papers/eccv_2020/papers_ECCV/html/3160_ECCV_2020_paper.php) | ✅ | ✅ | [pretto0](https://github.com/pretto0) | | [AutoSTR](./configs/rec/autostr/) | [ECCV 2020](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123690732.pdf) | ✅ | ✅ | | @@ -125,11 +213,11 @@ python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.inf | [SEED](./configs/rec/seed/) | [CVPR 2020](https://openaccess.thecvf.com/content_CVPR_2020/html/Qiao_SEED_Semantics_Enhanced_Encoder-Decoder_Framework_for_Scene_Text_Recognition_CVPR_2020_paper.html) | ✅ | ✅ | | | [ABINet](./configs/rec/abinet/) | [CVPR 2021](https://openaccess.thecvf.com//content/CVPR2021/html/Fang_Read_Like_Humans_Autonomous_Bidirectional_and_Iterative_Language_Modeling_for_CVPR_2021_paper.html) | ✅ | ✅ | [YesianRohn](https://github.com/YesianRohn) | | [VisionLAN](./configs/rec/visionlan/) | [ICCV 2021](https://openaccess.thecvf.com/content/ICCV2021/html/Wang_From_Two_to_One_A_New_Scene_Text_Recognizer_With_ICCV_2021_paper.html) | ✅ | ✅ | [YesianRohn](https://github.com/YesianRohn) | +| PIMNet | [ACM MM 2021](https://dl.acm.org/doi/10.1145/3474085.3475238) | | | TODO | | [SVTR](./configs/rec/svtrs/) | [IJCAI 2022](https://www.ijcai.org/proceedings/2022/124) | ✅ | ✅ | | | [PARSeq](./configs/rec/parseq/) | [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136880177.pdf) | ✅ | ✅ | | | [MATRN](./configs/rec/matrn/) | [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136880442.pdf) | ✅ | ✅ | | | [MGP-STR](./configs/rec/mgpstr/) | [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136880336.pdf) | ✅ | ✅ | | -| [CPPD](./configs/rec/cppd/) | [2023](https://arxiv.org/abs/2307.12270) | ✅ | ✅ | | | [LPV](./configs/rec/lpv/) | [IJCAI 2023](https://www.ijcai.org/proceedings/2023/0189.pdf) | ✅ | ✅ | | | [MAERec](./configs/rec/maerec/)(Union14M) | [ICCV 2023](https://openaccess.thecvf.com/content/ICCV2023/papers/Jiang_Revisiting_Scene_Text_Recognition_A_Data_Perspective_ICCV_2023_paper.pdf) | ✅ | ✅ | | | [LISTER](./configs/rec/lister/) | [ICCV 2023](https://openaccess.thecvf.com/content/ICCV2023/papers/Cheng_LISTER_Neighbor_Decoding_for_Length-Insensitive_Scene_Text_Recognition_ICCV_2023_paper.pdf) | ✅ | ✅ | | @@ -137,14 +225,15 @@ python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.inf | [BUSNet](./configs/rec/busnet/) | [AAAI 2024](https://ojs.aaai.org/index.php/AAAI/article/view/28402) | ✅ | ✅ | | | DCTC | [AAAI 2024](https://ojs.aaai.org/index.php/AAAI/article/view/28575) | | | TODO | | [CAM](./configs/rec/cam/) | [PR 2024](https://arxiv.org/abs/2402.13643) | ✅ | ✅ | | -| [OTE](./configs/rec/ote/) | [CVPR 2024](https://openaccess.thecvf.com/content/CVPR2024/papers/Xu_OTE_Exploring_Accurate_Scene_Text_Recognition_Using_One_Token_CVPR_2024_paper.pdf) | ✅ | ✅ | | +| [OTE](./configs/rec/ote/) | [CVPR 2024](https://openaccess.thecvf.com/content/CVPR2024/html/Xu_OTE_Exploring_Accurate_Scene_Text_Recognition_Using_One_Token_CVPR_2024_paper.html) | ✅ | ✅ | | | CFF | [IJCAI 2024](https://arxiv.org/abs/2407.05562) | | | TODO | -| DPTR | [ACM MM 2024](https://arxiv.org/abs/2408.05706) | | | TODO | +| [DPTR](./configs/rec/dptr/) | [ACM MM 2024](https://arxiv.org/abs/2408.05706) | | | [fd-zs](https://github.com/fd-zs) | | VIPTR | [ACM CIKM 2024](https://arxiv.org/abs/2401.10110) | | | TODO | -| [IGTR](./configs/rec/igtr/) | [2024](https://arxiv.org/abs/2401.17851) | ✅ | ✅ | | -| [SMTR](./configs/rec/smtr/) | [2024](https://arxiv.org/abs/2407.12317) | ✅ | ✅ | | +| [IGTR](./configs/rec/igtr/) | [TPAMI 2025](https://doi.ieeecomputersociety.org/10.1109/TPAMI.2025.3525526) | ✅ | ✅ | | +| [SMTR](./configs/rec/smtr/) | [AAAI 2025](https://arxiv.org/abs/2407.12317) | ✅ | ✅ | | +| [CPPD](./configs/rec/cppd/) | [TPAMI Online Access](https://doi.ieeecomputersociety.org/10.1109/TPAMI.2025.3545453) | ✅ | ✅ | | | [FocalSVTR-CTC](./configs/rec/svtrs/) | [2024](https://arxiv.org/abs/2407.12317) | ✅ | ✅ | | -| [SVTRv2](./configs/rec/svtrv2/) | [2024](./configs/rec/svtrv2/SVTRv2.pdf) | ✅ | ✅ | | +| [SVTRv2](./configs/rec/svtrv2/) | [2024](https://arxiv.org/abs/2411.15858) | ✅ | ✅ | | | [ResNet+Trans-CTC](./configs/rec/svtrs/) | | ✅ | ✅ | | | [ViT-CTC](./configs/rec/svtrs/) | | ✅ | ✅ | | @@ -152,7 +241,7 @@ python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.inf ______________________________________________________________________ -Yiming Lei ([pretto0](https://github.com/pretto0)) and Xingsong Ye ([YesianRohn](https://github.com/YesianRohn)) from the [FVL](https://fvl.fudan.edu.cn) Laboratory, Fudan University, under the guidance of Professor Zhineng Chen, completed the majority of the algorithm reproduction work. Grateful for their outstanding contributions. +Yiming Lei ([pretto0](https://github.com/pretto0)), Xingsong Ye ([YesianRohn](https://github.com/YesianRohn)), and Shuai Zhao ([fd-zs](https://github.com/fd-zs)) from the [FVL Laboratory](https://fvl.fudan.edu.cn), Fudan University, with guidance from Dr. Zhineng Chen ([Homepage](https://zhinchenfd.github.io/)), completed the majority work of the algorithm reproduction. Grateful for their outstanding contributions. ### Scene Text Detection (STD) @@ -164,6 +253,23 @@ TODO ______________________________________________________________________ +## Citation + +If you find our method useful for your reserach, please cite: + +```bibtex +@article{Du2024SVTRv2, + title={SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition}, + author={Yongkun Du and Zhineng Chen and Hongtao Xie and Caiyan Jia and Yu-Gang Jiang}, + journal={CoRR}, + volume={abs/2411.15858}, + eprinttype={arXiv}, + year={2024}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2411.15858} +} +``` + # Acknowledgement This codebase is built based on the [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [PytorchOCR](https://github.com/WenmuZhou/PytorchOCR), and [MMOCR](https://github.com/open-mmlab/mmocr). Thanks for their awesome work! diff --git a/README_ch.md b/README_ch.md new file mode 100644 index 0000000..fb899cc --- /dev/null +++ b/README_ch.md @@ -0,0 +1,269 @@ +
+ +

OpenOCR: A general OCR system with accuracy and efficiency

+ +
如果您觉得本项目有帮助,请为我们点亮Star🌟
+ +license + + + + + + + +PyPI + + 🚀 快速开始 | 简体中文 | [English](./README.md) + +
+ +______________________________________________________________________ + +我们致力于构建场景文本检测与识别模型的统一训练评估基准。基于此基准,我们推出了兼顾精度与效率的通用OCR系统——**OpenOCR**。本仓库同时作为复旦大学[FVL实验室](https://fvl.fudan.edu.cn)OCR团队的官方代码库。 + +我们诚挚欢迎研究者推荐OCR相关算法,并指出潜在的事实性错误或代码缺陷。收到建议后,我们将及时评估并严谨复现。期待与您携手推进OpenOCR发展,持续为OCR社区贡献力量! + +## 核心特性 + +- 🔥**OpenOCR: A general OCR system with accuracy and efficiency** + + - ⚡\[[快速开始](#快速开始)\] \[[模型下载](https://github.com/Topdu/OpenOCR/releases/tag/develop0.0.1)\] \[[ModelScope Demo](https://modelscope.cn/studios/topdktu/OpenOCR-Demo)\] \[[Hugging Face Demo](https://huggingface.co/spaces/topdu/OpenOCR-Demo)\] \[[本地Demo](#本地Demo)\] \[[PaddleOCR实现](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/text_recognition/algorithm_rec_svtrv2.html)\] + - [技术文档](./docs/openocr.md) + - 基于SVTRv2构建的实用OCR系统 + - 在[OCR竞赛榜单](https://aistudio.baidu.com/competition/detail/1131/0/leaderboard)上,精度超越[PP-OCRv4](https://paddlepaddle.github.io/PaddleOCR/latest/ppocr/model_list.html)基线4.5%,推理速度保持相近 + - [x] 支持中英文文本检测与识别 + - [x] 提供服务器端(Server)与移动端(mobile)模型 + - [x] 支持自定义数据集微调: [检测模型微调](./docs/finetune_det.md), [识别模型微调](./docs/finetune_rec.md) + - [x] [支持导出ONNX模型](#导出onnx模型) + +- 🔥**SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition** + + - \[[论文](https://arxiv.org/abs/2411.15858)\] \[[文档](./configs/rec/svtrv2/)\] \[[模型](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[数据集](./docs/svtrv2.md#downloading-datasets)\] \[[配置/训练/推理](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] \[[基准测试](./docs/svtrv2.md#results-benchmark--configs--checkpoints)\] + - [技术文档](./docs/svtrv2.md) + - 基于[Union14M](https://github.com/Mountchicken/Union14M)构建的场景文本识别统一训练评估基准 + - 支持24种场景文本识别方法在大规模真实数据集[Union14M-L-Filter](./docs/svtrv2.md#数据集详情)上的训练,将持续集成前沿方法 + - 相比基于合成数据训练的模型,精度提升20-30% + - 单一视觉模型实现任意形状文本识别与语言建模 + - 在精度与速度上全面超越基于Attention的编解码模型 + - [从零训练SOTA模型指南](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) + +## 自研STR算法 + +- [**SMTR&FocalSVTR**](./configs/rec/smtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang. Out of Length Text Recognition with Sub-String Matching,* AAAI 2025. [Doc](./configs/rec/smtr/), [Paper](https://arxiv.org/abs/2407.12317)) +- [**DPTR**](./configs/rec/dptr/) (*Shuai Zhao, Yongkun Du, Zhineng Chen\*, Yu-Gang Jiang. Decoder Pre-Training with only Text for Scene Text Recognition,* ACM MM 2024. [Paper](https://arxiv.org/abs/2408.05706)) +- [**IGTR**](./configs/rec/igtr/) (*Yongkun Du, Zhineng Chen\*, Yuchen Su, Caiyan Jia, Yu-Gang Jiang. Instruction-Guided Scene Text Recognition,* TPAMI 2025. [Doc](./configs/rec/igtr), [Paper](https://doi.ieeecomputersociety.org/10.1109/TPAMI.2025.3525526)) +- [**SVTRv2**](./configs/rec/svtrv2) (*Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang. SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition,* 2024. [Doc](./configs/rec/svtrv2/), [Paper](https://arxiv.org/abs/2411.15858)) +- [**CDistNet**](./configs/rec/cdistnet/) (*Tianlun Zheng, Zhineng Chen\*, Shancheng Fang, Hongtao Xie, Yu-Gang Jiang. CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition,* IJCV 2024. [Paper](https://link.springer.com/article/10.1007/s11263-023-01880-0)) +- **MRN** (*Tianlun Zheng, Zhineng Chen\*, Bingchen Huang, Wei Zhang, Yu-Gang Jiang. MRN: Multiplexed Routing Network for Incremental Multilingual Text Recognition,* ICCV 2023. [Paper](https://openaccess.thecvf.com/content/ICCV2023/html/Zheng_MRN_Multiplexed_Routing_Network_for_Incremental_Multilingual_Text_Recognition_ICCV_2023_paper.html), [Code](https://github.com/simplify23/MRN)) +- **TPS++** (*Tianlun Zheng, Zhineng Chen\*, Jinfeng Bai, Hongtao Xie, Yu-Gang Jiang. TPS++: Attention-Enhanced Thin-Plate Spline for Scene Text Recognition,* IJCAI 2023. [Paper](https://arxiv.org/abs/2305.05322), [Code](https://github.com/simplify23/TPS_PP)) +- [**CPPD**](./configs/rec/cppd/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Chenxia Li, Yuning Du, Yu-Gang Jiang. Context Perception Parallel Decoder for Scene Text Recognition,* TPAMI (accepted). [PaddleOCR Doc](https://github.com/PaddlePaddle/PaddleOCR/blob/main/docs/algorithm/text_recognition/algorithm_rec_cppd.en.md), [Paper](https://doi.ieeecomputersociety.org/10.1109/TPAMI.2025.3545453)) +- [**SVTR**](./configs/rec/svtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Tianlun Zheng, Chenxia Li, Yuning Du, Yu-Gang Jiang. SVTR: Scene Text Recognition with a Single Visual Model,* IJCAI 2022 (Long). [PaddleOCR Doc](https://github.com/Topdu/PaddleOCR/blob/main/doc/doc_ch/algorithm_rec_svtr.md), [Paper](https://www.ijcai.org/proceedings/2022/124)) +- [**NRTR**](./configs/rec/nrtr/) (*Fenfen Sheng, Zhineng Chen, Bo Xu. NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition,* ICDAR 2019. [Paper](https://arxiv.org/abs/1806.00926)) + +## 近期更新 + +- **2025.03.24**: 🔥 发布自定义数据集微调功能: [检测模型微调](./docs/finetune_det.md), [识别模型微调](./docs/finetune_rec.md) +- **2025.03.23**: 🔥 新增[ONNX模型导出功能](#导出onnx模型) +- **2025.02.22**: [CPPD](https://doi.ieeecomputersociety.org/10.1109/TPAMI.2025.3545453)论文被TPAMI录用,详见[文档](./configs/rec/cppd/)与[PaddleOCR文档](https://github.com/PaddlePaddle/PaddleOCR/blob/main/docs/algorithm/text_recognition/algorithm_rec_cppd.en.md) +- **2024.12.31**: [IGTR](https://doi.ieeecomputersociety.org/10.1109/TPAMI.2025.3525526)论文被TPAMI录用,详见[文档](./configs/rec/igtr/) +- **2024.12.16**: [SMTR](https://arxiv.org/abs/2407.12317)论文被AAAI 2025录用,详见[文档](./configs/rec/smtr/) +- **2024.12.03**: [DPTR](https://arxiv.org/abs/2408.05706)预训练代码合并 +- **🔥 2024.11.23 重大更新**: + - **OpenOCR通用OCR系统发布** + - ⚡\[[快速开始](#快速开始)\] \[[模型下载](https://github.com/Topdu/OpenOCR/releases/tag/develop0.0.1)\] \[[ModelScopeDemo](https://modelscope.cn/studios/topdktu/OpenOCR-Demo)\] \[[Hugging FaceDemo](https://huggingface.co/spaces/topdu/OpenOCR-Demo)\] \[[本地Demo](#本地Demo)\] \[[PaddleOCR实现](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/text_recognition/algorithm_rec_svtrv2.html)\] + - [技术文档](./docs/openocr.md) + - **SVTRv2论文发布** + - \[[论文](https://arxiv.org/abs/2411.15858)\] \[[文档](./configs/rec/svtrv2/)\] \[[模型](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[数据集](./docs/svtrv2.md#downloading-datasets)\] \[[配置/训练/推理](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] \[[基准测试](./docs/svtrv2.md#results-benchmark--configs--checkpoints)\] + - [技术文档](./docs/svtrv2.md) + - [从零训练SOTA模型指南](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) + +## 快速开始 + +**注意**: OpenOCR支持ONNX和PyTorch双框架推理,环境相互独立。使用ONNX推理时无需安装PyTorch,反之亦然。 + +### 1. ONNX推理 + +#### 安装OpenOCR及依赖: + +```shell +pip install openocr-python +pip install onnxruntime +``` + +#### 使用示例: + +```python +from openocr import OpenOCR +onnx_engine = OpenOCR(backend='onnx', device='cpu') +img_path = '/path/img_path or /path/img_file' +result, elapse = onnx_engine(img_path) +``` + +### 2. PyTorch推理 + +#### 环境依赖: + +- [PyTorch](http://pytorch.org/) >= 1.13.0 +- Python >= 3.7 + +```shell +conda create -n openocr python==3.8 +conda activate openocr +# 安装GPU版本 +conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia +# 或CPU版本 +conda install pytorch torchvision torchaudio cpuonly -c pytorch +``` + +#### 2.1 Python包安装 + +**安装OpenOCR**: + +```shell +pip install openocr-python +``` + +**使用示例**: + +```python +from openocr import OpenOCR +engine = OpenOCR() +img_path = '/path/img_path or /path/img_file' +result, elapse = engine(img_path) + +# Server模式 +# engine = OpenOCR(mode='server') +``` + +#### 2.2 源码安装 + +```shell +git clone https://github.com/Topdu/OpenOCR.git +cd OpenOCR +pip install -r requirements.txt +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth +# Server识别模型 +# wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth +``` + +**使用命令**: + +```shell +# 端到端OCR系统: 检测+识别 +python tools/infer_e2e.py --img_path=/path/img_path or /path/img_file +# 单独检测模型 +python tools/infer_det.py --c ./configs/det/dbnet/repvit_db.yml --o Global.infer_img=/path/img_path or /path/img_file +# 单独识别模型 +python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.infer_img=/path/img_path or /path/img_file +``` + +##### 导出ONNX模型 + +```shell +pip install onnx +python tools/toonnx.py --c configs/rec/svtrv2/repsvtr_ch.yml --o Global.device=cpu +python tools/toonnx.py --c configs/det/dbnet/repvit_db.yml --o Global.device=cpu +``` + +##### ONNXRuntime推理 + +```shell +pip install onnxruntime +# 端到端OCR系统 +python tools/infer_e2e.py --img_path=/path/img_path or /path/img_file --backend=onnx --device=cpu +# 检测模型 +python tools/infer_det.py --c ./configs/det/dbnet/repvit_db.yml --o Global.backend=onnx Global.device=cpu Global.infer_img=/path/img_path or /path/img_file +# 识别模型 +python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.backend=onnx Global.device=cpu Global.infer_img=/path/img_path or /path/img_file +``` + +#### 本地Demo + +```shell +pip install gradio==4.20.0 +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/OCR_e2e_img.tar +tar xf OCR_e2e_img.tar +# 启动Demo +python demo_gradio.py +``` + +## 算法复现计划 + +### 场景文本识别(STR) + +| 方法 | 会议/期刊 | 训练支持 | 评估支持 | 贡献者 | +| --------------------------------------------- | ------------------------------------------------------------------------------------------------ | -------- | -------- | ------------------------------------------- | +| [CRNN](./configs/rec/svtrs/) | [TPAMI 2016](https://arxiv.org/abs/1507.05717) | ✅ | ✅ | | +| [ASTER](./configs/rec/aster/) | [TPAMI 2019](https://ieeexplore.ieee.org/document/8395027) | ✅ | ✅ | [pretto0](https://github.com/pretto0) | +| [NRTR](./configs/rec/nrtr/) | [ICDAR 2019](https://arxiv.org/abs/1806.00926) | ✅ | ✅ | | +| [SAR](./configs/rec/sar/) | [AAAI 2019](https://aaai.org/papers/08610-show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition/) | ✅ | ✅ | [pretto0](https://github.com/pretto0) | +| [MORAN](./configs/rec/moran/) | [PR 2019](https://www.sciencedirect.com/science/article/abs/pii/S0031320319300263) | ✅ | ✅ | | +| [DAN](./configs/rec/dan/) | [AAAI 2020](https://arxiv.org/pdf/1912.10205) | ✅ | ✅ | | +| [RobustScanner](./configs/rec/robustscanner/) | [ECCV 2020](https://www.ecva.net/papers/eccv_2020/papers_ECCV/html/3160_ECCV_2020_paper.php) | ✅ | ✅ | [pretto0](https://github.com/pretto0) | +| [AutoSTR](./configs/rec/autostr/) | [ECCV 2020](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123690732.pdf) | ✅ | ✅ | | +| [SRN](./configs/rec/srn/) | [CVPR 2020](https://openaccess.thecvf.com/content_CVPR_2020/html/Yu_Towards_Accurate_Scene_Text_Recognition_With_Semantic_Reasoning_Networks_CVPR_2020_paper.html) | ✅ | ✅ | [pretto0](https://github.com/pretto0) | +| [SEED](./configs/rec/seed/) | [CVPR 2020](https://openaccess.thecvf.com/content_CVPR_2020/html/Qiao_SEED_Semantics_Enhanced_Encoder-Decoder_Framework_for_Scene_Text_Recognition_CVPR_2020_paper.html) | ✅ | ✅ | | +| [ABINet](./configs/rec/abinet/) | [CVPR 2021](https://openaccess.thecvf.com//content/CVPR2021/html/Fang_Read_Like_Humans_Autonomous_Bidirectional_and_Iterative_Language_Modeling_for_CVPR_2021_paper.html) | ✅ | ✅ | [YesianRohn](https://github.com/YesianRohn) | +| [VisionLAN](./configs/rec/visionlan/) | [ICCV 2021](https://openaccess.thecvf.com/content/ICCV2021/html/Wang_From_Two_to_One_A_New_Scene_Text_Recognizer_With_ICCV_2021_paper.html) | ✅ | ✅ | [YesianRohn](https://github.com/YesianRohn) | +| PIMNet | [ACM MM 2021](https://dl.acm.org/doi/10.1145/3474085.3475238) | | | TODO | +| [SVTR](./configs/rec/svtrs/) | [IJCAI 2022](https://www.ijcai.org/proceedings/2022/124) | ✅ | ✅ | | +| [PARSeq](./configs/rec/parseq/) | [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136880177.pdf) | ✅ | ✅ | | +| [MATRN](./configs/rec/matrn/) | [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136880442.pdf) | ✅ | ✅ | | +| [MGP-STR](./configs/rec/mgpstr/) | [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136880336.pdf) | ✅ | ✅ | | +| [LPV](./configs/rec/lpv/) | [IJCAI 2023](https://www.ijcai.org/proceedings/2023/0189.pdf) | ✅ | ✅ | | +| [MAERec](./configs/rec/maerec/)(Union14M) | [ICCV 2023](https://openaccess.thecvf.com/content/ICCV2023/papers/Jiang_Revisiting_Scene_Text_Recognition_A_Data_Perspective_ICCV_2023_paper.pdf) | ✅ | ✅ | | +| [LISTER](./configs/rec/lister/) | [ICCV 2023](https://openaccess.thecvf.com/content/ICCV2023/papers/Cheng_LISTER_Neighbor_Decoding_for_Length-Insensitive_Scene_Text_Recognition_ICCV_2023_paper.pdf) | ✅ | ✅ | | +| [CDistNet](./configs/rec/cdistnet/) | [IJCV 2024](https://link.springer.com/article/10.1007/s11263-023-01880-0) | ✅ | ✅ | [YesianRohn](https://github.com/YesianRohn) | +| [BUSNet](./configs/rec/busnet/) | [AAAI 2024](https://ojs.aaai.org/index.php/AAAI/article/view/28402) | ✅ | ✅ | | +| DCTC | [AAAI 2024](https://ojs.aaai.org/index.php/AAAI/article/view/28575) | | | TODO | +| [CAM](./configs/rec/cam/) | [PR 2024](https://arxiv.org/abs/2402.13643) | ✅ | ✅ | | +| [OTE](./configs/rec/ote/) | [CVPR 2024](https://openaccess.thecvf.com/content/CVPR2024/html/Xu_OTE_Exploring_Accurate_Scene_Text_Recognition_Using_One_Token_CVPR_2024_paper.html) | ✅ | ✅ | | +| CFF | [IJCAI 2024](https://arxiv.org/abs/2407.05562) | | | TODO | +| [DPTR](./configs/rec/dptr/) | [ACM MM 2024](https://arxiv.org/abs/2408.05706) | | | [fd-zs](https://github.com/fd-zs) | +| VIPTR | [ACM CIKM 2024](https://arxiv.org/abs/2401.10110) | | | TODO | +| [IGTR](./configs/rec/igtr/) | [TPAMI 2025](https://doi.ieeecomputersociety.org/10.1109/TPAMI.2025.3525526) | ✅ | ✅ | | +| [SMTR](./configs/rec/smtr/) | [AAAI 2025](https://arxiv.org/abs/2407.12317) | ✅ | ✅ | | +| [CPPD](./configs/rec/cppd/) | [TPAMI Online Access](https://doi.ieeecomputersociety.org/10.1109/TPAMI.2025.3545453) | ✅ | ✅ | | +| [FocalSVTR-CTC](./configs/rec/svtrs/) | [2024](https://arxiv.org/abs/2407.12317) | ✅ | ✅ | | +| [SVTRv2](./configs/rec/svtrv2/) | [2024](https://arxiv.org/abs/2411.15858) | ✅ | ✅ | | +| [ResNet+Trans-CTC](./configs/rec/svtrs/) | | ✅ | ✅ | | +| [ViT-CTC](./configs/rec/svtrs/) | | ✅ | ✅ | | + +#### 核心贡献者 + +______________________________________________________________________ + +复旦大学[FVL实验室](https://fvl.fudan.edu.cn)的Yiming Lei ([pretto0](https://github.com/pretto0)), Xingsong Ye ([YesianRohn](https://github.com/YesianRohn)), and Shuai Zhao ([fd-zs](https://github.com/fd-zs))在Zhineng Chen老师([个人主页](https://zhinchenfd.github.io/))指导下完成了主要算法复现工作,感谢他们的贡献。 + +### 场景文本检测(STD) + +开发中 + +### 端到端文本识别(Text Spotting) + +开发中 + +______________________________________________________________________ + +## 引用 + +如果我们的工作对您的研究有所帮助,请引用: + +```bibtex +@article{Du2024SVTRv2, + title={SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition}, + author={Yongkun Du and Zhineng Chen and Hongtao Xie and Caiyan Jia and Yu-Gang Jiang}, + journal={CoRR}, + volume={abs/2411.15858}, + eprinttype={arXiv}, + year={2024}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2411.15858} +} +``` + +## 致谢 + +本代码库基于[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)、[PytorchOCR](https://github.com/WenmuZhou/PytorchOCR)和[MMOCR](https://github.com/open-mmlab/mmocr)构建,感谢他们的出色工作! diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..1fa1819 --- /dev/null +++ b/__init__.py @@ -0,0 +1,11 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) + +from tools.infer_e2e import OpenOCR, OpenDetector, OpenRecognizer diff --git a/configs/dataset/rec/evaluation.yaml b/configs/dataset/rec/evaluation.yaml new file mode 100644 index 0000000..7c10194 --- /dev/null +++ b/configs/dataset/rec/evaluation.yaml @@ -0,0 +1,41 @@ +root: ../evaluation +task: str +download_links: + # IC15_1811 + - https://drive.usercontent.google.com/download?id=1eGY0kXNV1qVxeUpoGzs-ioUO-ky7msH6&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1BWv7aLoLAT7avY326gXP3GJF48UZpuBC&authuser=0&confirm=t + # SVT + - https://drive.usercontent.google.com/download?id=1ecEZ4cJ7dIbTCZRltE0s5KzUotQWagH-&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1OygBP7i9R-3Pwi6WodCcW31J8CUMugOJ&authuser=0&confirm=t + # IIIT5k + - https://drive.usercontent.google.com/download?id=1PJ9_IvIGZTS5hHdGLnpKuYKZcCO8jE0E&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=10P3MixSBt1v8k8_6aFfziC33Z5IlM6Uf&authuser=0&confirm=t + # IC13_857 + - https://drive.usercontent.google.com/download?id=1-wMHOFBXJaOaY-UD00nDn6qw2s_8R4Vd&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1J1QCFtOFxFKiLJIgTqZ6eRo9Y5QGqHpA&authuser=0&confirm=t + # SVTP + - https://drive.usercontent.google.com/download?id=1kckwfZkdaHG8k_FW5IIJKUaYZkF21Hza&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1x61lm_ea7lvIdxNPMG-jy-5W0MxtdH0N&authuser=0&confirm=t + # CUTE80 + - https://drive.usercontent.google.com/download?id=1Zv_91c81tinLy5Je89HPr-5wUSnqXKIB&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1OuJ6QoJ9AlyNHIM9j2WedAPxTnac7kyY&authuser=0&confirm=t +filenames: + # IC15_1811 + - ../evaluation/IC15_1811/data.mdb + - ../evaluation/IC15_1811/lock.mdb + # SVT + - ../evaluation/SVT/data.mdb + - ../evaluation/SVT/lock.mdb + # IIIT5k + - ../evaluation/IIIT5k/data.mdb + - ../evaluation/IIIT5k/lock.mdb + # IC13_857 + - ../evaluation/IC13_857/data.mdb + - ../evaluation/IC13_857/lock.mdb + # SVTP + - ../evaluation/SVTP/data.mdb + - ../evaluation/SVTP/lock.mdb + # CUTE80 + - ../evaluation/CUTE80/data.mdb + - ../evaluation/CUTE80/lock.mdb +check_validity: true diff --git a/configs/dataset/rec/ltb.yaml b/configs/dataset/rec/ltb.yaml new file mode 100644 index 0000000..a0b03a5 --- /dev/null +++ b/configs/dataset/rec/ltb.yaml @@ -0,0 +1,9 @@ +root: ../ltb +task: str +download_links: + - https://drive.usercontent.google.com/download?id=16AEA1YGTsyVB44uEjKi4ZUV1snjCYBr4&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1xU4OStrOaI23bPG4flWAPWn2YrQe2bmY&authuser=0&confirm=t +filenames: + - ../ltb/data.mdb + - ../ltb/lock.mdb +check_validity: true diff --git a/configs/dataset/rec/mjsynth.yaml b/configs/dataset/rec/mjsynth.yaml new file mode 100644 index 0000000..7a9fc75 --- /dev/null +++ b/configs/dataset/rec/mjsynth.yaml @@ -0,0 +1,11 @@ +root: ../synth +task: str +download_links: + - https://drive.usercontent.google.com/download?id=1FIoplSFZ-BKQoRDHDXsVMKa844e-K8PD&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1eckTvaeRtlTZvbO2orrVz-cIuIk6i87K&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1PBXTf-2PnmEvJBsqzJqxxRwzhAZGTiMG&authuser=0&confirm=t +filenames: + - ../synth/MJ_train.zip + - ../synth/MJ_val.zip + - ../synth/MJ_test.zip +check_validity: true \ No newline at end of file diff --git a/configs/dataset/rec/openvino.yaml b/configs/dataset/rec/openvino.yaml new file mode 100644 index 0000000..f60e448 --- /dev/null +++ b/configs/dataset/rec/openvino.yaml @@ -0,0 +1,25 @@ +root: ../OpenVINO +task: str +download_links: + # train_1 + - https://drive.usercontent.google.com/download?id=1q23QAIRTyG0t-bBm4aAwRwiqB6VUfphw&authuser=0&confirm= + # train_2 + - https://drive.usercontent.google.com/download?id=1AtbaJljM68cbZqi5lcM92d9VkQUCbSqI&authuser=0&confirm= + # train_5 + - https://drive.usercontent.google.com/download?id=1dejstYnJ8_sESuO_uvwi__jT1B8gPxf3&authuser=0&confirm=t + # train_f + - https://drive.usercontent.google.com/download?id=1C4akchTc7-yi1OS_sJ3KP693UKcnecke&authuser=0&confirm=t + # validation + - https://drive.usercontent.google.com/download?id=17TRzSQhuK_juAxAv3KmX0y13pQP2cz6R&authuser=0&confirm=t +filenames: + # train_1 + - ../OpenVINO/train_1.zip + # train_2 + - ../OpenVINO/train_2.zip + # train_5 + - ../OpenVINO/train_5.zip + # train_f + - ../OpenVINO/train_f.zip + # validation + - ../OpenVINO/validation.zip +check_validity: true diff --git a/configs/dataset/rec/ost.yaml b/configs/dataset/rec/ost.yaml new file mode 100644 index 0000000..6690f12 --- /dev/null +++ b/configs/dataset/rec/ost.yaml @@ -0,0 +1,17 @@ +root: ../OST +task: str +download_links: + # OST heavy + - https://drive.usercontent.google.com/download?id=1RGpIFbD_SRlrzZFBoVF_LGvetNx1-5pg&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1Th4MfDf44k0EBpIqCLqVoGRu6G-FP1hq&authuser=0&confirm=t + # OST weak + - https://drive.usercontent.google.com/download?id=1z5CTDJucUnvALG12Q4UXk1DDKJDd8WJn&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1V17TTkX3sjpV7v0km_F2SDCK0tL3k_ls&authuser=0&confirm=t +filenames: + # OST heavy + - ../OST/heavy/data.mdb + - ../OST/heavy/lock.mdb + # OST weak + - ../OST/weak/data.mdb + - ../OST/weak/lock.mdb +check_validity: true diff --git a/configs/dataset/rec/synthtext.yaml b/configs/dataset/rec/synthtext.yaml new file mode 100644 index 0000000..4b84088 --- /dev/null +++ b/configs/dataset/rec/synthtext.yaml @@ -0,0 +1,7 @@ +root: ../synth +task: str +download_links: + - https://drive.usercontent.google.com/download?id=1T-enqkq6_l2HqrsV3da_h0oJ7CUKu_oc&authuser=0&confirm=t +filenames: + - ../synth/ST.zip +check_validity: true diff --git a/configs/dataset/rec/test.yaml b/configs/dataset/rec/test.yaml new file mode 100644 index 0000000..b043ba4 --- /dev/null +++ b/configs/dataset/rec/test.yaml @@ -0,0 +1,77 @@ +root: ../test +task: str +download_links: + # IC13_857 + - https://drive.usercontent.google.com/download?id=1PZSCbe6_DI8MlCqCRWXGT2PP92_frIXq&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1qkN7NDg0zUHxUiZHAeEatDTqlsgpFWp3&authuser=0&confirm=t + # IC15_2077 + - https://drive.usercontent.google.com/download?id=1dFkY3DNbr-Mepn3TWBiA9COEJ63fGFcp&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1UvVwLNZ3tS1YdTBa8MulPzjeVezKaDro&authuser=0&confirm=t + # SVTP + - https://drive.usercontent.google.com/download?id=1aofeerilxJ7J3S7QxuCEXbmXTpz8Xshx&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1rJ1KoO4K_VUxEAUN_bMgBGzK8_JZAAno&authuser=0&confirm=t + # IIIT5k + - https://drive.usercontent.google.com/download?id=1XFO2M1Kbgwv3-iTNTmhQXAEjNmKYOeoT&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1stwK2hFsyaV7HHsEG9EYgnUQebNb2_nG&authuser=0&confirm=t + # COCOv1.4 + - https://drive.usercontent.google.com/download?id=1Se2QSGS19xx7Gfy-SUdX9mlAOr2eYsfA&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1xvekFi389QfkH7yS0JIVV0QzjhUspjDv&authuser=0&confirm=t + # IC15_1811 + - https://drive.usercontent.google.com/download?id=1pHsw8wrThD9EGEE6AusQLZozefSj4iyR&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1TXZ1qHuKAksaAlvd3qMv4IHKnN-IJW9a&authuser=0&confirm=t + # Uber + - https://drive.usercontent.google.com/download?id=1L2j6BZeLTGQ1FIl8HB_D3AFiWLltGV5r&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=12DUj28yzLWxFO_gfMfSjTkRujYD5MNEE&authuser=0&confirm=t + # IC13_1095 + - https://drive.usercontent.google.com/download?id=1fu8onMt3Z6fDLNAiHcm-sQ2qCXduE-FU&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1OQAZtLj8U2Cl4L0ErGFsz6vGIVTTWasD&authuser=0&confirm=t + # IC13_1015 + - https://drive.usercontent.google.com/download?id=1mbsfuvWB282HYfn9tbqcj1nUDkLXcSNB&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1QGogU_hV-oN7iY2POutdD2LDcmK6plnV&authuser=0&confirm=t + # ArT + - https://drive.usercontent.google.com/download?id=1-53knSy-uTSngCG7wyBngVyTuTCmdnWl&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=172EsSaf7BVaB1ORtohi-Jc_8SuUKZGGf&authuser=0&confirm=t + # SVT + - https://drive.usercontent.google.com/download?id=1p7aVUr9Yr7c4X4YUBvk2-YP28rraHjn9&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1ALmhvSleZ0yf-lcdbQPP3M9Zc3oqnXij&authuser=0&confirm=t + # CUTE80 + - https://drive.usercontent.google.com/download?id=1Ujr4axHKnu54P2rIGUhkjdM6XlhDYrI_&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1DvZi9L3MqjO2zRUyCg3YvP4qMAt2bsme&authuser=0&confirm=t +filenames: + # IC13_857 + - ../test/IC13_857/data.mdb + - ../test/IC13_857/lock.mdb + # IC15_2077 + - ../test/IC15_2077/data.mdb + - ../test/IC15_2077/lock.mdb + # SVTP + - ../test/SVTP/data.mdb + - ../test/SVTP/lock.mdb + # IIIT5k + - ../test/IIIT5k/data.mdb + - ../test/IIIT5k/lock.mdb + # COCOv1.4 + - ../test/COCOv1.4/data.mdb + - ../test/COCOv1.4/lock.mdb + # IC15_1811 + - ../test/IC15_1811/data.mdb + - ../test/IC15_1811/lock.mdb + # Uber + - ../test/Uber/data.mdb + - ../test/Uber/lock.mdb + # IC13_1095 + - ../test/IC13_1095/data.mdb + - ../test/IC13_1095/lock.mdb + # IC13_1015 + - ../test/IC13_1015/data.mdb + - ../test/IC13_1015/lock.mdb + # ArT + - ../test/ArT/data.mdb + - ../test/ArT/lock.mdb + # SVT + - ../test/SVT/data.mdb + - ../test/SVT/lock.mdb + # CUTE80 + - ../test/CUTE80/data.mdb + - ../test/CUTE80/lock.mdb +check_validity: true diff --git a/configs/dataset/rec/textocr.yaml b/configs/dataset/rec/textocr.yaml new file mode 100644 index 0000000..abfb4b7 --- /dev/null +++ b/configs/dataset/rec/textocr.yaml @@ -0,0 +1,13 @@ +root: ../TextOCR +task: str +download_links: + # train + - https://drive.usercontent.google.com/download?id=1jVjJFno4pnsU0Cp_kn4MIXQrChmELy92&authuser=0&confirm= + # val + - https://drive.usercontent.google.com/download?id=1ubIRu01MXIek6OvInu-XjaIbw6277-vw&authuser=0&confirm=t +filenames: + # train + - ../TextOCR/train.zip + # val + - ../TextOCR/val.zip +check_validity: true diff --git a/configs/dataset/rec/textocr_horizontal.yaml b/configs/dataset/rec/textocr_horizontal.yaml new file mode 100644 index 0000000..1ccec04 --- /dev/null +++ b/configs/dataset/rec/textocr_horizontal.yaml @@ -0,0 +1,13 @@ +root: ../TextOCR_horizontal +task: str +download_links: + # train + - https://drive.usercontent.google.com/download?id=1sWH6J11xbjQb8SH7fdG_8mIKVI81ZQy5&authuser=0&confirm= + # val + - https://drive.usercontent.google.com/download?id=1gIE-AU2o-5hvg288-bjphO6UkI5AEQ2d&authuser=0&confirm=t +filenames: + # train + - ../TextOCR_horizontal/train.zip + # val + - ../TextOCR_horizontal/val.zip +check_validity: true diff --git a/configs/dataset/rec/union14m_b.yaml b/configs/dataset/rec/union14m_b.yaml new file mode 100644 index 0000000..7479207 --- /dev/null +++ b/configs/dataset/rec/union14m_b.yaml @@ -0,0 +1,47 @@ +root: ../u14m +task: str +download_links: + # artistic + - https://drive.usercontent.google.com/download?id=1Je2DTuFHnkXDI99yDnm9Anl5naWaCQwd&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1xtT_Q0juBJUIvAG55qBxoVNNTECd2usZ&authuser=0&confirm=t + # contextless + - https://drive.usercontent.google.com/download?id=1_0OzyzWhZOmGrHkayFTVrzhrQrNRDRPR&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1PPgC42y3xoM9bR0HQFbDYbcT3PzMdD_y&authuser=0&confirm=t + # salient + - https://drive.usercontent.google.com/download?id=1tHLMYBmTqRnxvFOTT3dfLfQiundqFWfd&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=13NQgpAtCK0kh9M5E2pAUmKKEp6Qu5Xwj&authuser=0&confirm=t + # multi_words + - https://drive.usercontent.google.com/download?id=1IlnDKX3V_Vp9gsDGFB0xoqsVLH1vtxUI&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1mFFjC7C0CwevvkwFU9YeVbZBdps_3Qpb&authuser=0&confirm=t + # curve + - https://drive.usercontent.google.com/download?id=1MxhMd85cmhUtI2lmtXhZQuFk7lav0_fw&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1N03g-4e-kJG2mRvlM0c5TrwWAkd-iG-Q&authuser=0&confirm=t + # general + - https://drive.usercontent.google.com/download?id=1Oqt7OaycP466NWoDmoJ3FqS8YP3YRgvu&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1K0MrX5eYNt8IIGFHXCwg0_oI5OF5PPFO&authuser=0&confirm=t + # multi_oriented + - https://drive.usercontent.google.com/download?id=1TKZFcZPVk0ThqfF-AGhJk_OCLg0ykKbv&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1PAoLMUWuR7O2-7XRoKkNzQcSiznErQzD&authuser=0&confirm=t +filenames: + # artistic + - ../u14m/artistic/data.mdb + - ../u14m/artistic/lock.mdb + # contextless + - ../u14m/contextless/data.mdb + - ../u14m/contextless/lock.mdb + # salient + - ../u14m/salient/data.mdb + - ../u14m/salient/lock.mdb + # multi_words + - ../u14m/multi_words/data.mdb + - ../u14m/multi_words/lock.mdb + # curve + - ../u14m/curve/data.mdb + - ../u14m/curve/lock.mdb + # general + - ../u14m/general/data.mdb + - ../u14m/general/lock.mdb + # multi_oriented + - ../u14m/multi_oriented/data.mdb + - ../u14m/multi_oriented/lock.mdb +check_validity: true diff --git a/configs/dataset/rec/union14m_l_filtered.yaml b/configs/dataset/rec/union14m_l_filtered.yaml new file mode 100644 index 0000000..86ab60e --- /dev/null +++ b/configs/dataset/rec/union14m_l_filtered.yaml @@ -0,0 +1,35 @@ +root: ../Union14M-L-LMDB-Filtered +task: str +download_links: + # train_challenging + - https://drive.usercontent.google.com/download?id=1etwzBgGHjsFsb0sygsaRnKbanW2PMe07&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1ly6FJfPjItwGlVQ-ifTrzzM3rVu3Ezhr&authuser=0&confirm=t + # train_easy + - https://drive.usercontent.google.com/download?id=1_zeNluTnywIaa5h3PN-Ah9tKyByypot7&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1caYLeQHDidXgVBDi9IWXbO1gg__DYq9a&authuser=0&confirm=t + # train_hard + - https://drive.usercontent.google.com/download?id=1eP6s2xyYPZX9gykvWA4VSOc3Fqul_UB_&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1-ZlCvocX8P5uVRclUXp_5DNGLDzd16EO&authuser=0&confirm=t + # train_medium + - https://drive.usercontent.google.com/download?id=1s_CoaLNJEr-UxHYiqZ5jOcliMCFiRUUy&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1Wpj6WVpZ5Ily77kVwfQ18CiZBzkgmEnF&authuser=0&confirm=t + # train_normal + - https://drive.usercontent.google.com/download?id=1jPt44arlAswl9cXZjzmVcdpptdTPpJ3I&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1Rfc5kE03AzOUv7B_eYcBhUV8KMQ2MZ1m&authuser=0&confirm=t +filenames: + # train_challenging + - ../Union14M-L-LMDB-Filtered/train_challenging/data.mdb + - ../Union14M-L-LMDB-Filtered/train_challenging/lock.mdb + # train_easy + - ../Union14M-L-LMDB-Filtered/train_easy/data.mdb + - ../Union14M-L-LMDB-Filtered/train_easy/lock.mdb + # train_hard + - ../Union14M-L-LMDB-Filtered/train_hard/data.mdb + - ../Union14M-L-LMDB-Filtered/train_hard/lock.mdb + # train_medium + - ../Union14M-L-LMDB-Filtered/train_medium/data.mdb + - ../Union14M-L-LMDB-Filtered/train_medium/lock.mdb + # train_normal + - ../Union14M-L-LMDB-Filtered/train_normal/data.mdb + - ../Union14M-L-LMDB-Filtered/train_normal/lock.mdb +check_validity: true diff --git a/configs/det/dbnet/repvit_db.yml b/configs/det/dbnet/repvit_db.yml index c9b1bc1..f66368f 100644 --- a/configs/det/dbnet/repvit_db.yml +++ b/configs/det/dbnet/repvit_db.yml @@ -3,8 +3,8 @@ Global: epoch_num: &epoch_num 500 log_smooth_window: 20 print_batch_step: 100 - save_model_dir: ./output/det_repsvtr_db - save_epoch_step: 10 + output_dir: ./output/det_repsvtr_db + save_epoch_step: [400, 25] eval_batch_step: - 0 - 1000 @@ -12,14 +12,14 @@ Global: checkpoints: pretrained_model: openocr_det_repvit_ch.pth save_inference_dir: null - use_visualdl: false - infer_img: ./testA + use_tensorboard: false + infer_img: save_res_path: ./checkpoints/det_db/predicts_db.txt distributed: true model_type: det Architecture: - algorithm: DB + algorithm: DB_mobile Backbone: name: RepSVTR_det Neck: @@ -30,112 +30,110 @@ Architecture: name: DBHead k: 50 -# Loss: -# name: DBLoss -# balance_loss: true -# main_loss_type: DiceLoss -# alpha: 5 -# beta: 10 -# ohem_ratio: 3 +Loss: + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 -# Optimizer: -# name: Adam -# beta1: 0.9 -# beta2: 0.999 -# lr: -# name: Cosine -# learning_rate: 0.001 #(8*8c) -# warmup_epoch: 2 -# regularizer: -# name: L2 -# factor: 5.0e-05 +Optimizer: + name: Adam + lr: 0.001 + weight_decay: 5.0e-05 + filter_bias_and_bn: False + +LRScheduler: + name: CosineAnnealingLR + warmup_epoch: 2 PostProcess: name: DBPostProcess thresh: 0.3 - box_thresh: 0.4 + box_thresh: 0.6 max_candidates: 1000 unclip_ratio: 1.5 score_mode: 'slow' -# Metric: -# name: DetMetric -# main_indicator: hmean +Metric: + name: DetMetric + main_indicator: hmean -# Train: -# dataset: -# name: SimpleDataSet -# data_dir: ./train_data/icdar2015/text_localization/ -# label_file_list: -# - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt -# ratio_list: [1.0] -# transforms: -# - DecodeImage: -# img_mode: BGR -# channel_first: false -# - DetLabelEncode: null -# - CopyPaste: null -# - IaaAugment: -# augmenter_args: -# - type: Fliplr -# args: -# p: 0.5 -# - type: Affine -# args: -# rotate: -# - -10 -# - 10 -# - type: Resize -# args: -# size: -# - 0.5 -# - 3 -# - EastRandomCropData: -# size: -# - 640 -# - 640 -# max_tries: 50 -# keep_ratio: true -# - MakeBorderMap: -# shrink_ratio: 0.4 -# thresh_min: 0.3 -# thresh_max: 0.7 -# total_epoch: *epoch_num -# - MakeShrinkMap: -# shrink_ratio: 0.4 -# min_text_size: 8 -# total_epoch: *epoch_num -# - NormalizeImage: -# scale: 1./255. -# mean: -# - 0.485 -# - 0.456 -# - 0.406 -# std: -# - 0.229 -# - 0.224 -# - 0.225 -# order: hwc -# - ToCHWImage: null -# - KeepKeys: -# keep_keys: -# - image -# - threshold_map -# - threshold_mask -# - shrink_map -# - shrink_mask -# loader: -# shuffle: true -# drop_last: false -# batch_size_per_card: 8 -# num_workers: 8 +Train: + dataset: + name: SimpleDataSet + data_dir: ../icdar2015/text_localization/ + label_file_list: + - ../icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - DetLabelEncode: null + - CopyPaste: null + - IaaAugment: + augmenter_args: + - type: Fliplr + args: + p: 0.5 + - type: Affine + args: + rotate: + - -10 + - 10 + - type: Resize + args: + size: + - 0.5 + - 3 + - EastRandomCropData: + size: + - 640 + - 640 + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + total_epoch: *epoch_num + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + total_epoch: *epoch_num + - NormalizeImage: + scale: 1./255. + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + order: hwc + - ToCHWImage: null + - KeepKeys: + keep_keys: + - image + - threshold_map + - threshold_mask + - shrink_map + - shrink_mask + loader: + shuffle: true + drop_last: false + batch_size_per_card: 8 + num_workers: 8 Eval: dataset: name: SimpleDataSet - data_dir: ./train_data/icdar2015/text_localization/ + data_dir: ../icdar2015/text_localization/ label_file_list: - - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + - ../icdar2015/text_localization/test_icdar2015_label.txt transforms: - DecodeImage: img_mode: BGR diff --git a/configs/rec/cppd/svtr_base_cppd.yml b/configs/rec/cppd/svtr_base_cppd.yml index f49548a..fa35782 100644 --- a/configs/rec/cppd/svtr_base_cppd.yml +++ b/configs/rec/cppd/svtr_base_cppd.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtr_base_cppd/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/cppd/svtr_base_cppd_ch.yml b/configs/rec/cppd/svtr_base_cppd_ch.yml index 4476cfd..1d20d82 100644 --- a/configs/rec/cppd/svtr_base_cppd_ch.yml +++ b/configs/rec/cppd/svtr_base_cppd_ch.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/ch/svtr_base_cppd/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 2000] eval_epoch_step: [0, 1] diff --git a/configs/rec/cppd/svtr_base_cppd_h8.yml b/configs/rec/cppd/svtr_base_cppd_h8.yml index 8a5f71a..b94c5f2 100644 --- a/configs/rec/cppd/svtr_base_cppd_h8.yml +++ b/configs/rec/cppd/svtr_base_cppd_h8.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtr_base_h8_cppd/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/cppd/svtr_base_cppd_syn.yml b/configs/rec/cppd/svtr_base_cppd_syn.yml index 4a3fc96..2b761fb 100644 --- a/configs/rec/cppd/svtr_base_cppd_syn.yml +++ b/configs/rec/cppd/svtr_base_cppd_syn.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/syn/svtr_base_cppd/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/cppd/svtrv2_cppd.yml b/configs/rec/cppd/svtrv2_cppd.yml index a369cfe..8a163e2 100644 --- a/configs/rec/cppd/svtrv2_cppd.yml +++ b/configs/rec/cppd/svtrv2_cppd.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtrv2_cppd/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/dptr/dptr_parseq_pretrain.yml b/configs/rec/dptr/dptr_parseq_pretrain.yml new file mode 100644 index 0000000..32b8adf --- /dev/null +++ b/configs/rec/dptr/dptr_parseq_pretrain.yml @@ -0,0 +1,88 @@ +Global: + device: gpu + epoch_num: 20 + log_smooth_window: 20 + print_batch_step: 10 + output_dir: /share/ckpt/zhaoshuai/openocr/dptr_parseq/ + eval_epoch_step: [0, 1] + eval_batch_step: [0, 500] + cal_metric_during_train: True + pretrained_model: + checkpoints: + use_tensorboard: false + infer_img: + # for data or label process + character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt + max_text_length: &max_text_length 25 + use_space_char: &use_space_char False + use_amp: True + save_res_path: /share/ckpt/zhaoshuai/openocr/dptr_parseq/predicts_dptr_parseq.txt + grad_clip_val: 20 + +Optimizer: + name: AdamW + lr: 0.001485 # 2gpus 384bs/gpu + weight_decay: 0. + filter_bias_and_bn: False + +LRScheduler: + name: OneCycleLR + warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep + cycle_momentum: False + +Architecture: + model_type: rec + algorithm: DPTR + Decoder: + name: DptrParseq + decode_ar: True + refine_iters: 1 + is_pretrain: True + ORP_path: /share/ckpt/zhaoshuai/parseq/clip_background.pth + +Loss: + name: PARSeqLoss + +PostProcess: + name: ARLabelDecode + character_dict_path: *character_dict_path + use_space_char: *use_space_char + +Metric: + name: RecMetric + main_indicator: acc + is_filter: True + +Train: + dataset: + name: TextLMDBDataSet + data_dir: /share/test/zhaoshuai/parseq-data/data/train/real/ArT + transforms: + - DPTRLabelEncode: # Class handling label + character_dict_path: *character_dict_path + use_space_char: *use_space_char + max_text_length: *max_text_length + - KeepKeys: + keep_keys: ['clip_label', 'label'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 256 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: TextLMDBDataSet + data_dir: /share/test/zhaoshuai/parseq-data/data/val + transforms: + - DPTRLabelEncode: # Class handling label + character_dict_path: *character_dict_path + use_space_char: *use_space_char + max_text_length: *max_text_length + - KeepKeys: + keep_keys: ['clip_label', 'label'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 2 diff --git a/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml b/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml index 57e97fc..692f2ea 100644 --- a/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml +++ b/configs/rec/gtc/svtrv2_lnconv_nrtr_gtc.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/svtrv2_lnconv_nrtr_gtc - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml b/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml index 1b21362..977384f 100644 --- a/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml +++ b/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_long_infer.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/svtrv2_lnconv_smtr_gtc_long_infer - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 1000] eval_epoch_step: [0, 1] diff --git a/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml b/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml index 66f557b..a8ef1a2 100644 --- a/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml +++ b/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_smtr_long.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/svtrv2_lnconv_smtr_gtc_nodetach_smtr_long_infer - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 1000] eval_epoch_step: [0, 1] diff --git a/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml b/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml index d00af78..22cccf6 100644 --- a/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml +++ b/configs/rec/gtc/svtrv2_lnconv_smtr_gtc_stream.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/svtrv2_lnconv_smtr_gtc_stream - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/igtr/readme.md b/configs/rec/igtr/readme.md index 61c7d33..dc420f7 100644 --- a/configs/rec/igtr/readme.md +++ b/configs/rec/igtr/readme.md @@ -13,11 +13,12 @@ Paper: -> [Instruction-Guided Scene Text Recognition](https://arxiv.org/abs/2401.17851) -> Yongkun Du, Zhineng Chen, Yuchen Su, Caiyan Jia, Yu-Gang Jiang +> [Instruction-Guided Scene Text Recognition](https://arxiv.org/abs/2401.17851), +> Yongkun Du, Zhineng Chen, Yuchen Su, Caiyan Jia, Yu-Gang Jiang, +> TPAMI -Multi-modal models show appealing performance in visual recognition tasks recently, as free-form text-guided training evokes the ability to understand fine-grained visual content. However, current models are either inefficient or cannot be trivially upgraded to scene text recognition (STR) due to the composition difference between natural and text images. We propose a novel instruction-guided scene text recognition (IGTR) paradigm that formulates STR as an instruction learning problem and understands text images by predicting character attributes, e.g., character frequency, position, etc. IGTR first devises $\\left \\langle condition,question,answer\\right \\rangle$ instruction triplets, providing rich and diverse descriptions of character attributes. To effectively learn these attributes through question-answering, IGTR develops lightweight instruction encoder, cross-modal feature fusion module and multi-task answer head, which guides nuanced text image understanding. Furthermore, IGTR realizes different recognition pipelines simply by using different instructions, enabling a character-understanding-based text reasoning paradigm that considerably differs from current methods. Experiments on English and Chinese benchmarks show that IGTR outperforms existing models by significant margins, while maintaining a small model size and efficient inference speed. Moreover, by adjusting the sampling of instructions, IGTR offers an elegant way to tackle the recognition of both rarely appearing and morphologically similar characters, which were previous challenges. +Multi-modal models have shown appealing performance in visual recognition tasks, as free-form text-guided training evokes the ability to understand fine-grained visual content. However, current models cannot be trivially applied to scene text recognition (STR) due to the compositional difference between natural and text images. We propose a novel instruction-guided scene text recognition (IGTR) paradigm that formulates STR as an instruction learning problem and understands text images by predicting character attributes, e.g., character frequency, position, etc. IGTR first devises $\\left \\langle condition,question,answer\\right \\rangle$ instruction triplets, providing rich and diverse descriptions of character attributes. To effectively learn these attributes through question-answering, IGTR develops a lightweight instruction encoder, a cross-modal feature fusion module and a multi-task answer head, which guides nuanced text image understanding. Furthermore, IGTR realizes different recognition pipelines simply by using different instructions, enabling a character-understanding-based text reasoning paradigm that differs from current methods considerably. Experiments on English and Chinese benchmarks show that IGTR outperforms existing models by significant margins, while maintaining a small model size and fast inference speed. Moreover, by adjusting the sampling of instructions, IGTR offers an elegant way to tackle the recognition of rarely appearing and morphologically similar characters, which were previous challenges. The accuracy (%) and model files of IGTR on the public dataset of scene text recognition are as follows: @@ -88,11 +89,11 @@ pip install -r requirements.txt #### Dataset Preparation -[English dataset download](https://github.com/baudm/parseq) +- [English dataset download](https://github.com/baudm/parseq) -[Union14M-L download](https://github.com/Mountchicken/Union14M) +- [Union14M-L-LMDB-Filtered download](https://drive.google.com/drive/folders/1OlDWJZgvd6s4S09S3IGeAI90jI0i7AB_?usp=sharing) -[Chinese dataset download](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download) +- [Chinese dataset download](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download) The expected filesystem structure is as follows: @@ -143,7 +144,7 @@ u14m # lmdb format ├── multi_oriented ├── multi_words └── salient -Union14M-LMDB-L # lmdb format +Union14M-L-LMDB-Filtered # lmdb format ├── train_challenging ├── train_easy ├── train_hard @@ -168,13 +169,15 @@ Evaluation: ```shell # The configuration file is available from the link provided in the table above. # en -python tools/eval_rec_all_ratio.py --c PATH/svtr_base_igtr_syn.yml +python tools/eval_rec_all_en.py --c PATH/svtr_base_igtr_syn.yml # ch python tools/eval_rec_all_ch.py --c PATH/svtr_base_igtr_ch_aug.yml ``` ## Citation +If you find our method useful for your reserach, please cite: + ```bibtex @article{Du2024IGTR, title = {Instruction-Guided Scene Text Recognition}, diff --git a/configs/rec/igtr/svtr_base_ds_igtr.yml b/configs/rec/igtr/svtr_base_ds_igtr.yml index df29c39..c8ff306 100644 --- a/configs/rec/igtr/svtr_base_ds_igtr.yml +++ b/configs/rec/igtr/svtr_base_ds_igtr.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtr_base_igtr - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/lpv/svtr_base_lpv.yml b/configs/rec/lpv/svtr_base_lpv.yml index 01c5509..49ae070 100644 --- a/configs/rec/lpv/svtr_base_lpv.yml +++ b/configs/rec/lpv/svtr_base_lpv.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtr_base_lpv/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml b/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml index ede4bdf..04e7d36 100644 --- a/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml +++ b/configs/rec/lpv/svtr_base_lpv_wo_glrm.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtr_base_lpv_wo_glrm/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/lpv/svtrv2_lpv.yml b/configs/rec/lpv/svtrv2_lpv.yml index 8b61431..1aef1e2 100644 --- a/configs/rec/lpv/svtrv2_lpv.yml +++ b/configs/rec/lpv/svtrv2_lpv.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtrv2_lpv/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml b/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml index 607c85a..ffc91f1 100644 --- a/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml +++ b/configs/rec/lpv/svtrv2_lpv_wo_glrm.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtrv2_lpv_wo_glrm/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/maerec/vit_nrtr.yml b/configs/rec/maerec/vit_nrtr.yml index 837b5ab..664dbd5 100644 --- a/configs/rec/maerec/vit_nrtr.yml +++ b/configs/rec/maerec/vit_nrtr.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/vit_nrtr_ft_mae/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml b/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml index ab0d799..c95594b 100644 --- a/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml +++ b/configs/rec/nrtr/focalsvtr_nrtr_maxraio12.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/focalsvtr_nrtr_maxrtio12 - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/nrtr/nrtr.yml b/configs/rec/nrtr/nrtr.yml index 9a4d738..ed1ce44 100644 --- a/configs/rec/nrtr/nrtr.yml +++ b/configs/rec/nrtr/nrtr.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/nrtr/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/nrtr/svtr_base_nrtr.yml b/configs/rec/nrtr/svtr_base_nrtr.yml index d62abbe..7a0ce03 100644 --- a/configs/rec/nrtr/svtr_base_nrtr.yml +++ b/configs/rec/nrtr/svtr_base_nrtr.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtr_base_nrtr/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/nrtr/svtr_base_nrtr_syn.yml b/configs/rec/nrtr/svtr_base_nrtr_syn.yml index bfddad5..e3876a9 100644 --- a/configs/rec/nrtr/svtr_base_nrtr_syn.yml +++ b/configs/rec/nrtr/svtr_base_nrtr_syn.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/syn/svtr_base_nrtr/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/nrtr/svtrv2_nrtr.yml b/configs/rec/nrtr/svtrv2_nrtr.yml index de74c86..e1a448c 100644 --- a/configs/rec/nrtr/svtrv2_nrtr.yml +++ b/configs/rec/nrtr/svtrv2_nrtr.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtrv2_nrtr/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/ote/svtr_base_h8_ote.yml b/configs/rec/ote/svtr_base_h8_ote.yml index 571ee71..74e051f 100644 --- a/configs/rec/ote/svtr_base_h8_ote.yml +++ b/configs/rec/ote/svtr_base_h8_ote.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtr_base_h8_ote/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/ote/svtr_base_ote.yml b/configs/rec/ote/svtr_base_ote.yml index 8f97c27..18896d7 100644 --- a/configs/rec/ote/svtr_base_ote.yml +++ b/configs/rec/ote/svtr_base_ote.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtr_base_ote/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml b/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml index b2601ea..70c975e 100644 --- a/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml +++ b/configs/rec/parseq/focalsvtr_parseq_maxratio12.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/focalsvtr_parseq_maxratio12 - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/smtr/focalsvtr_smtr.yml b/configs/rec/smtr/focalsvtr_smtr.yml index 2324767..190fb42 100644 --- a/configs/rec/smtr/focalsvtr_smtr.yml +++ b/configs/rec/smtr/focalsvtr_smtr.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/focalsvtr_smtr_maxratio12 - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/smtr/focalsvtr_smtr_long.yml b/configs/rec/smtr/focalsvtr_smtr_long.yml index d0b03d1..c140f3b 100644 --- a/configs/rec/smtr/focalsvtr_smtr_long.yml +++ b/configs/rec/smtr/focalsvtr_smtr_long.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/focalsvtr_smtr_long - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/smtr/readme.md b/configs/rec/smtr/readme.md index ddda487..52bbedd 100644 --- a/configs/rec/smtr/readme.md +++ b/configs/rec/smtr/readme.md @@ -13,8 +13,9 @@ Paper: -> [Out of Length Text Recognition with Sub-String Matching](https://arxiv.org/abs/2407.12317) -> Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang +> [Out of Length Text Recognition with Sub-String Matching](https://arxiv.org/abs/2407.12317). +> Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang. +> AAAI 2025 Scene Text Recognition (STR) methods have demonstrated robust performance in word-level text recognition. However, in applications the text image is sometimes long due to detected with multiple horizontal words. It triggers the requirement to build long text recognition models from readily available short word-level text datasets, which has been less studied previously. In this paper, we term this the Out of Length (OOL) text recognition. We establish the first Long Text Benchmark (LTB) to facilitate the assessment of different methods in long text recognition. Meanwhile, we propose a novel method called OOL Text Recognition with sub-String Matching (SMTR). SMTR comprises two cross-attention-based modules: one encodes a sub-string containing multiple characters into next and previous queries, and the other employs the queries to attend to the image features, matching the sub-string and simultaneously recognizing its next and previous character. SMTR can recognize text of arbitrary length by iterating the process above. To avoid being trapped in recognizing highly similar sub-strings, we introduce a regularization training to compel SMTR to effectively discover subtle differences between similar sub-strings for precise matching. In addition, we propose an inference augmentation to alleviate confusion caused by identical sub-strings and improve the overall recognition efficiency. Extensive experimental results reveal that SMTR, even when trained exclusively on short text, outperforms existing methods in public short text benchmarks and exhibits a clear advantage on LTB. @@ -23,31 +24,31 @@ The accuracy (%) and model files of SMTR on the public dataset of scene text rec - Syn: Synth dataset(MJ+ST) from [PARSeq](https://github.com/baudm/parseq) -- U14M: Union14M-L from [Union14M](https://github.com/Mountchicken/Union14M/) +- Union14M-L-LMDB-Filtered: A filtered version of [Union14M](https://github.com/Mountchicken/Union14M/) - Test on Long Text Benchmark ([Download LTB](https://drive.google.com/drive/folders/1NChdlw7ustbXtlFBmh_0xnHvRkffb9Ge?usp=sharing)): -| Model | Training Data | LTB | Config&Model&Log | -| :-------: | :-----------: | :--: | :---------------------------------------------------------------------------------------------: | -| SMTR | Syn | 39.6 | [link](https://drive.google.com/drive/folders/11SplakPPOFDMhPixv7ABNgjeTg4jKyfU?usp=sharing) | -| SMTR | U14M | 51.0 | [link](https://drive.google.com/drive/folders/1-K5O0d0q9fhY5fJvU6nn5fFFtSMnbE_-?usp=drive_link) | -| FocalSVTR | U14M | 42.1 | [link](https://drive.google.com/drive/folders/100xF5wFr7xSCVBYM1h_0d_8xv5Qeqobp?usp=sharing) | +| Model | Training Data | LTB | Config&Model&Log | +| :-------: | :----------------------: | :--: | :---------------------------------------------------------------------------------------------: | +| SMTR | Syn | 39.6 | [link](https://drive.google.com/drive/folders/11SplakPPOFDMhPixv7ABNgjeTg4jKyfU?usp=sharing) | +| SMTR | Union14M-L-LMDB-Filtered | 51.0 | [link](https://drive.google.com/drive/folders/1-K5O0d0q9fhY5fJvU6nn5fFFtSMnbE_-?usp=drive_link) | +| FocalSVTR | Union14M-L-LMDB-Filtered | 42.1 | [link](https://drive.google.com/drive/folders/100xF5wFr7xSCVBYM1h_0d_8xv5Qeqobp?usp=sharing) | - Test on Common Benchmarks from [PARSeq](https://github.com/baudm/parseq): -| Model | Training Data | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log | -| :-------: | :-----------: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :---------------------: | -| SMTR | Syn | 97.4 | 94.9 | 97.4 | 88.4 | 89.9 | 96.2 | 94.02 | Same as the above table | -| SMTR | U14M | 98.3 | 97.4 | 99.0 | 90.1 | 92.7 | 97.9 | 95.90 | Same as the above table | -| FocalSVTR | U14M | 97.3 | 96.3 | 98.2 | 87.4 | 88.4 | 96.2 | 93.97 | Same as the above table | +| Model | Training Data | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log | +| :-------: | :----------------------: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :---------------------: | +| SMTR | Syn | 97.4 | 94.9 | 97.4 | 88.4 | 89.9 | 96.2 | 94.02 | Same as the above table | +| SMTR | Union14M-L-LMDB-Filtered | 98.3 | 97.4 | 99.0 | 90.1 | 92.7 | 97.9 | 95.90 | Same as the above table | +| FocalSVTR | Union14M-L-LMDB-Filtered | 97.3 | 96.3 | 98.2 | 87.4 | 88.4 | 96.2 | 93.97 | Same as the above table | - Test on Union14M-L benchmark from [Union14M](https://github.com/Mountchicken/Union14M/). -| Model | Traing Data | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log | -| :-------: | :---------: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: | -| SMTR | Syn | 74.2 | 30.6 | 58.5 | 67.6 | 79.6 | 75.1 | 67.9 | 64.79 | Same as the above table | -| SMTR | U14M | 89.1 | 87.7 | 76.8 | 83.9 | 84.6 | 89.3 | 83.7 | 85.00 | Same as the above table | -| FocalSVTR | U14M | 77.7 | 62.4 | 65.7 | 78.6 | 71.6 | 81.3 | 79.2 | 73.80 | Same as the above table | +| Model | Traing Data | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log | +| :-------: | :----------------------: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: | +| SMTR | Syn | 74.2 | 30.6 | 58.5 | 67.6 | 79.6 | 75.1 | 67.9 | 64.79 | Same as the above table | +| SMTR | Union14M-L-LMDB-Filtered | 89.1 | 87.7 | 76.8 | 83.9 | 84.6 | 89.3 | 83.7 | 85.00 | Same as the above table | +| FocalSVTR | Union14M-L-LMDB-Filtered | 77.7 | 62.4 | 65.7 | 78.6 | 71.6 | 81.3 | 79.2 | 73.80 | Same as the above table | - Training and test on Chinese dataset, from [Chinese Benckmark](https://github.com/FudanVI/benchmarking-chinese-text-recognition). @@ -79,7 +80,7 @@ pip install -r requirements.txt - [English dataset download](https://github.com/baudm/parseq) -- [Union14M-L download](https://github.com/Mountchicken/Union14M) +- [Union14M-L-LMDB-Filtered download](https://drive.google.com/drive/folders/1OlDWJZgvd6s4S09S3IGeAI90jI0i7AB_?usp=sharing) - [Chinese dataset download](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download) @@ -135,7 +136,7 @@ u14m # lmdb format ├── multi_words └── salient ltb # download link: https://drive.google.com/drive/folders/1NChdlw7ustbXtlFBmh_0xnHvRkffb9Ge?usp=sharing -Union14M-LMDB-L # lmdb format +Union14M-L-LMDB-Filtered # lmdb format ├── train_challenging ├── train_easy ├── train_hard @@ -160,7 +161,7 @@ Evaluation: ```shell # en -python tools/eval_rec_all_ratio.py --c configs/rec/smtr/focalsvtr_smtr.yml +python tools/eval_rec_all_en.py --c configs/rec/smtr/focalsvtr_smtr.yml # long text python tools/eval_rec_all_long_simple.py --c configs/rec/smtr/focalsvtr_smtr_long.yml # ch @@ -169,6 +170,8 @@ python tools/eval_rec_all_ch.py --c configs/rec/smtr/focalsvtr_smtr_ch.yml ## Citation +If you find our method useful for your reserach, please cite: + ```bibtex @article{Du2024SMTR, title = {Out of Length Text Recognition with Sub-String Matching}, diff --git a/configs/rec/smtr/svtrv2_smtr.yml b/configs/rec/smtr/svtrv2_smtr.yml index e0c57ce..013fda0 100644 --- a/configs/rec/smtr/svtrv2_smtr.yml +++ b/configs/rec/smtr/svtrv2_smtr.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtrv2_lnconv_smtr_maxratio12 - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/smtr/svtrv2_smtr_bi.yml b/configs/rec/smtr/svtrv2_smtr_bi.yml index 25bc7d3..54c070b 100644 --- a/configs/rec/smtr/svtrv2_smtr_bi.yml +++ b/configs/rec/smtr/svtrv2_smtr_bi.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtrv2_lnconv_smtr_bi - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/svtrv2/SVTRv2.pdf b/configs/rec/svtrv2/SVTRv2.pdf deleted file mode 100644 index 0be5128..0000000 Binary files a/configs/rec/svtrv2/SVTRv2.pdf and /dev/null differ diff --git a/configs/rec/svtrv2/readme.md b/configs/rec/svtrv2/readme.md index aade0b5..31abbf6 100644 --- a/configs/rec/svtrv2/readme.md +++ b/configs/rec/svtrv2/readme.md @@ -18,7 +18,7 @@ Paper: -> [SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition](./SVTRv2.pdf) +> [SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition](https://arxiv.org/abs/2411.15858) > Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang @@ -30,11 +30,11 @@ The accuracy (%) and model files of SVTRv2 on the public dataset of scene text r Download all Configs, Models, and Logs from [Google Drive](https://drive.google.com/drive/folders/1i2EZVT-oxfDIDdhwQRm9E6Fk8s6qD3C1?usp=sharing). -| Model | Model size | FPS | -| :------: | :--------: | :-: | -| SVTRv2-T | 5.13 | 5.0 | -| SVTRv2-S | 11.25 | 5.3 | -| SVTRv2-B | 19.76 | 7.0 | +| Model | Model size | Latency | +| :------: | :--------: | :-----: | +| SVTRv2-T | 5.13 | 5.0 | +| SVTRv2-S | 11.25 | 5.3 | +| SVTRv2-B | 19.76 | 7.0 | - Test on Common Benchmarks from [PARSeq](https://github.com/baudm/parseq): @@ -102,10 +102,10 @@ Referring to [Downloading Datasets](../../../docs/svtrv2.md#downloading-datasets CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_rctc.yml # Second stage -CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Global.pretrained_model=./output/rec/u14m_filter/svtrv2_rctc/best.pth +CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch --master_port=23332 --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Global.pretrained_model=./output/rec/u14m_filter/svtrv2_rctc/best.pth # For Multi RTX 4090 -NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_rctc.yml +NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port=23333 --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_rctc.yml # 20epoch runs for about 6 hours ``` @@ -113,9 +113,10 @@ NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.laun ```shell # short text: Common, Union14M-Benchmark, OST -python tools/eval_rec_all_ratio.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml -# long text -python tools/eval_rec_all_long_simple.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +python tools/eval_rec_all_en.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml + +# long text: LTB +python tools/eval_rec_all_long.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml --o Eval.loader.max_ratio=20 ``` After a successful run, the results are saved in a csv file in `output_dir` in the config file. @@ -123,7 +124,7 @@ After a successful run, the results are saved in a csv file in `output_dir` in t ### Inference ```shell -python tools/infer_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Global.infer_img=/path/img_fold or /path/img_file +python tools/infer_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml --o Global.infer_img=/path/img_fold or /path/img_file ``` ### Latency Measurement @@ -131,15 +132,22 @@ python tools/infer_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Gl Firstly, downloading the IIIT5K images from [Google Drive](https://drive.google.com/drive/folders/1Po1LSBQb87DxGJuAgLNxhsJ-pdXxpIfS?usp=drive_link). Then, running the following command: ```shell -python tools/infer_rec.py --c configs/rec/SVTRv2/svtrv2_rctc.yml --o Global.infer_img=../iiit5k_test_image +python tools/infer_rec.py --c configs/rec/SVTRv2/svtrv2_smtr_gtc_rctc_infer.yml --o Global.infer_img=../iiit5k_test_image ``` ## Citation +If you find our method useful for your reserach, please cite: + ```bibtex -@article{Du2024SVTRv4, - title = {SVTRv2: Scene Text Recognition with a Single Visual Model}, - author = {Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang}, - year = {2024} +@article{Du2024SVTRv2, + title={SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition}, + author={Yongkun Du and Zhineng Chen and Hongtao Xie and Caiyan Jia and Yu-Gang Jiang}, + journal={CoRR}, + volume={abs/2411.15858}, + eprinttype={arXiv}, + year={2024}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2411.15858} } ``` diff --git a/configs/rec/svtrv2/repsvtr_ch.yml b/configs/rec/svtrv2/repsvtr_ch.yml index 033d6e7..41a1992 100644 --- a/configs/rec/svtrv2/repsvtr_ch.yml +++ b/configs/rec/svtrv2/repsvtr_ch.yml @@ -1,13 +1,13 @@ Global: device: gpu - epoch_num: 20 + epoch_num: 100 log_smooth_window: 20 print_batch_step: 10 - output_dir: ./output/rec/repsvtr_ch/ - save_epoch_step: 1 + output_dir: ./output/rec/ch/repsvtr_ch/ + save_epoch_step: [150, 10] # evaluation is run every 2000 iterations eval_epoch_step: [0, 1] - eval_batch_step: [0, 500] + eval_batch_step: [0, 2000] cal_metric_during_train: True pretrained_model: ./openocr_repsvtr_ch.pth checkpoints: @@ -22,19 +22,18 @@ Global: project_name: resvtr_ctc_nosgm_ds Optimizer: - name: AdamW - lr: 0.00065 # for 4gpus bs256/gpu - weight_decay: 0.05 - filter_bias_and_bn: True + name: Adam + lr: 0.0001 + weight_decay: 3.0e-05 + filter_bias_and_bn: False LRScheduler: - name: OneCycleLR - warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep - cycle_momentum: False + name: CosineAnnealingLR + warmup_epoch: 5 Architecture: model_type: rec - algorithm: SVTRv2 + algorithm: SVTRv2_mobile Transform: Encoder: name: RepSVTREncoder @@ -53,6 +52,7 @@ Loss: PostProcess: name: CTCLabelDecode + character_dict_path: *character_dict_path Metric: name: RecMetric @@ -62,16 +62,10 @@ Metric: Train: dataset: - name: RatioDataSetTVResize - ds_width: True - padding: False - base_shape: [[32, 32], [64, 32], [96, 32], [128, 32]] - data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging', - '../Union14M-L-LMDB-Filtered/filter_train_hard', - '../Union14M-L-LMDB-Filtered/filter_train_medium', - '../Union14M-L-LMDB-Filtered/filter_train_normal', - '../Union14M-L-LMDB-Filtered/filter_train_easy', - ] + name: SimpleDataSet + data_dir: ../ic15_data/ + label_file_list: + - ../ic15_data/rec_gt_train.txt transforms: - DecodeImagePIL: # load image img_mode: RGB @@ -80,27 +74,23 @@ Train: character_dict_path: *character_dict_path use_space_char: *use_space_char max_text_length: *max_text_length + - RecTVResize: + image_shape: [48, 320] + padding: True - KeepKeys: keep_keys: ['image', 'label', 'length'] - sampler: - name: RatioSampler - scales: [[128, 32]] # w, h - # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple - first_bs: &bs 256 - fix_bs: false - divided_factor: [4, 16] # w, h - is_training: True loader: shuffle: True - batch_size_per_card: *bs + batch_size_per_card: 256 drop_last: True - max_ratio: 4 num_workers: 4 Eval: dataset: - name: LMDBDataSet - data_dir: ../evaluation + name: SimpleDataSet + data_dir: ../ic15_data/ + label_file_list: + - ../ic15_data/rec_gt_test.txt transforms: - DecodeImage: # load image img_mode: BGR @@ -111,6 +101,9 @@ Eval: - RecDynamicResize: image_shape: [48, 320] padding: False + # - SVTRResize: + # image_shape: [3, 48, 320] + # padding: True - KeepKeys: keep_keys: ['image', 'label', 'length'] loader: diff --git a/configs/rec/svtrv2/svtrv2_ch.yml b/configs/rec/svtrv2/svtrv2_ch.yml index a6538df..a3b346e 100644 --- a/configs/rec/svtrv2/svtrv2_ch.yml +++ b/configs/rec/svtrv2/svtrv2_ch.yml @@ -1,13 +1,13 @@ Global: device: gpu - epoch_num: 20 + epoch_num: 100 log_smooth_window: 20 print_batch_step: 10 - output_dir: ./output/rec/u14m_filter/svtrv2_ctc_u14m_two33_tvresize/ - save_epoch_step: 1 + output_dir: ./output/rec/ch/svtrv2_ctc_ch/ + save_epoch_step: [150, 10] # evaluation is run every 2000 iterations eval_epoch_step: [0, 1] - eval_batch_step: [0, 500] + eval_batch_step: [0, 2000] cal_metric_during_train: True pretrained_model: ./openocr_svtrv2_ch.pth checkpoints: @@ -19,22 +19,21 @@ Global: use_space_char: &use_space_char True save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_ctc.txt use_amp: True - project_name: svtrv2_ctc_nosgm_ds + project_name: svtrv2_ctc_ch Optimizer: name: AdamW - lr: 0.00065 # for 4gpus bs256/gpu + lr: 0.0001 # for 4gpus bs256/gpu weight_decay: 0.05 filter_bias_and_bn: True LRScheduler: - name: OneCycleLR - warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep - cycle_momentum: False + name: CosineAnnealingLR + warmup_epoch: 5 Architecture: model_type: rec - algorithm: SVTRv2 + algorithm: SVTRv2_server Transform: Encoder: name: SVTRv2LNConvTwo33 @@ -65,6 +64,7 @@ Loss: PostProcess: name: CTCLabelDecode + character_dict_path: *character_dict_path Metric: name: RecMetric @@ -74,16 +74,10 @@ Metric: Train: dataset: - name: RatioDataSetTVResize - ds_width: True - padding: False - base_shape: [[32, 32], [64, 32], [96, 32], [128, 32]] - data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging', - '../Union14M-L-LMDB-Filtered/filter_train_hard', - '../Union14M-L-LMDB-Filtered/filter_train_medium', - '../Union14M-L-LMDB-Filtered/filter_train_normal', - '../Union14M-L-LMDB-Filtered/filter_train_easy', - ] + name: SimpleDataSet + data_dir: ../ic15_data/ + label_file_list: + - ../ic15_data/rec_gt_train.txt transforms: - DecodeImagePIL: # load image img_mode: RGB @@ -92,27 +86,23 @@ Train: character_dict_path: *character_dict_path use_space_char: *use_space_char max_text_length: *max_text_length + - RecTVResize: + image_shape: [48, 320] + padding: True - KeepKeys: keep_keys: ['image', 'label', 'length'] - sampler: - name: RatioSampler - scales: [[128, 32]] # w, h - # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple - first_bs: &bs 256 - fix_bs: false - divided_factor: [4, 16] # w, h - is_training: True loader: shuffle: True - batch_size_per_card: *bs + batch_size_per_card: 256 drop_last: True - max_ratio: 4 num_workers: 4 Eval: dataset: - name: LMDBDataSet - data_dir: ../evaluation + name: SimpleDataSet + data_dir: ../ic15_data/ + label_file_list: + - ../ic15_data/rec_gt_test.txt transforms: - DecodeImage: # load image img_mode: BGR @@ -123,10 +113,13 @@ Eval: - RecDynamicResize: image_shape: [48, 320] padding: False + # - RecTVResize: + # image_shape: [48, 320] + # padding: True - KeepKeys: keep_keys: ['image', 'label', 'length'] loader: shuffle: False drop_last: False - batch_size_per_card: 256 + batch_size_per_card: 1 num_workers: 4 diff --git a/configs/rec/svtrv2/svtrv2_ctc.yml b/configs/rec/svtrv2/svtrv2_ctc.yml index da7b413..d446612 100644 --- a/configs/rec/svtrv2/svtrv2_ctc.yml +++ b/configs/rec/svtrv2/svtrv2_ctc.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtrv2_ctc/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_epoch_step: [0, 1] eval_batch_step: [0, 500] diff --git a/configs/rec/svtrv2/svtrv2_rctc.yml b/configs/rec/svtrv2/svtrv2_rctc.yml index f5d9ff3..5ac9ae2 100644 --- a/configs/rec/svtrv2/svtrv2_rctc.yml +++ b/configs/rec/svtrv2/svtrv2_rctc.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtrv2_rctc/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_epoch_step: [0, 1] eval_batch_step: [0, 500] diff --git a/configs/rec/svtrv2/svtrv2_small_rctc.yml b/configs/rec/svtrv2/svtrv2_small_rctc.yml index 95329cd..b36c3d6 100644 --- a/configs/rec/svtrv2/svtrv2_small_rctc.yml +++ b/configs/rec/svtrv2/svtrv2_small_rctc.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtrv2_small_rctc/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_epoch_step: [0, 1] eval_batch_step: [0, 500] diff --git a/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml index aa36492..4e61370 100644 --- a/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +++ b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml @@ -3,8 +3,8 @@ Global: epoch_num: 20 log_smooth_window: 20 print_batch_step: 10 - output_dir: ./output/rec/u14m_filter/svtrv2_smtr_gtc_rctc_maxratio12 - save_epoch_step: 1 + output_dir: ./output/rec/u14m_filter/svtrv2_smtr_gtc_rctc + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 500] eval_epoch_step: [0, 1] diff --git a/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml index 8d3c3d7..f296dfc 100644 --- a/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml +++ b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_ch.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/ch/svtrv2_smtr_gtc_rctc_ch - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_batch_step: [0, 2000] eval_epoch_step: [0, 1] diff --git a/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml new file mode 100644 index 0000000..5cd0499 --- /dev/null +++ b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml @@ -0,0 +1,151 @@ +Global: + device: gpu + epoch_num: 20 + log_smooth_window: 20 + print_batch_step: 10 + output_dir: ./output/rec/u14m_filter/svtrv2_smtr_gtc_rctc + save_epoch_step: [15, 1] + # evaluation is run every 2000 iterations + eval_batch_step: [0, 500] + eval_epoch_step: [0, 1] + cal_metric_during_train: True + pretrained_model: + # ./output/rec/u14m_filter/svtrv2_rctc/best.pth + checkpoints: + use_tensorboard: false + infer_img: + # for data or label process + character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en + # ./tools/utils/ppocr_keys_v1.txt # ch + max_text_length: &max_text_length 25 + use_space_char: &use_space_char False + save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_smtr_gtc_rctc.txt + use_amp: True + +Optimizer: + name: AdamW + lr: 0.000325 # for 4gpus bs128/gpu + weight_decay: 0.05 + filter_bias_and_bn: True + +LRScheduler: + name: OneCycleLR + warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep + cycle_momentum: False + +Architecture: + model_type: rec + algorithm: SVTRv2 + in_channels: 3 + Transform: + Encoder: + name: SVTRv2LNConvTwo33 + use_pos_embed: False + dims: [128, 256, 384] + depths: [6, 6, 6] + num_heads: [4, 8, 12] + mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']] + local_k: [[5, 5], [5, 5], [-1, -1]] + sub_k: [[1, 1], [2, 1], [-1, -1]] + last_stage: false + feat2d: True + Decoder: + name: GTCDecoder + infer_gtc: False + detach: False + gtc_decoder: + name: SMTRDecoder + num_layer: 1 + ds: True + max_len: *max_text_length + next_mode: &next True + sub_str_len: &subsl 5 + ctc_decoder: + name: RCTCDecoder + +Loss: + name: CTCLoss + zero_infinity: True + +PostProcess: + name: CTCLabelDecode + character_dict_path: *character_dict_path + use_space_char: *use_space_char + +Metric: + name: RecMetric + main_indicator: acc + is_filter: True + +Train: + dataset: + name: RatioDataSetTVResize + ds_width: True + padding: false + data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging', + '../Union14M-L-LMDB-Filtered/filter_train_hard', + '../Union14M-L-LMDB-Filtered/filter_train_medium', + '../Union14M-L-LMDB-Filtered/filter_train_normal', + '../Union14M-L-LMDB-Filtered/filter_train_easy', + ] + transforms: + - DecodeImagePIL: # load image + img_mode: RGB + - PARSeqAugPIL: + - CTCLabelEncode: # Class handling label + character_dict_path: *character_dict_path + use_space_char: *use_space_char + max_text_length: *max_text_length + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + sampler: + name: RatioSampler + scales: [[128, 32]] # w, h + # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple + first_bs: &bs 128 + fix_bs: false + divided_factor: [4, 16] # w, h + is_training: True + loader: + shuffle: True + batch_size_per_card: *bs + drop_last: True + max_ratio: &max_ratio 12 + num_workers: 4 + +Eval: + dataset: + name: RatioDataSetTVResize + ds_width: True + padding: False + data_dir_list: [ + '../evaluation/CUTE80', + '../evaluation/IC13_857', + '../evaluation/IC15_1811', + '../evaluation/IIIT5k', + '../evaluation/SVT', + '../evaluation/SVTP', + ] + transforms: + - DecodeImagePIL: # load image + img_mode: RGB + - CTCLabelEncode: # Class handling label + character_dict_path: *character_dict_path + use_space_char: *use_space_char + max_text_length: *max_text_length + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + sampler: + name: RatioSampler + scales: [[128, 32]] # w, h + # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple + first_bs: *bs + fix_bs: false + divided_factor: [4, 16] # w, h + is_training: False + loader: + shuffle: False + drop_last: False + batch_size_per_card: *bs + max_ratio: *max_ratio + num_workers: 4 diff --git a/configs/rec/svtrv2/svtrv2_tiny_rctc.yml b/configs/rec/svtrv2/svtrv2_tiny_rctc.yml index 121163e..6e98534 100644 --- a/configs/rec/svtrv2/svtrv2_tiny_rctc.yml +++ b/configs/rec/svtrv2/svtrv2_tiny_rctc.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 10 output_dir: ./output/rec/u14m_filter/svtrv2_tiny_rctc/ - save_epoch_step: 1 + save_epoch_step: [15, 1] # evaluation is run every 2000 iterations eval_epoch_step: [0, 1] eval_batch_step: [0, 500] diff --git a/demo_gradio.py b/demo_gradio.py new file mode 100644 index 0000000..17fe8c2 --- /dev/null +++ b/demo_gradio.py @@ -0,0 +1,207 @@ +# @Author: OpenOCR +# @Contact: 784990967@qq.com +import os +import gradio as gr # gradio==4.20.0 + +os.environ['FLAGS_allocator_strategy'] = 'auto_growth' +import cv2 +import numpy as np +import json +import time +from PIL import Image +from tools.infer_e2e import OpenOCR, check_and_download_font, draw_ocr_box_txt + + +def initialize_ocr(model_type, drop_score): + return OpenOCR(mode=model_type, drop_score=drop_score) + + +# Default model type +model_type = 'mobile' +drop_score = 0.4 +text_sys = initialize_ocr(model_type, drop_score) + +# warm up 5 times +if True: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(5): + res = text_sys(img_numpy=img) + +font_path = './simfang.ttf' +font_path = check_and_download_font(font_path) + + +def main(input_image, + model_type_select, + det_input_size_textbox=960, + rec_drop_score=0.4, + mask_thresh=0.3, + box_thresh=0.6, + unclip_ratio=1.5, + det_score_mode='slow'): + global text_sys, model_type + + # Update OCR model if the model type changes + if model_type_select != model_type: + model_type = model_type_select + text_sys = initialize_ocr(model_type, rec_drop_score) + + img = input_image[:, :, ::-1] + starttime = time.time() + results, time_dict, mask = text_sys( + img_numpy=img, + return_mask=True, + det_input_size=int(det_input_size_textbox), + thresh=mask_thresh, + box_thresh=box_thresh, + unclip_ratio=unclip_ratio, + score_mode=det_score_mode) + elapse = time.time() - starttime + save_pred = json.dumps(results[0], ensure_ascii=False) + image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + boxes = [res['points'] for res in results[0]] + txts = [res['transcription'] for res in results[0]] + scores = [res['score'] for res in results[0]] + draw_img = draw_ocr_box_txt( + image, + boxes, + txts, + scores, + drop_score=rec_drop_score, + font_path=font_path, + ) + mask = mask[0, 0, :, :] > mask_thresh + return save_pred, elapse, draw_img, mask.astype('uint8') * 255 + + +def get_all_file_names_including_subdirs(dir_path): + all_file_names = [] + + for root, dirs, files in os.walk(dir_path): + for file_name in files: + all_file_names.append(os.path.join(root, file_name)) + + file_names_only = [os.path.basename(file) for file in all_file_names] + return file_names_only + + +def list_image_paths(directory): + image_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff') + + image_paths = [] + + for root, dirs, files in os.walk(directory): + for file in files: + if file.lower().endswith(image_extensions): + relative_path = os.path.relpath(os.path.join(root, file), + directory) + full_path = os.path.join(directory, relative_path) + image_paths.append(full_path) + image_paths = sorted(image_paths) + return image_paths + + +def find_file_in_current_dir_and_subdirs(file_name): + for root, dirs, files in os.walk('.'): + if file_name in files: + relative_path = os.path.join(root, file_name) + return relative_path + + +e2e_img_example = list_image_paths('./OCR_e2e_img') + +if __name__ == '__main__': + css = '.image-container img { width: 100%; max-height: 320px;}' + + with gr.Blocks(css=css) as demo: + gr.HTML(""" +

OpenOCR

+

准确高效的通用 OCR 系统 (由FVL实验室 OCR Team 创建) [本地快速部署]

""" + ) + with gr.Row(): + with gr.Column(scale=1): + input_image = gr.Image(label='Input image', + elem_classes=['image-container']) + + examples = gr.Examples(examples=e2e_img_example, + inputs=input_image, + label='Examples') + downstream = gr.Button('Run') + + # 添加参数调节组件 + with gr.Column(): + with gr.Row(): + det_input_size_textbox = gr.Number( + label='Detection Input Size', + value=960, + info='检测网络输入尺寸的最长边,默认为960。') + det_score_mode_dropdown = gr.Dropdown( + ['slow', 'fast'], + value='slow', + label='Detection Score Mode', + info='文本框的置信度计算模式,默认为 slow。slow 模式计算速度较慢,但准确度较高。fast 模式计算速度较快,但准确度较低。' + ) + with gr.Row(): + rec_drop_score_slider = gr.Slider( + 0.0, + 1.0, + value=0.4, + step=0.01, + label='Recognition Drop Score', + info='识别置信度阈值,默认值为0.4。低于该阈值的识别结果和对应的文本框被丢弃。') + mask_thresh_slider = gr.Slider( + 0.0, + 1.0, + value=0.3, + step=0.01, + label='Mask Threshold', + info='Mask 阈值,用于二值化 mask,默认值为0.3。如果存在文本截断时,请调低该值。') + with gr.Row(): + box_thresh_slider = gr.Slider( + 0.0, + 1.0, + value=0.6, + step=0.01, + label='Box Threshold', + info='文本框置信度阈值,默认值为0.6。如果存在文本被漏检时,请调低该值。') + unclip_ratio_slider = gr.Slider( + 1.5, + 2.0, + value=1.5, + step=0.05, + label='Unclip Ratio', + info='文本框解析时的膨胀系数,默认值为1.5。值越大文本框越大。') + + # 模型选择组件 + model_type_dropdown = gr.Dropdown( + ['mobile', 'server'], + value='mobile', + label='Model Type', + info='选择 OCR 模型类型:高效率模型mobile,高精度模型server。') + + with gr.Column(scale=1): + img_mask = gr.Image(label='mask', + interactive=False, + elem_classes=['image-container']) + img_output = gr.Image(label=' ', + interactive=False, + elem_classes=['image-container']) + + output = gr.Textbox(label='Result') + confidence = gr.Textbox(label='Latency') + + downstream.click(fn=main, + inputs=[ + input_image, model_type_dropdown, + det_input_size_textbox, rec_drop_score_slider, + mask_thresh_slider, box_thresh_slider, + unclip_ratio_slider, det_score_mode_dropdown + ], + outputs=[ + output, + confidence, + img_output, + img_mask, + ]) + + demo.launch(share=True) diff --git a/demo_rec.py b/demo_rec.py deleted file mode 100644 index a1c604a..0000000 --- a/demo_rec.py +++ /dev/null @@ -1,131 +0,0 @@ -import os - -import gradio as gr -import numpy as np -import torch - -from openrec.modeling import build_model -from openrec.postprocess import build_post_process -from openrec.preprocess import create_operators, transform -from tools.engine import Config -from tools.utils.ckpt import load_ckpt - - -def build_rec_process(cfg): - transforms = [] - for op in cfg['Eval']['dataset']['transforms']: - op_name = list(op)[0] - if 'Label' in op_name: - continue - # TODO - elif op_name in ['DecodeImage']: - op[op_name]['gradio_infer_mode'] = True - - elif op_name in ['RecResizeImg']: - op[op_name]['infer_mode'] = True - elif op_name == 'KeepKeys': - if cfg['Architecture']['algorithm'] == 'SRN': - op[op_name]['keep_keys'] = [ - 'image', - 'encoder_word_pos', - 'gsrm_word_pos', - 'gsrm_slf_attn_bias1', - 'gsrm_slf_attn_bias2', - ] - elif cfg['Architecture']['algorithm'] == 'SAR': - op[op_name]['keep_keys'] = ['image', 'valid_ratio'] - elif cfg['Architecture']['algorithm'] == 'RobustScanner': - op[op_name]['keep_keys'] = [ - 'image', 'valid_ratio', 'word_positons' - ] - else: - op[op_name]['keep_keys'] = ['image'] - transforms.append(op) - return transforms - - -def get_all_file_names_including_subdirs(dir_path): - all_file_names = [] - - for root, dirs, files in os.walk(dir_path): - for file_name in files: - all_file_names.append(os.path.join(root, file_name)) - - file_names_only = [os.path.basename(file) for file in all_file_names] - return file_names_only - - -root_directory = './configs/rec' -yml_Config = get_all_file_names_including_subdirs(root_directory) - - -def find_file_in_current_dir_and_subdirs(file_name): - for root, dirs, files in os.walk('.'): - if file_name in files: - relative_path = os.path.join(root, file_name) - return relative_path - - -def predict(input_image, Model_type, OCR_type): - - path = find_file_in_current_dir_and_subdirs(Model_type) - - cfg = Config(path).cfg - post_process_class = build_post_process(cfg['PostProcess']) - global_config = cfg['Global'] - char_num = len(getattr(post_process_class, 'character')) - cfg['Architecture']['Decoder']['out_channels'] = char_num - model = build_model(cfg['Architecture']) - load_ckpt(model, cfg) - model.eval() - - transforms = build_rec_process(cfg) - global_config['infer_mode'] = True - ops = create_operators(transforms, global_config) - data = {'image': input_image} - batch = transform(data, ops) - others = None - images = np.expand_dims(batch[0], axis=0) - images = torch.from_numpy(images) - with torch.no_grad(): - preds = model(images, others) - post_result = post_process_class(preds) - return post_result[0][0], post_result[0][1] - - -if __name__ == '__main__': - - with gr.Blocks() as demo: - with gr.Row(): - with gr.Column(scale=1): - input_image = gr.Image(label='Input Image') - - # TODO - OCR_type = gr.Radio(['STR', 'STD', 'E2E'], label='模型类别') - - Model_type = gr.Dropdown(choices=yml_Config, label='现有模型配置文件') - - downstream = gr.Button('识别结果') - - with gr.Column(scale=1): - - # TODO - img_output = gr.Image(label='图片识别结果') - - output = gr.Textbox(label='文字识别结果') - confidence = gr.Textbox(label='置信度') - - downstream.click( - fn=predict, - inputs=[ - input_image, - Model_type, - OCR_type, - ], - outputs=[ - output, - confidence, - # TODO img_output, - ]) - - demo.launch(debug=True) diff --git a/docs/finetune_det.md b/docs/finetune_det.md new file mode 100644 index 0000000..2cacffd --- /dev/null +++ b/docs/finetune_det.md @@ -0,0 +1,159 @@ +# Fine-tuning Text Detection Model of OpenOCR System + +1. [Data and Weights Preparation](#1-data-and-weights-preparation) + - [1.1 Data Preparation](#11-data-preparation) + - [1.2 Download Pre-trained Model](#12-download-pre-trained-model) +2. [Training](#2-training) + - [2.1 Start Training](#21-start-training) + - [2.2 Load Trained Model and Continue Training](#22-load-trained-model-and-continue-training) +3. [Evaluation and Test](#3-evaluation-and-test) + - [3.1 Evaluation](#31-evaluation) + - [3.2 Test](#32-test) +4. [ONNX Inference](#4-onnx-inference) + +______________________________________________________________________ + +## Installation + +#### Dependencies: + +- [PyTorch](http://pytorch.org/) version >= 1.13.0 +- Python version >= 3.7 + +```shell +conda create -n openocr python==3.8 +conda activate openocr +# install gpu version torch +conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia +# or cpu version +conda install pytorch torchvision torchaudio cpuonly -c pytorch +``` + +#### Clone this repository: + +```shell +git clone https://github.com/Topdu/OpenOCR.git +cd OpenOCR +pip install albumentations +pip install -r requirements.txt +``` + +This section uses the icdar2015 dataset as an example to introduce the training, evaluation, and testing of the detection model in OpenOCR. + +## 1. Data and Weights Preparation + +### 1.1 Data Preparation + +**Note:** If you want to use your own dataset, please following the format of [icdar2015 dataset](https://aistudio.baidu.com/datasetdetail/46088). + +Downloading datasets from [icdar2015 dataset](https://aistudio.baidu.com/datasetdetail/46088)/[Google Drive](https://drive.google.com/file/d/1nfsYj-JzAqVouZPBDqmuP0Rkj6J6XFUJ/view?usp=sharing). + +#### File Directory + +``` +OpenOCR/ +icdar2015/text_localization/ + └─ icdar_c4_train_imgs/ Training data of the icdar dataset + └─ ch4_test_images/ Testing data of the icdar dataset + └─ train_icdar2015_label.txt Training annotations of the icdar dataset + └─ test_icdar2015_label.txt Testing annotations of the icdar dataset +``` + +The provided annotation file format is as follows, where the fields are separated by "\\t": + +``` +"Image file name json.dumps encoded image annotation information" +ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}] +``` + +Before being encoded with `json.dumps`, the image annotation information is a list containing multiple dictionaries. In each dictionary, the field `points` represents the coordinates (x, y) of the four corners of the text bounding box, arranged in a clockwise order starting from the top-left corner. The field `transcription` indicates the text content within the current bounding box. + +To modify the training and evaluation dataset paths in the configuration file `./configs/det/dbnet/repvit_db.yml` to your own dataset paths, for example: + +```yaml +Train: + dataset: + name: SimpleDataSet + data_dir: ../icdar2015/text_localization/ # Root directory of the training dataset + label_file_list: ["../icdar2015/text_localization/train_icdar2015_label.txt"] # Path to the training label file + ...... +Eval: + dataset: + name: SimpleDataSet + data_dir: ../icdar2015/text_localization/ # Root directory of the evaluation dataset + label_file_list: ["../icdar2015/text_localization/test_icdar2015_label.txt"] # Path to the evaluation label file +``` + +### 1.2 Download Pre-trained Model + +First download the pre-trained model. + +```bash +cd OpenOCR/ +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth +``` + +______________________________________________________________________ + +## 2. Training + +### 2.1 Start Training + +```bash +# multi-GPU training +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_det.py --c configs/det/dbnet/repvit_db.yml --o Global.pretrained_model=./openocr_det_repvit_ch.pth +# single GPU training +CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 tools/train_det.py --c configs/det/dbnet/repvit_db.yml --o Global.pretrained_model=./openocr_det_repvit_ch.pth +``` + +### 2.2 Load Trained Model and Continue Training + +If you expect to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded. + +For example: + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_det.py --c configs/det/dbnet/repvit_db.yml --o Global.checkpoints=./your/trained/model +``` + +**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrained_model`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrained_model` will be loaded. + +______________________________________________________________________ + +## 3. Evaluation and Test + +### 3.1 Evaluation + +OpenOCR calculates three indicators for evaluating performance of OCR detection task: Precision, Recall, and Hmean(F-Score). + +```bash +python tools/eval_det.py --c configs/det/dbnet/repvit_db.yml --o Global.pretrained_model="{path/to/weights}/best.pth" +``` + +### 3.2 Test + +Test the detection result on all images in the folder or a single image: + +```bash +python tools/infer_det.py --c ./configs/det/dbnet/repvit_db.yml --o Global.infer_img=/path/img_fold or /path/img_file Global.pretrained_model={path/to/weights}/best.pth +``` + +______________________________________________________________________ + +## 4. ONNX Inference + +Firstly, we can convert Detection model to onnx model: + +```bash +pip install onnx +python tools/toonnx.py --c ./configs/det/dbnet/repvit_db.yml --o Global.device=cpu Global.pretrained_model={path/to/weights}/best.pth +``` + +The onnx model is saved in `./output/det_repsvtr_db/export_det/det_model.onnx`. + +The detection onnx model inference: + +```bash +pip install onnxruntime +python tools/infer_det.py --c ./configs/det/dbnet/repvit_db.yml --o Global.backend=onnx Global.device=cpu Global.infer_img=/path/img_fold or /path/img_file Global.onnx_model_path=./output/det_repsvtr_db/export_det/det_model.onnx +``` diff --git a/docs/finetune_rec.md b/docs/finetune_rec.md new file mode 100644 index 0000000..27c162e --- /dev/null +++ b/docs/finetune_rec.md @@ -0,0 +1,150 @@ +# Fine-tuning Text Recognition Model of OpenOCR system + +1. [Data and Weights Preparation](#1-data-and-weights-preparation) + - [1.1 Data Preparation](#11-data-preparation) + - [1.2 Download Pre-trained Model](#12-download-pre-trained-model) +2. [Training](#2-training) + - [2.1 Start Training](#21-start-training) + - [2.2 Load Trained Model and Continue Training](#22-load-trained-model-and-continue-training) +3. [Evaluation and Test](#3-evaluation-and-test) + - [3.1 Evaluation](#31-evaluation) + - [3.2 Test](#32-test) +4. [ONNX Inference](#4-onnx-inference) + +## Installation + +#### Dependencies: + +- [PyTorch](http://pytorch.org/) version >= 1.13.0 +- Python version >= 3.7 + +```shell +conda create -n openocr python==3.8 +conda activate openocr +# install gpu version torch +conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia +# or cpu version +conda install pytorch torchvision torchaudio cpuonly -c pytorch +``` + +#### Clone this repository: + +```shell +git clone https://github.com/Topdu/OpenOCR.git +cd OpenOCR +pip install -r requirements.txt +``` + +This section uses the icdar2015 recognition dataset as an example to introduce the training, evaluation, and testing of the recognition model in OpenOCR. + +## 1. Data and Weights Preparation + +### 1.1 Data Preparation + +**Note:** If you want to use your own dataset, please following the following data format. + +Downloading datasets from [icdar2015 recognition dataset](https://aistudio.baidu.com/datasetdetail/75418)/[Google Drive](https://drive.google.com/file/d/1YviGN_f7xrRrMOSR4OGwv7uhKFjnxuUP/view?usp=sharing). + +#### File Directory + +``` +OpenOCR/ +ic15_data/ + └─ test/ Training data of the icdar dataset + └─ train/ Testing data of the icdar dataset + └─ rec_gt_test.txt Training annotations of the icdar dataset + └─ rec_gt_train.txt Testing annotations of the icdar dataset +``` + +The provided annotation file format is as follows, where the fields are separated by "\\t": + +``` +"Image file name label" +test/word_2077.png Underpass +``` + +To modify the training and evaluation dataset paths in the configuration file `./configs/rec/svtrv2/repsvtr_ch.yml` to your own dataset paths, for example: + +```yaml +Train: + dataset: + name: SimpleDataSet + data_dir: ../ic15_data/ # Root directory of the training dataset + label_file_list: ["../ic15_data/rec_gt_train.txt"] # Path to the training label file + ...... +Eval: + dataset: + name: SimpleDataSet + data_dir: ../ic15_data # Root directory of the evaluation dataset + label_file_list: ["../ic15_data/rec_gt_test.txt"] # Path to the evaluation label file +``` + +### 1.2 Download Pre-trained Model + +First download the pre-trained model. + +```bash +cd OpenOCR/ +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth +# Rec Server model +# wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth +``` + +## 2. Training + +### 2.1 Start Training + +```bash +# multi-GPU training +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.pretrained_model=./openocr_repsvtr_ch.pth +# single GPU training +CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 tools/train_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.pretrained_model=./openocr_repsvtr_ch.pth +``` + +### 2.2 Load Trained Model and Continue Training + +If you expect to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded. + +For example: + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.checkpoints=./your/trained/model +``` + +**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrained_model`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrained_model` will be loaded. + +## 3. Evaluation and Test + +### 3.1 Evaluation + +OpenOCR calculates the word accuracy for evaluating performance of OCR recognition task. + +```bash +python tools/eval_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.pretrained_model="{path/to/weights}/best.pth" +``` + +### 3.2 Test + +Test the recognition result on all images in the folder or a single image: + +```bash +python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.infer_img=/path/img_fold or /path/img_file Global.pretrained_model={path/to/weights}/best.pth +``` + +## 4. ONNX Inference + +Firstly, we can convert recognition model to onnx model: + +```bash +pip install onnx +python tools/toonnx.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.device=cpu Global.pretrained_model={path/to/weights}/best.pth +``` + +The onnx model is saved in `./output/rec/repsvtr_ch/export_rec/rec_model.onnx`. + +The recognition onnx model infernce: + +```bash +pip install onnxruntime +python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.backend=onnx Global.device=cpu Global.infer_img=/path/img_fold or /path/img_file Global.onnx_model_path=./output/rec/repsvtr_ch/export_rec/rec_model.onnx +``` diff --git a/docs/openocr.md b/docs/openocr.md index 4129747..137198b 100644 --- a/docs/openocr.md +++ b/docs/openocr.md @@ -1,29 +1,58 @@ -# OpenOCR: A general OCR system for accuracy and efficiency +# OpenOCR: A general OCR system with accuracy and efficiency -We proposed strategies to comprehensively enhance CTC-based STR models and developed a novel CTC-based method, [SVTRv2](../configs/rec/svtrv2/). SVTRv2 can outperform previous attention-based STR methods in terms of accuracy while maintaining the advantages of CTC, such as fast inference and robust recognition of long text sequences. These features make SVTRv2 particularly well-suited for commercial applications. To this end, building on SVTRv2, we develop a practical version of the model from scratch on publicly available Chinese and English datasets. Combined with a detection model, this forms an accurate and efficient general OCR system, OpenOCR. Comparing with PP-OCRv4 released by PaddleOCR, OpenOCR achieve a 4.5% improvement on the [OCR competition leaderboard](https://aistudio.baidu.com/competition/detail/1131/0/leaderboard). +⚡\[[Quick Start](#quick-start)\] \[[Model](https://github.com/Topdu/OpenOCR/releases/tag/develop0.0.1)\] \[[ModelScope Demo](https://modelscope.cn/studios/topdktu/OpenOCR-Demo)\] \[[Hugging Face Demo](https://huggingface.co/spaces/topdu/OpenOCR-Demo)\] \[[Local Demo](#local-demo)\] \[[PaddleOCR Implementation](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/text_recognition/algorithm_rec_svtrv2.html)\] + +We proposed strategies to comprehensively enhance CTC-based STR models and developed a novel CTC-based method, [SVTRv2](../configs/rec/svtrv2/). SVTRv2 can outperform previous attention-based STR methods in terms of accuracy while maintaining the advantages of CTC, such as fast inference and robust recognition of long text. These features make SVTRv2 particularly well-suited for practical applications. To this end, building on SVTRv2, we develop a practical version of the model from scratch on publicly available Chinese and English datasets. Combined with a detection model, this forms a general OCR system with accuracy and efficiency, **OpenOCR**. Comparing with [PP-OCRv4](https://paddlepaddle.github.io/PaddleOCR/latest/ppocr/model_list.html) baseline in the [OCR competition leaderboard](https://aistudio.baidu.com/competition/detail/1131/0/leaderboard), OpenOCR (mobile) achieve a 4.5% improvement in terms of accuracy, while preserving quite similar inference speed on NVIDIA 1080Ti GPU. | Model | Config | E2E Metric | Downloading | | ------------------- | ----------------------------------------------------------------------------------- | ---------- | ---------------------------------------------------------------------------------------- | | PP-OCRv4 | | 62.77% | [PaddleOCR Model List](../../ppocr/model_list.md) | -| SVTRv2 (Rec Server) | [configs/rec/svtrv2/svtrv2_ch.yml](../configs/rec/svtrv2/svtrv2_ch.yml) | 68.81% | [Google Dirve ](https://drive.google.com/file/d/13LXbIVEyx2Aat3X_vVte4JQgQ7yJWdxH/view?usp=drive_link) | -| RepSVTR (Mobile) | [Rec: configs/rec/svtrv2/repsvtr_ch.yml](../configs/rec/svtrv2/repsvtr_ch.yml)
[Det: configs/det/dbnet/repvit_db.yml](../configs/det/dbnet/repvit_db.yml) | 67.22% | [Rec: Google Drive](https://drive.google.com/file/d/1DNfarP_UmTqZnENjmmQHCexqzVmrIfLF/view?usp=drive_link)
[Det: Google Drive](https://drive.google.com/file/d/1eR6k5NitCvFEiGlYx1lAArVupIszfEmM/view?usp=drive_link) | +| SVTRv2 (Rec Server) | [configs/rec/svtrv2/svtrv2_ch.yml](../configs/rec/svtrv2/svtrv2_ch.yml) | 68.81% | [Google Dirve](https://drive.google.com/file/d/13LXbIVEyx2Aat3X_vVte4JQgQ7yJWdxH/view?usp=drive_link), [Github Released](https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth) | +| RepSVTR (Mobile) | [Rec: configs/rec/svtrv2/repsvtr_ch.yml](../configs/rec/svtrv2/repsvtr_ch.yml)
[Det: configs/det/dbnet/repvit_db.yml](../configs/det/dbnet/repvit_db.yml) | 67.22% | [Rec: Google Drive](https://drive.google.com/file/d/1DNfarP_UmTqZnENjmmQHCexqzVmrIfLF/view?usp=drive_link), [Github Released](https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth)
[Det: Google Drive](https://drive.google.com/file/d/1eR6k5NitCvFEiGlYx1lAArVupIszfEmM/view?usp=drive_link), [Github Released](https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth) | ## Quick Start +**Note**: OpenOCR supports inference using both the ONNX and Torch frameworks, with the dependency environments for the two frameworks being isolated. When using ONNX for inference, there is no need to install Torch, and vice versa. + +### 1. ONNX Inference + +#### Install OpenOCR and Dependencies: + +```shell +pip install openocr-python +pip install onnxruntime +``` + +#### Usage: + +```python +from openocr import OpenOCR +onnx_engine = OpenOCR(backend='onnx', device='cpu') +img_path = '/path/img_path or /path/img_file' +result, elapse = onnx_engine(img_path) +``` + +### 2. Pytorch inference + #### Dependencies: - [PyTorch](http://pytorch.org/) version >= 1.13.0 - Python version >= 3.7 ```shell -conda create -n openocre2e python==3.8 -conda activate openocre2e +conda create -n openocr python==3.8 +conda activate openocr +# install gpu version torch conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia +# or cpu version +conda install pytorch torchvision torchaudio cpuonly -c pytorch ``` After installing dependencies, the following two installation methods are available. Either one can be chosen. -#### 1. Python Modules +#### 2.1. Python Modules + +**Install OpenOCR**: ```shell pip install openocr-python @@ -33,24 +62,24 @@ pip install openocr-python ```python from openocr import OpenOCR - engine = OpenOCR() - img_path = '/path/img_path or /path/img_file' result, elapse = engine(img_path) -print(result) -print(elapse) # Server mode -engine = OpenOCR(mode='server') +# engine = OpenOCR(mode='server') ``` -#### 2. Clone this repository: +#### 2.2. Clone this repository: ```shell git clone https://github.com/Topdu/OpenOCR.git cd OpenOCR pip install -r requirements.txt +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth +# Rec Server model +# wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth ``` **Usage**: @@ -58,21 +87,50 @@ pip install -r requirements.txt ```shell # OpenOCR system: Det + Rec model python tools/infer_e2e.py --img_path=/path/img_fold or /path/img_file - # Det model python tools/infer_det.py --c ./configs/det/dbnet/repvit_db.yml --o Global.infer_img=/path/img_fold or /path/img_file - # Rec model python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.infer_img=/path/img_fold or /path/img_file ``` +#### Local Demo + +```shell +pip install gradio==4.20.0 +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/OCR_e2e_img.tar +tar xf OCR_e2e_img.tar +# start demo +python demo_gradio.py +``` + ## Fine-tuning on a Custom dataset -TODO +Referring to [Finetuning Det](./finetune_det.md) and [Finetuning Rec](./finetune_rec.md). ## Exporting to ONNX Engine -TODO +### Export ONNX model + +```shell +pip install onnx +python tools/toonnx.py --c configs/rec/svtrv2/repsvtr_ch.yml --o Global.device=cpu +python tools/toonnx.py --c configs/det/dbnet/repvit_db.yml --o Global.device=cpu +``` + +The det onnx model is saved in `./output/det_repsvtr_db/export_det/det_model.onnx`. +The rec onnx model is saved in `./output/rec/repsvtr_ch/export_rec/rec_model.onnx`. + +### Inference with ONNXRuntime + +```shell +pip install onnxruntime +# OpenOCR system: Det + Rec model +python tools/infer_e2e.py --img_path=/path/img_fold or /path/img_file --backend=onnx --device=cpu --onnx_det_model_path=./output/det_repsvtr_db/export_det/det_model.onnx --onnx_rec_model_path=output/rec/repsvtr_ch/export_rec/rec_model.onnx +# Det model +python tools/infer_det.py --c ./configs/det/dbnet/repvit_db.yml --o Global.backend=onnx Global.device=cpu Global.infer_img=/path/img_fold or /path/img_file Global.onnx_model_path=./output/det_repsvtr_db/export_det/det_model.onnx +# Rec model +python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.backend=onnx Global.device=cpu Global.infer_img=/path/img_fold or /path/img_file Global.onnx_model_path=./output/rec/repsvtr_ch/export_rec/rec_model.onnx +``` ## Results Showcase @@ -88,6 +146,50 @@ TODO -The results show that OpenOCR’s detection model outperforms PP-OCRv4 in generating more complete and accurate text boundaries, effectively capturing entire text instances. This reflects its larger receptive field and its ability to avoid common issues like merging separate text instances or splitting a single instance into multiple fragments. +### Det + Rec System results + +
+ +
+
+ +
+
+ +
+ +### **Detection Model Performance** + +In the examples provided, OpenOCR's detection model generates bounding boxes that are generally more comprehensive and better aligned with the boundaries of text instances compared to PP-OCRv4. In addition, OpenOCR excels in distinguishing separate text instances, avoiding errors such as merging two distinct text instances into one or splitting a single instance into multiple parts. This indicates superior handling of **semantic completeness and spatial understanding**, making it particularly effective for complex layouts. + +### **Recognition Model Generalization** + +OpenOCR's recognition model demonstrates enhanced generalization capabilities when compared to PP-OCRv4. It performs exceptionally well in recognizing text under difficult conditions, such as: + +- Artistic or stylized fonts. +- Handwritten text. +- Blurry or low-resolution images. +- Incomplete or occluded text. -In terms of recognition, OpenOCR demonstrates superior adaptability to challenging scenarios, such as artistic fonts, handwriting, blur, low resolution, incomplete text, and occlusion. Notably, the OpenOCR mobile model performs at a level comparable to PP-OCRv4's larger server-side model, showcasing its efficiency and robustness. +Remarkably, the **OpenOCR mobile recognition model** delivers results comparable to the larger and more resource-intensive **PP-OCRv4 server model**. This highlights OpenOCR's efficiency and accuracy, making it a versatile solution across different hardware platforms. + +### **System used in Real-World Scenarios** + +As shown in Det + Rec System results, OpenOCR demonstrates outstanding performance in practical scenarios, including documents, tables, invoices, and similar contexts. This underscores its potential as a **general-purpose OCR system**. It is capable of adapting to diverse use cases with high accuracy and reliability. + +## Citation + +If you find our method useful for your reserach, please cite: + +```bibtex +@article{Du2024SVTRv2, + title={SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition}, + author={Yongkun Du and Zhineng Chen and Hongtao Xie and Caiyan Jia and Yu-Gang Jiang}, + journal={CoRR}, + volume={abs/2411.15858}, + eprinttype={arXiv}, + year={2024}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2411.15858} +} +``` diff --git a/docs/svtrv2.md b/docs/svtrv2.md index bb2be79..a0d6c4b 100644 --- a/docs/svtrv2.md +++ b/docs/svtrv2.md @@ -1,6 +1,6 @@ # SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition -\[[Paper](../configs/rec/svtrv2/SVTRv2.pdf)\] \[[Model](../configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Config, Training and Inference](../configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] +\[[Paper](https://arxiv.org/abs/2411.15858)\] \[[Doc](../configs/rec/svtrv2/)\] \[[Model](../configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Datasets](#downloading-datasets)\] \[[Config, Training and Inference](../configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] \[[Benchmark](#results-benchmark--configs--checkpoints)\] ## Introduction @@ -190,7 +190,7 @@ Union14M-L-LMDB-Filtered # lmdb format Union14M-L-Filtered | Union14M-L-Filter | [LMDB archives](https://drive.google.com/drive/folders/1OlDWJZgvd6s4S09S3IGeAI90jI0i7AB_?usp=sharing) | | | Evaluation | [LMDB archives](https://drive.google.com/drive/folders/1EW0_YvmRSdpVOkR2guTQFrGz7KNqNc66?usp=drive_link) | | -If you have downloaded Union14M-L, you can use [the filtered list of images](https://drive.google.com/drive/folders/1x1LC8C_W-Frl3sGV9i9_i_OD-bqNdodJ?usp=drive_link) to create an LMDB of the training set Union14M-L-Filter. +If you have downloaded Union14M-L, you can use [the filtered list of images](https://drive.google.com/drive/folders/1x1LC8C_W-Frl3sGV9i9_i_OD-bqNdodJ?usp=drive_link) to create an LMDB of the training set Union14M-L-Filter (detailed in [create_lmdb_dataset.py](../tools/create_lmdb_dataset.py)). #### **Test Set** @@ -211,9 +211,9 @@ If you have downloaded Union14M-L, you can use [the filtered list of images](htt ```shell # Multi GPU training -CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_maxratio12.yml +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml # For Multi RTX 4090 -NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_maxratio12.yml +NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml # 20epoch runs for about 6 hours ``` @@ -221,9 +221,7 @@ NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.laun ```shell # short text: Common, Union14M-Benchmark, OST -python tools/eval_rec_all_ratio.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml -# long text -python tools/eval_rec_all_long_simple.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +python tools/eval_rec_all_en.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml ``` After a successful run, the results are saved in a csv file in `output_dir` in the config file. @@ -242,23 +240,9 @@ Firstly, downloading the IIIT5K images from [Google Drive](https://drive.google. python tools/infer_rec.py --c configs/rec/svtrv2/svtrv2_rctc.yml --o Global.infer_img=../iiit5k_test_image ``` -## Results & Configs & Checkpoints: +## Results (Benchmark) & Configs & Checkpoints: -Downloading all model checkpoints from [Google Drive](<>) and [Baidu Yun](<>). - - +(TODO) Downloading all model checkpoints from [Google Drive](<>) and [Baidu Yun](<>). @@ -994,25 +978,10 @@ Downloading all model checkpoints from [Google Drive](<>) and [Baidu Yun](<>).
-**Note**: TF$\_n$ denotes the $n$-layer Transformer block. $Size$ denotes the number of parameters ($M$). $Latency$ is measured on one NVIDIA 1080Ti GPU with Pytorch Dynamic mode. +**Note**: TF$\_n$ denotes the $n$-layer Transformer block. $Size$ denotes the number of parameters ($M$). $Latency$ is measured on one NVIDIA 1080Ti GPU with Pytorch dynamic graph mode. ## Results when trained on synthetic datasets ($ST$ + $MJ$). - - @@ -1678,10 +1647,17 @@ Downloading all model checkpoints from [Google Drive](<>) and [Baidu Yun](<>). ## Citation +If you find our method useful for your reserach, please cite: + ```bibtex -@article{Du2024SVTRv4, - title = {SVTRv2: Scene Text Recognition with a Single Visual Model}, - author = {Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang}, - year = {2024} +@article{Du2024SVTRv2, + title={SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition}, + author={Yongkun Du and Zhineng Chen and Hongtao Xie and Caiyan Jia and Yu-Gang Jiang}, + journal={CoRR}, + volume={abs/2411.15858}, + eprinttype={arXiv}, + year={2024}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2411.15858} } ``` diff --git a/opendet/losses/__init__.py b/opendet/losses/__init__.py new file mode 100644 index 0000000..f3dd72c --- /dev/null +++ b/opendet/losses/__init__.py @@ -0,0 +1,20 @@ +import copy +from importlib import import_module + +name_to_module = { + 'DBLoss': '.db_loss', +} + + +def build_loss(config): + config = copy.deepcopy(config) + module_name = config.pop('name') + assert module_name in name_to_module, Exception( + '{} is not supported. The losses in {} are supportes'.format( + module_name, list(name_to_module.keys()))) + + module_path = name_to_module[module_name] + module = import_module(module_path, package=__package__) + module_class = getattr(module, module_name) + + return module_class(**config) diff --git a/opendet/losses/db_loss.py b/opendet/losses/db_loss.py new file mode 100644 index 0000000..7563dd6 --- /dev/null +++ b/opendet/losses/db_loss.py @@ -0,0 +1,87 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/DB_loss.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +from torch import nn + +from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss + + +class DBLoss(nn.Module): + """ + Differentiable Binarization (DB) Loss Function + args: + param (dict): the super paramter for DB Loss + """ + + def __init__(self, + balance_loss=True, + main_loss_type='DiceLoss', + alpha=5, + beta=10, + ohem_ratio=3, + eps=1e-6, + **kwargs): + super(DBLoss, self).__init__() + self.alpha = alpha + self.beta = beta + self.dice_loss = DiceLoss(eps=eps) + self.l1_loss = MaskL1Loss(eps=eps) + self.bce_loss = BalanceLoss(balance_loss=balance_loss, + main_loss_type=main_loss_type, + negative_ratio=ohem_ratio) + + def forward(self, predicts, labels): + predict_maps = predicts['maps'] + label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[ + 1:] + shrink_maps = predict_maps[:, 0, :, :] + threshold_maps = predict_maps[:, 1, :, :] + binary_maps = predict_maps[:, 2, :, :] + + loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map, + label_shrink_mask) + loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map, + label_threshold_mask) + loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map, + label_shrink_mask) + loss_shrink_maps = self.alpha * loss_shrink_maps + loss_threshold_maps = self.beta * loss_threshold_maps + # CBN loss + if 'distance_maps' in predicts.keys(): + # distance_maps = predicts['distance_maps'] + cbn_maps = predicts['cbn_maps'] + cbn_loss = self.bce_loss(cbn_maps[:, 0, :, :], label_shrink_map, + label_shrink_mask) + else: + # dis_loss = torch.tensor([0.]) + cbn_loss = torch.tensor([0.], device=predict_maps.device) + + loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps + losses = { + 'loss': loss_all + cbn_loss, + 'loss_shrink_maps': loss_shrink_maps, + 'loss_threshold_maps': loss_threshold_maps, + 'loss_binary_maps': loss_binary_maps, + 'loss_cbn': cbn_loss + } + return losses diff --git a/opendet/losses/det_basic_loss.py b/opendet/losses/det_basic_loss.py new file mode 100644 index 0000000..66a2f6f --- /dev/null +++ b/opendet/losses/det_basic_loss.py @@ -0,0 +1,159 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/basic_loss.py +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +from torch import nn +import torch.nn.functional as F + + +class BalanceLoss(nn.Module): + + def __init__( + self, + balance_loss=True, + main_loss_type='DiceLoss', + negative_ratio=3, + return_origin=False, + eps=1e-6, + **kwargs, + ): + """ + The BalanceLoss for Differentiable Binarization text detection + args: + balance_loss (bool): whether balance loss or not, default is True + main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss', + 'Euclidean','BCELoss', 'MaskL1Loss'], default is 'DiceLoss'. + negative_ratio (int|float): float, default is 3. + return_origin (bool): whether return unbalanced loss or not, default is False. + eps (float): default is 1e-6. + """ + super(BalanceLoss, self).__init__() + self.balance_loss = balance_loss + self.main_loss_type = main_loss_type + self.negative_ratio = negative_ratio + self.return_origin = return_origin + self.eps = eps + + if self.main_loss_type == 'CrossEntropy': + self.loss = nn.CrossEntropyLoss() + elif self.main_loss_type == 'Euclidean': + self.loss = nn.MSELoss() + elif self.main_loss_type == 'DiceLoss': + self.loss = DiceLoss(self.eps) + elif self.main_loss_type == 'BCELoss': + self.loss = BCELoss(reduction='none') + elif self.main_loss_type == 'MaskL1Loss': + self.loss = MaskL1Loss(self.eps) + else: + loss_type = [ + 'CrossEntropy', + 'DiceLoss', + 'Euclidean', + 'BCELoss', + 'MaskL1Loss', + ] + raise Exception( + 'main_loss_type in BalanceLoss() can only be one of {}'.format( + loss_type)) + + def forward(self, pred, gt, mask=None): + """ + The BalanceLoss for Differentiable Binarization text detection + args: + pred (variable): predicted feature maps. + gt (variable): ground truth feature maps. + mask (variable): masked maps. + return: (variable) balanced loss + """ + positive = gt * mask + negative = (1 - gt) * mask + + positive_count = int(positive.sum()) + negative_count = int( + min(negative.sum(), positive_count * self.negative_ratio)) + loss = self.loss(pred, gt, mask=mask) + + if not self.balance_loss: + return loss + + positive_loss = positive * loss + negative_loss = negative * loss + negative_loss = negative_loss.reshape(-1) + if negative_count > 0: + negative_loss, _ = negative_loss.topk(negative_count) + balance_loss = (positive_loss.sum() + negative_loss.sum()) / ( + positive_count + negative_count + self.eps) + else: + balance_loss = positive_loss.sum() / (positive_count + self.eps) + if self.return_origin: + return balance_loss, loss + + return balance_loss + + +class DiceLoss(nn.Module): + + def __init__(self, eps=1e-6): + super(DiceLoss, self).__init__() + self.eps = eps + + def forward(self, pred, gt, mask, weights=None): + """ + DiceLoss function. + """ + + assert pred.shape == gt.shape + assert pred.shape == mask.shape + if weights is not None: + assert weights.shape == mask.shape + mask = weights * mask + intersection = torch.sum(pred * gt * mask) + + union = torch.sum(pred * mask) + torch.sum(gt * mask) + self.eps + loss = 1 - 2.0 * intersection / union + assert loss <= 1 + return loss + + +class MaskL1Loss(nn.Module): + + def __init__(self, eps=1e-6): + super(MaskL1Loss, self).__init__() + self.eps = eps + + def forward(self, pred, gt, mask): + """ + Mask L1 Loss + """ + loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps) + loss = torch.mean(loss) + return loss + + +class BCELoss(nn.Module): + + def __init__(self, reduction='mean'): + super(BCELoss, self).__init__() + self.reduction = reduction + + def forward(self, input, label, mask=None, weight=None, name=None): + loss = F.binary_cross_entropy(input, label, reduction=self.reduction) + return loss diff --git a/opendet/metrics/__init__.py b/opendet/metrics/__init__.py new file mode 100644 index 0000000..6a09ae5 --- /dev/null +++ b/opendet/metrics/__init__.py @@ -0,0 +1,16 @@ +import copy + +__all__ = ['build_metric'] + +from .det_metric import DetMetric + +support_dict = ['DetMetric'] + + +def build_metric(config): + config = copy.deepcopy(config) + module_name = config.pop('name') + assert module_name in support_dict, Exception( + 'metric only support {}'.format(support_dict)) + module_class = eval(module_name)(**config) + return module_class diff --git a/opendet/metrics/det_metric.py b/opendet/metrics/det_metric.py new file mode 100644 index 0000000..276fb8c --- /dev/null +++ b/opendet/metrics/det_metric.py @@ -0,0 +1,156 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__all__ = ['DetMetric', 'DetFCEMetric'] + +from .eval_det_iou import DetectionIoUEvaluator + + +class DetMetric(object): + + def __init__(self, main_indicator='hmean', **kwargs): + self.evaluator = DetectionIoUEvaluator() + self.main_indicator = main_indicator + self.reset() + + def __call__(self, preds, batch, **kwargs): + """ + batch: a list produced by dataloaders. + image: np.ndarray of shape (N, C, H, W). + ratio_list: np.ndarray of shape(N,2) + polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions. + ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not. + preds: a list of dict produced by post process + points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions. + """ + gt_polyons_batch = batch[2] + ignore_tags_batch = batch[3] + for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch, + ignore_tags_batch): + # prepare gt + gt_info_list = [{ + 'points': gt_polyon, + 'text': '', + 'ignore': ignore_tag + } for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)] + # prepare det + det_info_list = [{ + 'points': det_polyon, + 'text': '' + } for det_polyon in pred['points']] + result = self.evaluator.evaluate_image(gt_info_list, det_info_list) + self.results.append(result) + + def get_metric(self): + """ + return metrics { + 'precision': 0, + 'recall': 0, + 'hmean': 0 + } + """ + + metrics = self.evaluator.combine_results(self.results) + self.reset() + return metrics + + def reset(self): + self.results = [] # clear results + + +class DetFCEMetric(object): + + def __init__(self, main_indicator='hmean', **kwargs): + self.evaluator = DetectionIoUEvaluator() + self.main_indicator = main_indicator + self.reset() + + def __call__(self, preds, batch, **kwargs): + """ + batch: a list produced by dataloaders. + image: np.ndarray of shape (N, C, H, W). + ratio_list: np.ndarray of shape(N,2) + polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions. + ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not. + preds: a list of dict produced by post process + points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions. + """ + gt_polyons_batch = batch[2] + ignore_tags_batch = batch[3] + + for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch, + ignore_tags_batch): + # prepare gt + gt_info_list = [{ + 'points': gt_polyon, + 'text': '', + 'ignore': ignore_tag + } for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)] + # prepare det + det_info_list = [{ + 'points': det_polyon, + 'text': '', + 'score': score + } for det_polyon, score in zip(pred['points'], pred['scores'])] + + for score_thr in self.results.keys(): + det_info_list_thr = [ + det_info for det_info in det_info_list + if det_info['score'] >= score_thr + ] + result = self.evaluator.evaluate_image(gt_info_list, + det_info_list_thr) + self.results[score_thr].append(result) + + def get_metric(self): + """ + return metrics {'heman':0, + 'thr 0.3':'precision: 0 recall: 0 hmean: 0', + 'thr 0.4':'precision: 0 recall: 0 hmean: 0', + 'thr 0.5':'precision: 0 recall: 0 hmean: 0', + 'thr 0.6':'precision: 0 recall: 0 hmean: 0', + 'thr 0.7':'precision: 0 recall: 0 hmean: 0', + 'thr 0.8':'precision: 0 recall: 0 hmean: 0', + 'thr 0.9':'precision: 0 recall: 0 hmean: 0', + } + """ + metrics = {} + hmean = 0 + for score_thr in self.results.keys(): + metric = self.evaluator.combine_results(self.results[score_thr]) + # for key, value in metric.items(): + # metrics['{}_{}'.format(key, score_thr)] = value + metric_str = 'precision:{:.5f} recall:{:.5f} hmean:{:.5f}'.format( + metric['precision'], metric['recall'], metric['hmean']) + metrics['thr {}'.format(score_thr)] = metric_str + hmean = max(hmean, metric['hmean']) + metrics['hmean'] = hmean + + self.reset() + return metrics + + def reset(self): + self.results = { + 0.3: [], + 0.4: [], + 0.5: [], + 0.6: [], + 0.7: [], + 0.8: [], + 0.9: [], + } # clear results diff --git a/opendet/metrics/eval_det_iou.py b/opendet/metrics/eval_det_iou.py new file mode 100644 index 0000000..8cfb8c1 --- /dev/null +++ b/opendet/metrics/eval_det_iou.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python +import numpy as np +from shapely.geometry import Polygon +""" +reference from : +https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8 +""" + + +class DetectionIoUEvaluator(object): + + def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5): + self.iou_constraint = iou_constraint + self.area_precision_constraint = area_precision_constraint + + def evaluate_image(self, gt, pred): + + def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + def compute_ap(confList, matchList, numGtCare): + correct = 0 + AP = 0 + if len(confList) > 0: + confList = np.array(confList) + matchList = np.array(matchList) + sorted_ind = np.argsort(-confList) + confList = confList[sorted_ind] + matchList = matchList[sorted_ind] + for n in range(len(confList)): + match = matchList[n] + if match: + correct += 1 + AP += float(correct) / (n + 1) + + if numGtCare > 0: + AP /= numGtCare + + return AP + + perSampleMetrics = {} + + matchedSum = 0 + + numGlobalCareGt = 0 + numGlobalCareDet = 0 + + precision = 0 + + detMatched = 0 + + iouMat = np.empty([1, 1]) + + gtPols = [] + detPols = [] + + gtPolPoints = [] + detPolPoints = [] + + # Array of Ground Truth Polygons' keys marked as don't Care + gtDontCarePolsNum = [] + # Array of Detected Polygons' matched with a don't Care GT + detDontCarePolsNum = [] + + pairs = [] + detMatchedNums = [] + + evaluationLog = '' + + for n in range(len(gt)): + points = gt[n]['points'] + dontCare = gt[n]['ignore'] + if not Polygon(points).is_valid: + continue + + gtPol = points + gtPols.append(gtPol) + gtPolPoints.append(points) + if dontCare: + gtDontCarePolsNum.append(len(gtPols) - 1) + + evaluationLog += ( + 'GT polygons: ' + str(len(gtPols)) + + (' (' + str(len(gtDontCarePolsNum)) + + " don't care)\n" if len(gtDontCarePolsNum) > 0 else '\n')) + + for n in range(len(pred)): + points = pred[n]['points'] + if not Polygon(points).is_valid: + continue + + detPol = points + detPols.append(detPol) + detPolPoints.append(points) + if len(gtDontCarePolsNum) > 0: + for dontCarePol in gtDontCarePolsNum: + dontCarePol = gtPols[dontCarePol] + intersected_area = get_intersection(dontCarePol, detPol) + pdDimensions = Polygon(detPol).area + precision = (0 if pdDimensions == 0 else intersected_area / + pdDimensions) + if precision > self.area_precision_constraint: + detDontCarePolsNum.append(len(detPols) - 1) + break + + evaluationLog += ( + 'DET polygons: ' + str(len(detPols)) + + (' (' + str(len(detDontCarePolsNum)) + + " don't care)\n" if len(detDontCarePolsNum) > 0 else '\n')) + + if len(gtPols) > 0 and len(detPols) > 0: + # Calculate IoU and precision matrixs + outputShape = [len(gtPols), len(detPols)] + iouMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtPols), np.int8) + detRectMat = np.zeros(len(detPols), np.int8) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = gtPols[gtNum] + pD = detPols[detNum] + iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG) + + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + if (gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 + and gtNum not in gtDontCarePolsNum + and detNum not in detDontCarePolsNum): + if iouMat[gtNum, detNum] > self.iou_constraint: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + detMatched += 1 + pairs.append({'gt': gtNum, 'det': detNum}) + detMatchedNums.append(detNum) + evaluationLog += ('Match GT #' + str(gtNum) + + ' with Det #' + str(detNum) + + '\n') + + numGtCare = len(gtPols) - len(gtDontCarePolsNum) + numDetCare = len(detPols) - len(detDontCarePolsNum) + if numGtCare == 0: + precision = float(0) if numDetCare > 0 else float(1) + else: + precision = 0 if numDetCare == 0 else float( + detMatched) / numDetCare + + matchedSum += detMatched + numGlobalCareGt += numGtCare + numGlobalCareDet += numDetCare + + perSampleMetrics = { + 'gtCare': numGtCare, + 'detCare': numDetCare, + 'detMatched': detMatched, + } + return perSampleMetrics + + def combine_results(self, results): + numGlobalCareGt = 0 + numGlobalCareDet = 0 + matchedSum = 0 + for result in results: + numGlobalCareGt += result['gtCare'] + numGlobalCareDet += result['detCare'] + matchedSum += result['detMatched'] + + methodRecall = (0 if numGlobalCareGt == 0 else float(matchedSum) / + numGlobalCareGt) + methodPrecision = (0 if numGlobalCareDet == 0 else float(matchedSum) / + numGlobalCareDet) + methodHmean = (0 if methodRecall + methodPrecision == 0 else 2 * + methodRecall * methodPrecision / + (methodRecall + methodPrecision)) + methodMetrics = { + 'precision': methodPrecision, + 'recall': methodRecall, + 'hmean': methodHmean, + } + + return methodMetrics + + +if __name__ == '__main__': + evaluator = DetectionIoUEvaluator() + gts = [[ + { + 'points': [(0, 0), (1, 0), (1, 1), (0, 1)], + 'text': 1234, + 'ignore': False, + }, + { + 'points': [(2, 2), (3, 2), (3, 3), (2, 3)], + 'text': 5678, + 'ignore': False, + }, + ]] + preds = [[{ + 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)], + 'text': 123, + 'ignore': False, + }]] + results = [] + for gt, pred in zip(gts, preds): + results.append(evaluator.evaluate_image(gt, pred)) + metrics = evaluator.combine_results(results) + print(metrics) diff --git a/opendet/postprocess/db_postprocess.py b/opendet/postprocess/db_postprocess.py index dd6b199..76c9f14 100644 --- a/opendet/postprocess/db_postprocess.py +++ b/opendet/postprocess/db_postprocess.py @@ -1,6 +1,5 @@ import numpy as np import cv2 -import torch from shapely.geometry import Polygon import pyclipper """ @@ -208,16 +207,21 @@ def box_score_slow(self, bitmap, contour): cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype('int32'), 1) return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] - def __call__(self, outs_dict, shape_list): + def __call__(self, outs_dict, batch, **kwargs): + self.thresh = kwargs.get('mask_thresh', self.thresh) + self.box_thresh = kwargs.get('box_thresh', self.box_thresh) + self.unclip_ratio = kwargs.get('unclip_ratio', self.unclip_ratio) + self.box_type = kwargs.get('box_type', self.box_type) + self.score_mode = kwargs.get('score_mode', self.score_mode) pred = outs_dict['maps'] - if isinstance(pred, torch.Tensor): + if kwargs.get('torch_tensor', True): pred = pred.detach().cpu().numpy() pred = pred[:, 0, :, :] segmentation = pred > self.thresh boxes_batch = [] for batch_index in range(pred.shape[0]): - src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] + src_h, src_w, ratio_h, ratio_w = batch[1][batch_index] if self.dilation_kernel is not None: mask = cv2.dilate( np.array(segmentation[batch_index]).astype(np.uint8), diff --git a/opendet/preprocess/__init__.py b/opendet/preprocess/__init__.py index 851cb65..ac77d4f 100644 --- a/opendet/preprocess/__init__.py +++ b/opendet/preprocess/__init__.py @@ -1,10 +1,19 @@ +import copy import io - import cv2 import numpy as np from PIL import Image +from importlib import import_module -from .db_resize_for_test import DetResizeForTest +MODULE_MAPPING = { + 'DetResizeForTest': '.db_resize_for_test', + 'CopyPaste': '.crop_paste', + 'IaaAugment': '.iaa_augment', + 'EastRandomCropData': '.crop_resize', + 'DetLabelEncode': '.db_label_encode', + 'MakeBorderMap': '.db_label_encode', + 'MakeShrinkMap': '.db_label_encode', +} class NormalizeImage(object): @@ -134,21 +143,28 @@ def __call__(self, data): return data -def create_operators(op_param_list, global_config=None): - """create operators based on the config. +def dynamic_import(class_name): + module_path = MODULE_MAPPING.get(class_name) + if not module_path: + raise ValueError(f'Unsupported class: {class_name}') + + module = import_module(module_path, package=__package__) + return getattr(module, class_name) - Args: - params(list): a dict list, used to create some operators - """ - assert isinstance(op_param_list, list), 'operator config should be a list' + +def create_operators(op_param_list, global_config=None): ops = [] - for operator in op_param_list: - assert isinstance(operator, - dict) and len(operator) == 1, 'yaml format error' - op_name = list(operator)[0] - param = {} if operator[op_name] is None else operator[op_name] - if global_config is not None: + for op_info in op_param_list: + op_name = list(op_info.keys())[0] + param = copy.deepcopy(op_info[op_name]) or {} + + if global_config: param.update(global_config) - op = eval(op_name)(**param) - ops.append(op) + + if op_name in globals(): + op_class = globals()[op_name] + else: + op_class = dynamic_import(op_name) + + ops.append(op_class(**param)) return ops diff --git a/opendet/preprocess/crop_paste.py b/opendet/preprocess/crop_paste.py new file mode 100644 index 0000000..e204a72 --- /dev/null +++ b/opendet/preprocess/crop_paste.py @@ -0,0 +1,178 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cv2 +import random +import numpy as np +from PIL import Image +from shapely.geometry import Polygon + +from .iaa_augment import IaaAugment +from .crop_resize import is_poly_outside_rect +from tools.infer.utility import get_rotate_crop_image + + +class CopyPaste(object): + + def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs): + self.ext_data_num = 1 + self.objects_paste_ratio = objects_paste_ratio + self.limit_paste = limit_paste + augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}] + self.aug = IaaAugment(augmenter_args) + + def __call__(self, data): + point_num = data['polys'].shape[1] + src_img = data['image'] + src_polys = data['polys'].tolist() + src_texts = data['texts'] + src_ignores = data['ignore_tags'].tolist() + ext_data = data['ext_data'][0] + ext_image = ext_data['image'] + ext_polys = ext_data['polys'] + ext_texts = ext_data['texts'] + ext_ignores = ext_data['ignore_tags'] + + indexes = [i for i in range(len(ext_ignores)) if not ext_ignores[i]] + select_num = max( + 1, min(int(self.objects_paste_ratio * len(ext_polys)), 30)) + + random.shuffle(indexes) + select_idxs = indexes[:select_num] + select_polys = ext_polys[select_idxs] + select_ignores = ext_ignores[select_idxs] + + src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) + ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB) + src_img = Image.fromarray(src_img).convert('RGBA') + for idx, poly, tag in zip(select_idxs, select_polys, select_ignores): + box_img = get_rotate_crop_image(ext_image, poly) + + src_img, box = self.paste_img(src_img, box_img, src_polys) + if box is not None: + box = box.tolist() + for _ in range(len(box), point_num): + box.append(box[-1]) + src_polys.append(box) + src_texts.append(ext_texts[idx]) + src_ignores.append(tag) + src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR) + h, w = src_img.shape[:2] + src_polys = np.array(src_polys) + src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w) + src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h) + data['image'] = src_img + data['polys'] = src_polys + data['texts'] = src_texts + data['ignore_tags'] = np.array(src_ignores) + return data + + def paste_img(self, src_img, box_img, src_polys): + box_img_pil = Image.fromarray(box_img).convert('RGBA') + src_w, src_h = src_img.size + box_w, box_h = box_img_pil.size + + angle = np.random.randint(0, 360) + box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]]) + box = rotate_bbox(box_img, box, angle)[0] + box_img_pil = box_img_pil.rotate(angle, expand=1) + box_w, box_h = box_img_pil.width, box_img_pil.height + if src_w - box_w < 0 or src_h - box_h < 0: + return src_img, None + + paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w, + src_h - box_h) + if paste_x is None: + return src_img, None + box[:, 0] += paste_x + box[:, 1] += paste_y + r, g, b, A = box_img_pil.split() + src_img.paste(box_img_pil, (paste_x, paste_y), mask=A) + + return src_img, box + + def select_coord(self, src_polys, box, endx, endy): + if self.limit_paste: + xmin, ymin, xmax, ymax = ( + box[:, 0].min(), + box[:, 1].min(), + box[:, 0].max(), + box[:, 1].max(), + ) + for _ in range(50): + paste_x = random.randint(0, endx) + paste_y = random.randint(0, endy) + xmin1 = xmin + paste_x + xmax1 = xmax + paste_x + ymin1 = ymin + paste_y + ymax1 = ymax + paste_y + + num_poly_in_rect = 0 + for poly in src_polys: + if not is_poly_outside_rect(poly, xmin1, ymin1, + xmax1 - xmin1, ymax1 - ymin1): + num_poly_in_rect += 1 + break + if num_poly_in_rect == 0: + return paste_x, paste_y + return None, None + else: + paste_x = random.randint(0, endx) + paste_y = random.randint(0, endy) + return paste_x, paste_y + + +def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + +def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + +def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + +def rotate_bbox(img, text_polys, angle, scale=1): + """ + from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py + Args: + img: np.ndarray + text_polys: np.ndarray N*4*2 + angle: int + scale: int + + Returns: + + """ + w = img.shape[1] + h = img.shape[0] + + rangle = np.deg2rad(angle) + nw = abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w) + nh = abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w) + rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale) + rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0])) + rot_mat[0, 2] += rot_move[0] + rot_mat[1, 2] += rot_move[1] + + # ---------------------- rotate box ---------------------- + rot_text_polys = list() + for bbox in text_polys: + point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1])) + point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1])) + point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1])) + point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1])) + rot_text_polys.append([point1, point2, point3, point4]) + return np.array(rot_text_polys, dtype=np.float32) diff --git a/opendet/preprocess/crop_resize.py b/opendet/preprocess/crop_resize.py index fa67a47..4604b4d 100644 --- a/opendet/preprocess/crop_resize.py +++ b/opendet/preprocess/crop_resize.py @@ -1,4 +1,5 @@ import cv2 +import numpy as np def padding_image(img, size=(640, 640)): @@ -37,6 +38,156 @@ def padding_image(img, size=(640, 640)): return padded_img +def is_poly_outside_rect(poly, x, y, w, h): + poly = np.array(poly) + if poly[:, 0].max() < x or poly[:, 0].min() > x + w: + return True + if poly[:, 1].max() < y or poly[:, 1].min() > y + h: + return True + return False + + +def split_regions(axis): + regions = [] + min_axis = 0 + for i in range(1, axis.shape[0]): + if axis[i] != axis[i - 1] + 1: + region = axis[min_axis:i] + min_axis = i + regions.append(region) + return regions + + +def random_select(axis, max_size): + xx = np.random.choice(axis, size=2) + xmin = np.min(xx) + xmax = np.max(xx) + xmin = np.clip(xmin, 0, max_size - 1) + xmax = np.clip(xmax, 0, max_size - 1) + return xmin, xmax + + +def region_wise_random_select(regions, max_size): + selected_index = list(np.random.choice(len(regions), 2)) + selected_values = [] + for index in selected_index: + axis = regions[index] + xx = int(np.random.choice(axis, size=1)) + selected_values.append(xx) + xmin = min(selected_values) + xmax = max(selected_values) + return xmin, xmax + + +def crop_area(im, text_polys, min_crop_side_ratio, max_tries): + h, w, _ = im.shape + h_array = np.zeros(h, dtype=np.int32) + w_array = np.zeros(w, dtype=np.int32) + for points in text_polys: + points = np.round(points, decimals=0).astype(np.int32) + minx = np.min(points[:, 0]) + maxx = np.max(points[:, 0]) + w_array[minx:maxx] = 1 + miny = np.min(points[:, 1]) + maxy = np.max(points[:, 1]) + h_array[miny:maxy] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + + if len(h_axis) == 0 or len(w_axis) == 0: + return 0, 0, w, h + + h_regions = split_regions(h_axis) + w_regions = split_regions(w_axis) + + for i in range(max_tries): + if len(w_regions) > 1: + xmin, xmax = region_wise_random_select(w_regions, w) + else: + xmin, xmax = random_select(w_axis, w) + if len(h_regions) > 1: + ymin, ymax = region_wise_random_select(h_regions, h) + else: + ymin, ymax = random_select(h_axis, h) + + if (xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h): + # area too small + continue + num_poly_in_rect = 0 + for poly in text_polys: + if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, + ymax - ymin): + num_poly_in_rect += 1 + break + + if num_poly_in_rect > 0: + return xmin, ymin, xmax - xmin, ymax - ymin + + return 0, 0, w, h + + +class EastRandomCropData(object): + + def __init__( + self, + size=(640, 640), + max_tries=10, + min_crop_side_ratio=0.1, + keep_ratio=True, + **kwargs, + ): + self.size = size + self.max_tries = max_tries + self.min_crop_side_ratio = min_crop_side_ratio + self.keep_ratio = keep_ratio + + def __call__(self, data): + img = data['image'] + text_polys = data['polys'] + ignore_tags = data['ignore_tags'] + texts = data['texts'] + all_care_polys = [ + text_polys[i] for i, tag in enumerate(ignore_tags) if not tag + ] + # 计算crop区域 + crop_x, crop_y, crop_w, crop_h = crop_area(img, all_care_polys, + self.min_crop_side_ratio, + self.max_tries) + # crop 图片 保持比例填充 + scale_w = self.size[0] / crop_w + scale_h = self.size[1] / crop_h + scale = min(scale_w, scale_h) + h = int(crop_h * scale) + w = int(crop_w * scale) + if self.keep_ratio: + padimg = np.zeros((self.size[1], self.size[0], img.shape[2]), + img.dtype) + padimg[:h, :w] = cv2.resize( + img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) + img = padimg + else: + img = cv2.resize( + img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], + tuple(self.size), + ) + # crop 文本框 + text_polys_crop = [] + ignore_tags_crop = [] + texts_crop = [] + for poly, text, tag in zip(text_polys, texts, ignore_tags): + poly = ((poly - (crop_x, crop_y)) * scale).tolist() + if not is_poly_outside_rect(poly, 0, 0, w, h): + text_polys_crop.append(poly) + ignore_tags_crop.append(tag) + texts_crop.append(text) + data['image'] = img + data['polys'] = np.array(text_polys_crop) + data['ignore_tags'] = ignore_tags_crop + data['texts'] = texts_crop + return data + + class CropResize(object): def __init__(self, size=(640, 640), interpolation=cv2.INTER_LINEAR): diff --git a/opendet/preprocess/db_label_encode.py b/opendet/preprocess/db_label_encode.py new file mode 100644 index 0000000..83a6cb2 --- /dev/null +++ b/opendet/preprocess/db_label_encode.py @@ -0,0 +1,313 @@ +import numpy as np +import json +import cv2 + +np.seterr(divide='ignore', invalid='ignore') +import pyclipper +from shapely.geometry import Polygon +import warnings + +warnings.simplefilter('ignore') + + +class DetLabelEncode(object): + + def __init__(self, **kwargs): + pass + + def __call__(self, data): + label = data['label'] + label = json.loads(label) + nBox = len(label) + boxes, txts, txt_tags = [], [], [] + for bno in range(0, nBox): + box = label[bno]['points'] + txt = label[bno]['transcription'] + boxes.append(box) + txts.append(txt) + if txt in ['*', '###']: + txt_tags.append(True) + else: + txt_tags.append(False) + if len(boxes) == 0: + return None + boxes = self.expand_points_num(boxes) + boxes = np.array(boxes, dtype=np.float32) + txt_tags = np.array(txt_tags, dtype=np.bool_) + + data['polys'] = boxes + data['texts'] = txts + data['ignore_tags'] = txt_tags + return data + + def order_points_clockwise(self, pts): + rect = np.zeros((4, 2), dtype='float32') + s = pts.sum(axis=1) + rect[0] = pts[np.argmin(s)] + rect[2] = pts[np.argmax(s)] + tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0) + diff = np.diff(np.array(tmp), axis=1) + rect[1] = tmp[np.argmin(diff)] + rect[3] = tmp[np.argmax(diff)] + return rect + + def expand_points_num(self, boxes): + max_points_num = 0 + for box in boxes: + if len(box) > max_points_num: + max_points_num = len(box) + ex_boxes = [] + for box in boxes: + ex_box = box + [box[-1]] * (max_points_num - len(box)) + ex_boxes.append(ex_box) + return ex_boxes + + +class MakeBorderMap(object): + + def __init__(self, + shrink_ratio=0.4, + thresh_min=0.3, + thresh_max=0.7, + **kwargs): + self.shrink_ratio = shrink_ratio + self.thresh_min = thresh_min + self.thresh_max = thresh_max + if 'total_epoch' in kwargs and 'epoch' in kwargs and kwargs[ + 'epoch'] != 'None': + self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs[ + 'epoch'] / float(kwargs['total_epoch']) + + def __call__(self, data): + img = data['image'] + text_polys = data['polys'] + ignore_tags = data['ignore_tags'] + + canvas = np.zeros(img.shape[:2], dtype=np.float32) + mask = np.zeros(img.shape[:2], dtype=np.float32) + + for i in range(len(text_polys)): + if ignore_tags[i]: + continue + self.draw_border_map(text_polys[i], canvas, mask=mask) + canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min + + data['threshold_map'] = canvas + data['threshold_mask'] = mask + return data + + def draw_border_map(self, polygon, canvas, mask): + polygon = np.array(polygon) + assert polygon.ndim == 2 + assert polygon.shape[1] == 2 + + polygon_shape = Polygon(polygon) + if polygon_shape.area <= 0: + return + distance = (polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / + polygon_shape.length) + subject = [tuple(l) for l in polygon] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + + padded_polygon = np.array(padding.Execute(distance)[0]) + cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) + + xmin = padded_polygon[:, 0].min() + xmax = padded_polygon[:, 0].max() + ymin = padded_polygon[:, 1].min() + ymax = padded_polygon[:, 1].max() + width = xmax - xmin + 1 + height = ymax - ymin + 1 + + polygon[:, 0] = polygon[:, 0] - xmin + polygon[:, 1] = polygon[:, 1] - ymin + + xs = np.broadcast_to( + np.linspace(0, width - 1, num=width).reshape(1, width), + (height, width)) + ys = np.broadcast_to( + np.linspace(0, height - 1, num=height).reshape(height, 1), + (height, width)) + + distance_map = np.zeros((polygon.shape[0], height, width), + dtype=np.float32) + for i in range(polygon.shape[0]): + j = (i + 1) % polygon.shape[0] + absolute_distance = self._distance(xs, ys, polygon[i], polygon[j]) + distance_map[i] = np.clip(absolute_distance / distance, 0, 1) + distance_map = distance_map.min(axis=0) + + xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) + xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) + ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) + ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) + canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax( + 1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height, + xmin_valid - xmin:xmax_valid - xmax + width, ], + canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1], + ) + + def _distance(self, xs, ys, point_1, point_2): + """ + compute the distance from point to a line + ys: coordinates in the first axis + xs: coordinates in the second axis + point_1, point_2: (x, y), the end of the line + """ + height, width = xs.shape[:2] + square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - + point_1[1]) + square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - + point_2[1]) + square_distance = np.square(point_1[0] - + point_2[0]) + np.square(point_1[1] - + point_2[1]) + + cosin = (square_distance - square_distance_1 - square_distance_2) / ( + 2 * np.sqrt(square_distance_1 * square_distance_2)) + square_sin = 1 - np.square(cosin) + square_sin = np.nan_to_num(square_sin) + result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / + square_distance) + + result[cosin < 0] = np.sqrt( + np.fmin(square_distance_1, square_distance_2))[cosin < 0] + # self.extend_line(point_1, point_2, result) + return result + + def extend_line(self, point_1, point_2, result, shrink_ratio): + ex_point_1 = ( + int( + round(point_1[0] + (point_1[0] - point_2[0]) * + (1 + shrink_ratio))), + int( + round(point_1[1] + (point_1[1] - point_2[1]) * + (1 + shrink_ratio))), + ) + cv2.line( + result, + tuple(ex_point_1), + tuple(point_1), + 4096.0, + 1, + lineType=cv2.LINE_AA, + shift=0, + ) + ex_point_2 = ( + int( + round(point_2[0] + (point_2[0] - point_1[0]) * + (1 + shrink_ratio))), + int( + round(point_2[1] + (point_2[1] - point_1[1]) * + (1 + shrink_ratio))), + ) + cv2.line( + result, + tuple(ex_point_2), + tuple(point_2), + 4096.0, + 1, + lineType=cv2.LINE_AA, + shift=0, + ) + return ex_point_1, ex_point_2 + + +class MakeShrinkMap(object): + r""" + Making binary mask from detection data with ICDAR format. + Typically following the process of class `MakeICDARData`. + """ + + def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs): + self.min_text_size = min_text_size + self.shrink_ratio = shrink_ratio + if 'total_epoch' in kwargs and 'epoch' in kwargs and kwargs[ + 'epoch'] != 'None': + self.shrink_ratio = self.shrink_ratio + 0.2 * kwargs[ + 'epoch'] / float(kwargs['total_epoch']) + + def __call__(self, data): + image = data['image'] + text_polys = data['polys'] + ignore_tags = data['ignore_tags'] + + h, w = image.shape[:2] + text_polys, ignore_tags = self.validate_polygons( + text_polys, ignore_tags, h, w) + gt = np.zeros((h, w), dtype=np.float32) + mask = np.ones((h, w), dtype=np.float32) + for i in range(len(text_polys)): + polygon = text_polys[i] + height = max(polygon[:, 1]) - min(polygon[:, 1]) + width = max(polygon[:, 0]) - min(polygon[:, 0]) + if ignore_tags[i] or min(height, width) < self.min_text_size: + cv2.fillPoly(mask, + polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + else: + polygon_shape = Polygon(polygon) + subject = [tuple(l) for l in polygon] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + shrunk = [] + + # Increase the shrink ratio every time we get multiple polygon returned back + possible_ratios = np.arange(self.shrink_ratio, 1, + self.shrink_ratio) + np.append(possible_ratios, 1) + # print(possible_ratios) + for ratio in possible_ratios: + # print(f"Change shrink ratio to {ratio}") + distance = (polygon_shape.area * (1 - np.power(ratio, 2)) / + polygon_shape.length) + shrunk = padding.Execute(-distance) + if len(shrunk) == 1: + break + + if shrunk == []: + cv2.fillPoly(mask, + polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + continue + + for each_shrink in shrunk: + shrink = np.array(each_shrink).reshape(-1, 2) + cv2.fillPoly(gt, [shrink.astype(np.int32)], 1) + + data['shrink_map'] = gt + data['shrink_mask'] = mask + return data + + def validate_polygons(self, polygons, ignore_tags, h, w): + """ + polygons (numpy.array, required): of shape (num_instances, num_points, 2) + """ + if len(polygons) == 0: + return polygons, ignore_tags + assert len(polygons) == len(ignore_tags) + for polygon in polygons: + polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1) + polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1) + + for i in range(len(polygons)): + area = self.polygon_area(polygons[i]) + if abs(area) < 1: + ignore_tags[i] = True + if area > 0: + polygons[i] = polygons[i][::-1, :] + return polygons, ignore_tags + + def polygon_area(self, polygon): + """ + compute polygon area + """ + area = 0 + q = polygon[-1] + for p in polygon: + area += p[0] * q[1] - p[1] * q[0] + q = p + return area / 2.0 diff --git a/opendet/preprocess/db_resize_for_test.py b/opendet/preprocess/db_resize_for_test.py index dbd95f0..e974884 100644 --- a/opendet/preprocess/db_resize_for_test.py +++ b/opendet/preprocess/db_resize_for_test.py @@ -27,6 +27,8 @@ def __init__(self, **kwargs): def __call__(self, data): img = data['image'] + if 'max_sile_len' in data: + self.limit_side_len = data['max_sile_len'] src_h, src_w, _ = img.shape if sum([src_h, src_w]) < 64: img = self.image_padding(img) diff --git a/opendet/preprocess/iaa_augment.py b/opendet/preprocess/iaa_augment.py new file mode 100644 index 0000000..1817faa --- /dev/null +++ b/opendet/preprocess/iaa_augment.py @@ -0,0 +1,230 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/iaa_augment.py +""" +import os + +# Prevent automatic updates in Albumentations for stability in augmentation behavior +os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' + +import numpy as np +import albumentations as A +from albumentations.core.transforms_interface import DualTransform +from albumentations.augmentations.geometric import functional as fgeometric +from packaging import version + +ALBU_VERSION = version.parse(A.__version__) +IS_ALBU_NEW_VERSION = ALBU_VERSION >= version.parse('1.4.15') + + +# Custom resize transformation mimicking Imgaug's behavior with scaling +class ImgaugLikeResize(DualTransform): + + def __init__(self, scale_range=(0.5, 3.0), interpolation=1, p=1.0): + super(ImgaugLikeResize, self).__init__(p) + self.scale_range = scale_range + self.interpolation = interpolation + + # Resize the image based on a randomly chosen scale within the scale range + def apply(self, img, scale=1.0, **params): + height, width = img.shape[:2] + new_height = int(height * scale) + new_width = int(width * scale) + + if IS_ALBU_NEW_VERSION: + return fgeometric.resize(img, (new_height, new_width), + interpolation=self.interpolation) + return fgeometric.resize(img, + new_height, + new_width, + interpolation=self.interpolation) + + # Apply the same scaling transformation to keypoints (e.g., polygon points) + def apply_to_keypoints(self, keypoints, scale=1.0, **params): + return np.array([(x * scale, y * scale) + tuple(rest) + for x, y, *rest in keypoints]) + + # Get random scale parameter within the specified range + def get_params(self): + scale = np.random.uniform(self.scale_range[0], self.scale_range[1]) + return {'scale': scale} + + +# Builder class to translate custom augmenter arguments into Albumentations-compatible format +class AugmenterBuilder(object): + + def __init__(self): + # Map common Imgaug transformations to equivalent Albumentations transforms + self.imgaug_to_albu = { + 'Fliplr': 'HorizontalFlip', + 'Flipud': 'VerticalFlip', + 'Affine': 'Affine', + # Additional mappings can be added here if needed + } + + # Recursive method to construct augmentation pipeline based on provided arguments + def build(self, args, root=True): + if args is None or len(args) == 0: + return None + elif isinstance(args, list): + # Build the full augmentation sequence if it's a root-level call + if root: + sequence = [self.build(value, root=False) for value in args] + return A.Compose( + sequence, + keypoint_params=A.KeypointParams(format='xy', + remove_invisible=False), + ) + else: + # Build individual augmenters for nested arguments + augmenter_type = args[0] + augmenter_args = args[1] if len(args) > 1 else {} + augmenter_args_mapped = self.map_arguments( + augmenter_type, augmenter_args) + augmenter_type_mapped = self.imgaug_to_albu.get( + augmenter_type, augmenter_type) + if augmenter_type_mapped == 'Resize': + return ImgaugLikeResize(**augmenter_args_mapped) + else: + cls = getattr(A, augmenter_type_mapped) + return cls( + **{ + k: self.to_tuple_if_list(v) + for k, v in augmenter_args_mapped.items() + }) + elif isinstance(args, dict): + # Process individual transformation specified as dictionary + augmenter_type = args['type'] + augmenter_args = args.get('args', {}) + augmenter_args_mapped = self.map_arguments(augmenter_type, + augmenter_args) + augmenter_type_mapped = self.imgaug_to_albu.get( + augmenter_type, augmenter_type) + if augmenter_type_mapped == 'Resize': + return ImgaugLikeResize(**augmenter_args_mapped) + else: + cls = getattr(A, augmenter_type_mapped) + return cls( + **{ + k: self.to_tuple_if_list(v) + for k, v in augmenter_args_mapped.items() + }) + else: + raise RuntimeError('Unknown augmenter arg: ' + str(args)) + + # Map arguments to expected format for each augmenter type + def map_arguments(self, augmenter_type, augmenter_args): + augmenter_args = augmenter_args.copy( + ) # Avoid modifying the original arguments + if augmenter_type == 'Resize': + # Ensure size is a valid 2-element list or tuple + size = augmenter_args.get('size') + if size: + if not isinstance(size, (list, tuple)) or len(size) != 2: + raise ValueError( + f"'size' must be a list or tuple of two numbers, but got {size}" + ) + min_scale, max_scale = size + return { + 'scale_range': (min_scale, max_scale), + 'interpolation': 1, # Linear interpolation + 'p': 1.0, + } + else: + return { + 'scale_range': (1.0, 1.0), + 'interpolation': 1, + 'p': 1.0 + } + elif augmenter_type == 'Affine': + # Map rotation to a tuple and ensure p=1.0 to apply transformation + rotate = augmenter_args.get('rotate', 0) + if isinstance(rotate, list): + rotate = tuple(rotate) + elif isinstance(rotate, (int, float)): + rotate = (float(rotate), float(rotate)) + augmenter_args['rotate'] = rotate + augmenter_args['p'] = 1.0 + return augmenter_args + else: + # For other augmenters, ensure 'p' probability is specified + p = augmenter_args.get('p', 1.0) + augmenter_args['p'] = p + return augmenter_args + + # Convert lists to tuples for Albumentations compatibility + def to_tuple_if_list(self, obj): + if isinstance(obj, list): + return tuple(obj) + return obj + + +# Wrapper class for image and polygon transformations using Imgaug-style augmentation +class IaaAugment: + + def __init__(self, augmenter_args=None, **kwargs): + if augmenter_args is None: + # Default augmenters if none are specified + augmenter_args = [ + { + 'type': 'Fliplr', + 'args': { + 'p': 0.5 + } + }, + { + 'type': 'Affine', + 'args': { + 'rotate': [-10, 10] + } + }, + { + 'type': 'Resize', + 'args': { + 'size': [0.5, 3] + } + }, + ] + self.augmenter = AugmenterBuilder().build(augmenter_args) + + # Apply the augmentations to image and polygon data + def __call__(self, data): + image = data['image'] + + if self.augmenter: + # Flatten polygons to individual keypoints for transformation + keypoints = [] + keypoints_lengths = [] + for poly in data['polys']: + keypoints.extend([tuple(point) for point in poly]) + keypoints_lengths.append(len(poly)) + + # Apply the augmentation pipeline to image and keypoints + transformed = self.augmenter(image=image, keypoints=keypoints) + data['image'] = transformed['image'] + + # Extract transformed keypoints and reconstruct polygon structures + transformed_keypoints = transformed['keypoints'] + + # Reassemble polygons from transformed keypoints + new_polys = [] + idx = 0 + for length in keypoints_lengths: + new_poly = transformed_keypoints[idx:idx + length] + new_polys.append(np.array([kp[:2] for kp in new_poly])) + idx += length + data['polys'] = np.array(new_polys) + return data diff --git a/openrec/losses/__init__.py b/openrec/losses/__init__.py index 7af475d..25fc17d 100644 --- a/openrec/losses/__init__.py +++ b/openrec/losses/__init__.py @@ -1,40 +1,43 @@ import copy - +from importlib import import_module from torch import nn -from .abinet_loss import ABINetLoss -from .ar_loss import ARLoss -from .cdistnet_loss import CDistNetLoss -from .ce_loss import CELoss -from .cppd_loss import CPPDLoss -from .ctc_loss import CTCLoss -from .igtr_loss import IGTRLoss -from .lister_loss import LISTERLoss -from .lpv_loss import LPVLoss -from .mgp_loss import MGPLoss -from .parseq_loss import PARSeqLoss -from .robustscanner_loss import RobustScannerLoss -from .smtr_loss import SMTRLoss -from .srn_loss import SRNLoss -from .visionlan_loss import VisionLANLoss -from .cam_loss import CAMLoss -from .seed_loss import SEEDLoss - -support_dict = [ - 'CTCLoss', 'ARLoss', 'CELoss', 'CPPDLoss', 'ABINetLoss', 'CDistNetLoss', - 'VisionLANLoss', 'PARSeqLoss', 'IGTRLoss', 'SMTRLoss', 'LPVLoss', - 'RobustScannerLoss', 'SRNLoss', 'LISTERLoss', 'GTCLoss', 'MGPLoss', - 'CAMLoss', 'SEEDLoss' -] +name_to_module = { + 'ABINetLoss': '.abinet_loss', + 'ARLoss': '.ar_loss', + 'CDistNetLoss': '.cdistnet_loss', + 'CELoss': '.ce_loss', + 'CPPDLoss': '.cppd_loss', + 'CTCLoss': '.ctc_loss', + 'IGTRLoss': '.igtr_loss', + 'LISTERLoss': '.lister_loss', + 'LPVLoss': '.lpv_loss', + 'MGPLoss': '.mgp_loss', + 'PARSeqLoss': '.parseq_loss', + 'RobustScannerLoss': '.robustscanner_loss', + 'SEEDLoss': '.seed_loss', + 'SMTRLoss': '.smtr_loss', + 'SRNLoss': '.srn_loss', + 'VisionLANLoss': '.visionlan_loss', + 'CAMLoss': '.cam_loss', +} def build_loss(config): config = copy.deepcopy(config) module_name = config.pop('name') - assert module_name in support_dict, Exception( - 'loss only support {}'.format(support_dict)) - module_class = eval(module_name)(**config) - return module_class + + if module_name in globals(): + module_class = globals()[module_name] + else: + assert module_name in name_to_module, Exception( + '{} is not supported. The losses in {} are supportes'.format( + module_name, list(name_to_module.keys()))) + module_path = name_to_module[module_name] + module = import_module(module_path, package=__package__) + module_class = getattr(module, module_name) + + return module_class(**config) class GTCLoss(nn.Module): @@ -46,7 +49,10 @@ def __init__(self, zero_infinity=True, **kwargs): super(GTCLoss, self).__init__() - self.ctc_loss = CTCLoss(zero_infinity=zero_infinity) + # 动态构建CTCLoss + ctc_config = {'name': 'CTCLoss', 'zero_infinity': zero_infinity} + self.ctc_loss = build_loss(ctc_config) + # 构建GTC损失 self.gtc_loss = build_loss(gtc_loss) self.gtc_weight = gtc_weight self.ctc_weight = ctc_weight diff --git a/openrec/modeling/decoders/__init__.py b/openrec/modeling/decoders/__init__.py index f570fd0..91dc019 100644 --- a/openrec/modeling/decoders/__init__.py +++ b/openrec/modeling/decoders/__init__.py @@ -1,48 +1,52 @@ import torch.nn as nn +from importlib import import_module __all__ = ['build_decoder'] +class_to_module = { + 'ABINetDecoder': '.abinet_decoder', + 'ASTERDecoder': '.aster_decoder', + 'CDistNetDecoder': '.cdistnet_decoder', + 'CPPDDecoder': '.cppd_decoder', + 'RCTCDecoder': '.rctc_decoder', + 'CTCDecoder': '.ctc_decoder', + 'DANDecoder': '.dan_decoder', + 'IGTRDecoder': '.igtr_decoder', + 'LISTERDecoder': '.lister_decoder', + 'LPVDecoder': '.lpv_decoder', + 'MGPDecoder': '.mgp_decoder', + 'NRTRDecoder': '.nrtr_decoder', + 'PARSeqDecoder': '.parseq_decoder', + 'RobustScannerDecoder': '.robustscanner_decoder', + 'SARDecoder': '.sar_decoder', + 'SMTRDecoder': '.smtr_decoder', + 'SMTRDecoderNumAttn': '.smtr_decoder_nattn', + 'SRNDecoder': '.srn_decoder', + 'VisionLANDecoder': '.visionlan_decoder', + 'MATRNDecoder': '.matrn_decoder', + 'CAMDecoder': '.cam_decoder', + 'OTEDecoder': '.ote_decoder', + 'BUSDecoder': '.bus_decoder', + 'DptrParseq': '.dptr_parseq_clip_b_decoder', +} -def build_decoder(config): - # rec decoder - from .abinet_decoder import ABINetDecoder - from .aster_decoder import ASTERDecoder - from .cdistnet_decoder import CDistNetDecoder - from .cppd_decoder import CPPDDecoder - from .rctc_decoder import RCTCDecoder - from .ctc_decoder import CTCDecoder - from .dan_decoder import DANDecoder - from .igtr_decoder import IGTRDecoder - from .lister_decoder import LISTERDecoder - from .lpv_decoder import LPVDecoder - from .mgp_decoder import MGPDecoder - from .nrtr_decoder import NRTRDecoder - from .parseq_decoder import PARSeqDecoder - from .robustscanner_decoder import RobustScannerDecoder - from .sar_decoder import SARDecoder - from .smtr_decoder import SMTRDecoder - from .smtr_decoder_nattn import SMTRDecoderNumAttn - from .srn_decoder import SRNDecoder - from .visionlan_decoder import VisionLANDecoder - from .matrn_decoder import MATRNDecoder - from .cam_decoder import CAMDecoder - from .ote_decoder import OTEDecoder - from .bus_decoder import BUSDecoder - support_dict = [ - 'CTCDecoder', 'NRTRDecoder', 'CPPDDecoder', 'ABINetDecoder', - 'CDistNetDecoder', 'VisionLANDecoder', 'PARSeqDecoder', 'IGTRDecoder', - 'SMTRDecoder', 'LPVDecoder', 'SARDecoder', 'RobustScannerDecoder', - 'SRNDecoder', 'ASTERDecoder', 'RCTCDecoder', 'LISTERDecoder', - 'GTCDecoder', 'SMTRDecoderNumAttn', 'MATRNDecoder', 'MGPDecoder', - 'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder' - ] +def build_decoder(config): module_name = config.pop('name') - assert module_name in support_dict, Exception( - 'decoder only support {}'.format(support_dict)) - module_class = eval(module_name)(**config) - return module_class + + # Check if the class is defined in current module (e.g., GTCDecoder) + if module_name in globals(): + module_class = globals()[module_name] + else: + if module_name not in class_to_module: + raise ValueError(f'Unsupported decoder: {module_name}') + module_str = class_to_module[module_name] + # Dynamically import the module and get the class + module = import_module(module_str, package=__package__) + module_class = getattr(module, module_name) + + return module_class(**config) class GTCDecoder(nn.Module): diff --git a/openrec/modeling/decoders/ctc_decoder.py b/openrec/modeling/decoders/ctc_decoder.py index aca917b..aac753a 100644 --- a/openrec/modeling/decoders/ctc_decoder.py +++ b/openrec/modeling/decoders/ctc_decoder.py @@ -132,7 +132,7 @@ def forward(self, x): z = self.conv2(z) # SVTR global block B, C, H, W = z.shape - z = z.flatten(2).transpose(1, 2) + z = z.flatten(2).transpose(1, 2).contiguous() for blk in self.svtr_block: z = blk(z) z = self.norm(z) diff --git a/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py b/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py new file mode 100644 index 0000000..566b4a2 --- /dev/null +++ b/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py @@ -0,0 +1,1398 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from itertools import permutations +from collections import OrderedDict +import hashlib +import os +import gzip +import html +import urllib +import warnings +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules import transformer +from typing import Any, Optional, Tuple, List, Union +from pkg_resources import packaging +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm +from functools import lru_cache + +import ftfy +import regex as re + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = SimpleTokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x) + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # take features from the eot embedding (eot_token is the highest number in each sequence) + output = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + output = torch.cat([output.unsqueeze(1), x], dim=1) + + return output + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +class FMU(nn.Module): + """A Transformer decoder layer supporting two-stream attention (XLNet) + This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch.""" + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', + layer_norm_eps=1e-5): + super().__init__() + self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = transformer._get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.gelu + super().__setstate__(state) + + def forward(self, query: Tensor, memory: Tensor): + """Forward pass for a single stream (i.e. content or query) + tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency. + Both tgt_kv and memory are expected to be LayerNorm'd too. + memory is LayerNorm'd by ViT. + """ + query1, ca_weights = self.cross_attn(query, memory, memory) + query = query + self.dropout1(query1) + + query2 = self.linear2(self.dropout2(self.activation(self.linear1(self.norm(query))))) + query = query + self.dropout3(query2) + + return query + + +class DecoderLayer(nn.Module): + """A Transformer decoder layer supporting two-stream attention (XLNet) This + implements a pre-LN decoder, as opposed to the post-LN default in + PyTorch.""" + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation='gelu', + layer_norm_eps=1e-5, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, + nhead, + dropout=dropout, + batch_first=True) + self.cross_attn = nn.MultiheadAttention(d_model, + nhead, + dropout=dropout, + batch_first=True) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = transformer._get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.gelu + super().__setstate__(state) + + def forward_stream( + self, + tgt: Tensor, + tgt_norm: Tensor, + tgt_kv: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor], + tgt_key_padding_mask: Optional[Tensor], + ): + """Forward pass for a single stream (i.e. content or query) tgt_norm is + just a LayerNorm'd tgt. + + Added as a separate parameter for efficiency. Both tgt_kv and memory + are expected to be LayerNorm'd too. memory is LayerNorm'd by ViT. + """ + tgt2, sa_weights = self.self_attn( + tgt_norm, + tgt_kv, + tgt_kv, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + + tgt = tgt + self.dropout1(tgt2) + + tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory) + self.attn_map = ca_weights + tgt = tgt + self.dropout2(tgt2) + + tgt2 = self.linear2( + self.dropout(self.activation(self.linear1(self.norm2(tgt))))) + tgt = tgt + self.dropout3(tgt2) + return tgt, sa_weights, ca_weights + + def forward( + self, + query, + content, + memory, + query_mask: Optional[Tensor] = None, + content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None, + update_content: bool = True, + ): + query_norm = self.norm_q(query) + content_norm = self.norm_c(content) + query = self.forward_stream(query, query_norm, content_norm, memory, + query_mask, content_key_padding_mask)[0] + if update_content: + content = self.forward_stream(content, content_norm, content_norm, + memory, content_mask, + content_key_padding_mask)[0] + return query, content + + +class Decoder(nn.Module): + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm): + super().__init__() + self.layers = transformer._get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + query, + content, + memory, + query_mask: Optional[Tensor] = None, + content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None, + ): + for i, mod in enumerate(self.layers): + last = i == len(self.layers) - 1 + query, content = mod( + query, + content, + memory, + query_mask, + content_mask, + content_key_padding_mask, + update_content=not last, + ) + query = self.norm(query) + return query + + +class TokenEmbedding(nn.Module): + + def __init__(self, charset_size: int, embed_dim: int): + super().__init__() + self.embedding = nn.Embedding(charset_size, embed_dim) + self.embed_dim = embed_dim + + def forward(self, tokens: torch.Tensor): + return math.sqrt(self.embed_dim) * self.embedding(tokens) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + +class DptrParseq(nn.Module): + + def __init__(self, + in_channels, + out_channels, + max_label_length=25, + embed_dim=512, + dec_num_heads=8, + dec_mlp_ratio=4, + dec_depth=6, + perm_num=6, + perm_forward=True, + perm_mirrored=True, + decode_ar=True, + refine_iters=1, + dropout=0.1, + is_pretrain=True, + ORP_path=None, + **kwargs: Any) -> None: + super().__init__() + self.pad_id = out_channels - 1 + self.eos_id = 0 + self.bos_id = out_channels - 2 + self.max_label_length = max_label_length + self.decode_ar = decode_ar + self.refine_iters = refine_iters + self.is_pretrain = is_pretrain + if not is_pretrain: + self.token_query = nn.Parameter(torch.Tensor(1, 26, embed_dim)) + self.fmu = FMU(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) + + decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) + self.decoder = Decoder(decoder_layer, + num_layers=dec_depth, + norm=nn.LayerNorm(embed_dim)) + + # Perm/attn mask stuff + self.rng = np.random.default_rng() + self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num + self.perm_forward = perm_forward + self.perm_mirrored = perm_mirrored + + # We don't predict nor + self.head = nn.Linear(embed_dim, out_channels - 2) + self.text_embed = TokenEmbedding(out_channels, embed_dim) + + # +1 for + self.pos_queries = nn.Parameter( + torch.Tensor(1, max_label_length + 1, embed_dim)) + self.dropout = nn.Dropout(p=dropout) + # Encoder has its own init. + self.apply(self._init_weights) + nn.init.trunc_normal_(self.pos_queries, std=0.02) + + if is_pretrain: + self.clip_encoder, preprocess = load("ViT-B/16") + for p in self.clip_encoder.parameters(): + p.requires_grad = False + if ORP_path is None: + background_image_folder_path = 'background_mages_folder/path' + self.background_features = self.get_noise(background_image_folder_path, preprocess) + torch.save(self.background_features, 'save/noise/to/ORP_path') + else: + self.background_features = torch.load(ORP_path, map_location='cpu') + + def _init_weights(self, module: nn.Module): + """Initialize the weights using the typical initialization schemes used + in SOTA models.""" + + if isinstance(module, nn.Linear): + nn.init.trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.trunc_normal_(module.weight, std=0.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, + mode='fan_out', + nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + @torch.jit.ignore + def no_weight_decay(self): + param_names = {'text_embed.embedding.weight', 'pos_queries'} + return param_names + + def get_noise(self, background_image_path, preprocess): + image_paths = [os.path.join(background_image_path, filename) for filename in os.listdir(image_folder_path) if + filename.endswith(('.png', '.jpg', '.jpeg'))] + features = [] + for image_path in image_paths: + image = Image.open(image_path) + input = preprocess(image).unsqueeze(0).to(self._device) + with torch.no_grad(): + feature = self.clip_encoder.encode_image(input) + features.append(feature) + image.close() + return torch.cat(features).cpu().numpy() + + def clip_encode(self, labels): + text_inputs = torch.cat([tokenize(f"a photo of a '{c}'") for c in labels]).to(self._device) + + return self.clip_encoder.encode_text(text_inputs) + + def decode( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[Tensor] = None, + tgt_padding_mask: Optional[Tensor] = None, + tgt_query: Optional[Tensor] = None, + tgt_query_mask: Optional[Tensor] = None, + pos_query: torch.Tensor = None, + ): + N, L = tgt.shape + # stands for the null context. We only supply position information for characters after . + null_ctx = self.text_embed(tgt[:, :1]) + + if tgt_query is None: + tgt_query = pos_query[:, :L] + tgt_emb = pos_query[:, :L - 1] + self.text_embed(tgt[:, 1:]) + tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1)) + + tgt_query = self.dropout(tgt_query) + return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, + tgt_mask, tgt_padding_mask) + + def forward(self, memory, data=None, pos_query=None): + # print(memory.shape, data[0].shape) + if self.training: + if self.is_pretrain: + return self.training_step(None, pos_query, data[0], memory) + return self.training_step(memory, pos_query, data[0], None) + else: + if self.is_pretrain: + return self.forward_test(None, memory, pos_query) + return self.forward_test(memory, None, pos_query) + + def forward_test(self, + memory: Tensor, clip_ids, + pos_query: Tensor = None, + max_length: Optional[int] = None) -> Tensor: + testing = max_length is None + max_length = (self.max_label_length if max_length is None else min( + max_length, self.max_label_length)) + + if self.is_pretrain: + memory = self.clip_encoder.encode_text(clip_ids) + else: + bs = memory.shape[0] + token_query = self.token_query.expand(bs, -1, -1) + memory = self.fmu(token_query, memory) + _device = memory.get_device() + bs = memory.shape[0] + # +1 for at end of sequence. + num_steps = max_length + 1 + # memory = self.encode(images) + + # Query positions up to `num_steps` + if pos_query is None: + pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1) + else: + pos_queries = pos_query + + # Special case for the forward permutation. Faster than using `generate_attn_masks()` + tgt_mask = query_mask = torch.triu( + torch.full((num_steps, num_steps), float('-inf'), device=_device), + 1) + self.attn_maps = [] + if self.decode_ar: + tgt_in = torch.full((bs, num_steps), + self.pad_id, + dtype=torch.long, + device=_device) + tgt_in[:, 0] = self.bos_id + + logits = [] + for i in range(num_steps): + j = i + 1 # next token index + # Efficient decoding: + # Input the context up to the ith token. We use only one query (at position = i) at a time. + # This works because of the lookahead masking effect of the canonical (forward) AR context. + # Past tokens have no access to future tokens, hence are fixed once computed. + tgt_out = self.decode( + tgt_in[:, :j], + memory, + tgt_mask[:j, :j], + tgt_query=pos_queries[:, i:j], + tgt_query_mask=query_mask[i:j, :j], + pos_query=pos_queries, + ) + self.attn_maps.append(self.decoder.layers[-1].attn_map) + # the next token probability is in the output's ith token position + p_i = self.head(tgt_out) + logits.append(p_i) + if j < num_steps: + # greedy decode. add the next token index to the target input + tgt_in[:, j] = p_i.squeeze().argmax(-1) + # Efficient batch decoding: If all output words have at least one EOS token, end decoding. + if testing and (tgt_in == self.eos_id).any(dim=-1).all(): + break + + logits = torch.cat(logits, dim=1) + else: + # No prior context, so input is just . We query all positions. + tgt_in = torch.full((bs, 1), + self.bos_id, + dtype=torch.long, + device=_device) + tgt_out = self.decode(tgt_in, + memory, + tgt_query=pos_queries, + pos_query=pos_queries) + logits = self.head(tgt_out) + + if self.refine_iters: + # For iterative refinement, we always use a 'cloze' mask. + # We can derive it from the AR forward mask by unmasking the token context to the right. + query_mask[torch.triu( + torch.ones(num_steps, + num_steps, + dtype=torch.bool, + device=_device), 2)] = 0 + bos = torch.full((bs, 1), + self.bos_id, + dtype=torch.long, + device=_device) + for i in range(self.refine_iters): + # Prior context is the previous output. + tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1) + tgt_len = tgt_in.shape[1] + tgt_padding_mask = (tgt_in == self.eos_id).int().cumsum( + -1) > 0 # mask tokens beyond the first EOS token. + tgt_out = self.decode( + tgt_in, + memory, + tgt_mask[:tgt_len, :tgt_len], + tgt_padding_mask, + tgt_query=pos_queries, + tgt_query_mask=query_mask[:, :tgt_len], + pos_query=pos_queries, + ) + logits = self.head(tgt_out) + + return F.softmax(logits, -1) + + def gen_tgt_perms(self, tgt, _device): + """Generate shared permutations for the whole batch. + + This works because the same attention mask can be used for the shorter + sequences because of the padding mask. + """ + # We don't permute the position of BOS, we permute EOS separately + max_num_chars = tgt.shape[1] - 2 + # Special handling for 1-character sequences + if max_num_chars == 1: + return torch.arange(3, device=_device).unsqueeze(0) + perms = [torch.arange(max_num_chars, device=_device) + ] if self.perm_forward else [] + # Additional permutations if needed + max_perms = math.factorial(max_num_chars) + if self.perm_mirrored: + max_perms //= 2 + num_gen_perms = min(self.max_gen_perms, max_perms) + # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions + # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars. + if max_num_chars < 5: + # Pool of permutations to sample from. We only need the first half (if complementary option is selected) + # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves + if max_num_chars == 4 and self.perm_mirrored: + selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21] + else: + selector = list(range(max_perms)) + perm_pool = torch.as_tensor(list( + permutations(range(max_num_chars), max_num_chars)), + device=_device)[selector] + # If the forward permutation is always selected, no need to add it to the pool for sampling + if self.perm_forward: + perm_pool = perm_pool[1:] + perms = torch.stack(perms) + if len(perm_pool): + i = self.rng.choice(len(perm_pool), + size=num_gen_perms - len(perms), + replace=False) + perms = torch.cat([perms, perm_pool[i]]) + else: + perms.extend([ + torch.randperm(max_num_chars, device=_device) + for _ in range(num_gen_perms - len(perms)) + ]) + perms = torch.stack(perms) + if self.perm_mirrored: + # Add complementary pairs + comp = perms.flip(-1) + # Stack in such a way that the pairs are next to each other. + perms = torch.stack([perms, comp + ]).transpose(0, 1).reshape(-1, max_num_chars) + # NOTE: + # The only meaningful way of permuting the EOS position is by moving it one character position at a time. + # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS + # positions will always be much less than the number of permutations (unless a low perm_num is set). + # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly + # distribute it across the chosen number of permutations. + # Add position indices of BOS and EOS + bos_idx = perms.new_zeros((len(perms), 1)) + eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1) + perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1) + # Special handling for the reverse direction. This does two things: + # 1. Reverse context for the characters + # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode) + if len(perms) > 1: + perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, + device=_device) + return perms + + def generate_attn_masks(self, perm, _device): + """Generate attention masks given a sequence permutation (includes pos. + for bos and eos tokens) + + :param perm: the permutation sequence. i = 0 is always the BOS + :return: lookahead attention masks + """ + sz = perm.shape[0] + mask = torch.zeros((sz, sz), device=_device) + for i in range(sz): + query_idx = perm[i] + masked_keys = perm[i + 1:] + mask[query_idx, masked_keys] = float('-inf') + content_mask = mask[:-1, :-1].clone() + mask[torch.eye(sz, dtype=torch.bool, + device=_device)] = float('-inf') # mask "self" + query_mask = mask[1:, :-1] + return content_mask, query_mask + + def training_step(self, memory, pos_query, tgt_ids, clip_ids): + bs = tgt_ids.shape[0] + if self.is_pretrain: + memory = self.clip_encoder.encode_text(clip_ids) + n = memory.shape[1] + B, N, D = self.background_features.shape + random_B = np.random.choice(B, bs, replace=False) + random_N = np.random.choice(N, n, replace=False) + noise = self.background_features[random_B][:, random_N] + noise = torch.from_numpy(noise).to(memory.get_device()) + memory = memory + noise * 1e-1 + else: + token_query = self.token_query.expand(bs, -1, -1) + memory = self.fmu(token_query, memory) + + if pos_query is None: + pos_query = self.pos_queries.expand(bs, -1, -1) + # Prepare the target sequences (input and output) + tgt_perms = self.gen_tgt_perms(tgt_ids, memory.get_device()) + tgt_in = tgt_ids[:, :-1] + tgt_out = tgt_ids[:, 1:] + + # The [EOS] token is not depended upon by any other token in any permutation ordering + tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id) + + loss = 0 + loss_numel = 0 + n = (tgt_out != self.pad_id).sum().item() + for i, perm in enumerate(tgt_perms): + tgt_mask, query_mask = self.generate_attn_masks( + perm, memory.get_device()) + # print("tgt_in:", tgt_in, "tgt_out:", tgt_out, "tgt_padding_mask:", tgt_padding_mask) + # print('tgt_mask:', tgt_mask) + # print('query_mask:', query_mask) + # print(tgt_in.shape, memory.shape, tgt_mask.shape, tgt_padding_mask.shape, query_mask.shape, pos_query.shape) + out = self.decode( + tgt_in, + memory, + tgt_mask, + tgt_padding_mask, + tgt_query_mask=query_mask, + pos_query=pos_query, + ) + # print('out:', out) + logits = self.head(out) + # print('logits:', logits) + if i == 0: + final_out = logits + loss += n * F.cross_entropy(logits.flatten(end_dim=1), + tgt_out.flatten(), + ignore_index=self.pad_id) + loss_numel += n + # After the second iteration (i.e. done with canonical and reverse orderings), + # remove the [EOS] tokens for the succeeding perms + if i == 1: + tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id, + tgt_out) + n = (tgt_out != self.pad_id).sum().item() + loss /= loss_numel + + # self.log('loss', loss) + return [loss, final_out] diff --git a/openrec/modeling/decoders/igtr_decoder.py b/openrec/modeling/decoders/igtr_decoder.py index 491e84b..8586b73 100644 --- a/openrec/modeling/decoders/igtr_decoder.py +++ b/openrec/modeling/decoders/igtr_decoder.py @@ -245,6 +245,42 @@ def forward(self, question_f, prompt_f, visual_f, mask=None): class IGTRDecoder(nn.Module): + """ + IGTRDecoder is a neural network module designed for decoding tasks in OCR (Optical Character Recognition) systems. + It utilizes a combination of embedding layers, multi-head attention layers, and linear layers to process input sequences + and generate output sequences. + + Args: + in_channels (int): Number of input channels. + dim (int): Dimension of the model. + out_channels (int): Number of output channels. + num_layer (int, optional): Number of layers in the decoder. Default is 2. + drop_path_rate (float, optional): Drop path rate for stochastic depth. Default is 0.1. + max_len (int, optional): Maximum length of the sequence. Default is 25. + vis_seq (int, optional): Length of the visual sequence. Default is 50. + ch (bool, optional): Flag for character embedding. Default is False. + ar (bool, optional): Flag for autoregressive decoding. Default is False. + refine_iter (int, optional): Number of refinement iterations. Default is 0. + quesall (bool, optional): Flag to use all questions. Default is True. + next_pred (bool, optional): Flag for next prediction. Default is False. + ds (bool, optional): Flag for downsampling. Default is False. + pos2d (bool, optional): Flag for 2D positional embedding. Default is False. + check_search (bool, optional): Flag for checking search. Default is False. + max_size (list, optional): Maximum size for 2D positional embedding. Default is [8, 32]. + **kwargs: Additional keyword arguments. + + Methods: + _init_weights(m): Initializes the weights of the module. + no_weight_decay(): Returns the parameters that should not have weight decay. + question_encoder(targets, train_i): Encodes the questions based on the targets and training index. + forward(x, data=None): Forward pass of the decoder. Calls either forward_train or forward_test based on the mode. + forward_test(x): Forward pass during testing. + forward_train(x, targets=None): Forward pass during training. + + Returns: + Depending on the mode (training or testing), the forward method returns either the loss and logits (during training) + or the predicted indices and probabilities (during testing). + """ def __init__(self, in_channels, @@ -426,6 +462,33 @@ def forward(self, x, data=None): return self.forward_test(x) def forward_test(self, x): + """ + Perform the forward pass for the test phase. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). + + Returns: + torch.Tensor or List[torch.Tensor]: The output logits or a list containing predicted indices and probabilities. + + The function handles different modes of operation based on the attributes: + - `self.ds`: Determines if positional embedding is added to the input tensor. + - `self.pos2d`: Determines if the positional embedding is 2D. + - `self.ar`: Determines if autoregressive decoding is used. + - `self.check_search`: Determines if beam search is used. + - `self.next_pred`: Determines if next token prediction is used. + - `self.refine_iter`: Number of refinement iterations for the predictions. + + The function performs the following steps: + 1. Adds positional embeddings to the input tensor if required. + 2. Initializes the BOS (beginning of sequence) prompt. + 3. Depending on the mode, performs decoding using different strategies: + - Beam search decoding. + - Autoregressive decoding. + - Next token prediction. + 4. If refinement iterations are specified, refines the predictions. + 5. Returns the final logits or the predicted indices and probabilities. + """ if not self.ds: visual_f = x + self.vis_pos_embed elif self.pos2d: @@ -477,7 +540,6 @@ def forward_test(self, x): if j < self.max_len - 1: # greedy decode. add the next token index to the target input tgt_in[:, j] = p_i.squeeze().argmax(-1) - # Efficient batch decoding: If all output words have at least one EOS token, end decoding. if (tgt_in == self.eos).any(dim=-1).all(): break @@ -652,6 +714,34 @@ def forward_test(self, x): return F.softmax(logits, -1) def forward_train(self, x, targets=None): + """ + Forward pass for training the model. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, ...). + targets (list, optional): List of target tensors. The list should contain: + - targets[1]: Tensor of shape (batch_size, ...), prompt position indices. + - targets[2]: Tensor of shape (batch_size, ...), prompt character indices. + - targets[3]: Tensor of shape (batch_size, ...), question position indices. + - targets[4]: Tensor of shape (batch_size, ...), question 1 answers. + - targets[5]: Tensor of shape (batch_size, ...), question 2 character indices. + - targets[6]: Tensor of shape (batch_size, ...), question 2 answers. + - targets[7]: Tensor of shape (batch_size, ..., 2), question 3 character indices and answers. + - targets[8]: Tensor of shape (batch_size, ...), question 4 character numbers. + - targets[9]: Tensor of shape (batch_size, ...), question lengths. + - targets[10]: Tensor of shape (batch_size, ...), prompt lengths. + - targets[11]: Tensor of shape (batch_size, ...), question 4 answers. + + Returns: + list: A list containing: + - loss (dict): Dictionary containing the total loss and individual losses for each question. + - 'loss': Total loss. + - 'loss1': Loss for question 1. + - 'loss2': Loss for question 2. + - 'loss3': Loss for question 3. + - 'loss4': Loss for question 4. + - logits (torch.Tensor): Logits for question 1 predictions. + """ bs = x.shape[0] answer_token = torch.tile(self.answer_query, (bs, 1, 1)) diff --git a/openrec/modeling/decoders/smtr_decoder.py b/openrec/modeling/decoders/smtr_decoder.py index 01bd7ee..5c7cc0a 100644 --- a/openrec/modeling/decoders/smtr_decoder.py +++ b/openrec/modeling/decoders/smtr_decoder.py @@ -240,26 +240,26 @@ def forward_test_bi(self, x): next_id = torch.full([1, self.sub_str_len], self.bos_next, dtype=torch.long, - device=x.get_device()) + device=x.device) pre_id = torch.full([1, self.sub_str_len], self.bos_pre, dtype=torch.long, - device=x.get_device()) + device=x.device) # prompt_next_bos = self.char_embed(prompt_id) - # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.get_device()) + # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.device) next_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, - device=x.get_device()) + device=x.device) pre_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, - device=x.get_device()) + device=x.device) next_logits_all = [] pre_logits_all = [] mask_pad = torch.zeros([bs, 1], dtype=torch.float32, - device=x.get_device()) + device=x.device) for j in range(0, min(70, self.max_len - 1)): prompt_char_next = torch.concat([ @@ -317,7 +317,7 @@ def forward_test_bi(self, x): ques_next = self.next_token.tile([1, 1, 1, 1]).squeeze(1) mask_pad = torch.zeros([1, 1], dtype=torch.float32, - device=x.get_device()) + device=x.device) for j in range(0, min(70, self.max_len - 1)): prompt_next = torch.concat([ @@ -371,17 +371,17 @@ def forward_test_bi_attn(self, x): prompt_next_embed = self.prompt_next_embed.squeeze(1) prompt_pre_embed = self.prompt_pre_embed.squeeze(1) - next_id = torch.full([1, self.sub_str_len], self.bos_next, dtype=torch.long, device=x.get_device()) - pre_id = torch.full([1, self.sub_str_len], self.bos_pre, dtype=torch.long, device=x.get_device()) + next_id = torch.full([1, self.sub_str_len], self.bos_next, dtype=torch.long, device=x.device) + pre_id = torch.full([1, self.sub_str_len], self.bos_pre, dtype=torch.long, device=x.device) # prompt_next_bos = self.char_embed(prompt_id) - # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.get_device()) - next_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.get_device()) - pre_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.get_device()) + # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.device) + next_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.device) + pre_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.device) next_logits_all = [] pre_logits_all = [] attn_map_next = [] attn_map_pre = [] - mask_pad = torch.zeros([bs, 1], dtype=torch.float32, device=x.get_device()) + mask_pad = torch.zeros([bs, 1], dtype=torch.float32, device=x.device) for j in range(0, min(70, self.max_len-1)): prompt_char_next = torch.concat([prompt_next_embed[:, :1, :], prompt_next_embed[:, 1:, :] + self.char_embed(next_id)], 1) # b, sub_l, dim @@ -428,7 +428,7 @@ def forward_test_bi_attn(self, x): next_logits_all_mid = [] attn_map_next_mid = [] ques_next = self.next_token.tile([1, 1, 1, 1]).squeeze(1) - mask_pad = torch.zeros([1, 1], dtype=torch.float32, device=x.get_device()) + mask_pad = torch.zeros([1, 1], dtype=torch.float32, device=x.device) for j in range(0, min(70, self.max_len-1)): prompt_next = torch.concat([prompt_next_embed[:, :1, :], prompt_next_embed[:, 1:, :] + self.char_embed(next_id)], 1) # b, sub_l, dim @@ -474,15 +474,15 @@ def forward_test(self, x): prompt_id = torch.full([bs, self.sub_str_len], self.bos_next, dtype=torch.long, - device=x.get_device()) + device=x.device) pred_id_list = torch.full([bs, self.max_len], self.ignore_index, dtype=torch.long, - device=x.get_device()) + device=x.device) logits_all = [] mask_pad = torch.zeros([bs, 1], dtype=torch.float32, - device=x.get_device()) + device=x.device) for j in range(0, self.max_len - 1): prompt_next = torch.concat([ @@ -522,15 +522,15 @@ def forward_test(self, x): prompt_id = torch.full([bs, self.sub_str_len], self.bos_pre, dtype=torch.long, - device=x.get_device()) + device=x.device) pred_id_list = torch.full([bs, self.max_len], self.ignore_index, dtype=torch.long, - device=x.get_device()) + device=x.device) logits_all = [] mask_pad = torch.zeros([bs, 1], dtype=torch.float32, - device=x.get_device()) + device=x.device) for j in range(0, self.max_len - 1): prompt_next = torch.concat([ @@ -599,7 +599,7 @@ def forward_train(self, x, targets=None): mask_pad = torch.zeros([bs * (max_len_curr + max_len_curr_pre), 1], dtype=torch.float32, - device=x.get_device()) + device=x.device) mask = torch.concat([mask_next, mask_pre], 1).flatten(0, 1) mask = torch.concat([mask_pad, mask], 1) next_pre = next_pre.flatten(0, 1) diff --git a/openrec/modeling/encoders/__init__.py b/openrec/modeling/encoders/__init__.py index a3496c4..390fe21 100644 --- a/openrec/modeling/encoders/__init__.py +++ b/openrec/modeling/encoders/__init__.py @@ -1,39 +1,41 @@ __all__ = ['build_encoder'] +from importlib import import_module + +name_to_module = { + 'MobileNetV1Enhance': '.rec_mv1_enhance', + 'ResNet31': '.rec_resnet_31', + 'MobileNetV3': '.rec_mobilenet_v3', + 'PPLCNetV3': '.rec_lcnetv3', + 'PPHGNet_small': '.rec_hgnet', + 'ResNet': '.rec_resnet_vd', + 'MTB': '.rec_nrtr_mtb', + 'SVTRNet': '.svtrnet', + 'ResNet45': '.rec_resnet_45', + 'ViT': '.vit', + 'SVTRNet2DPos': '.svtrnet2dpos', + 'SVTRv2': '.svtrv2', + 'FocalSVTR': '.focalsvtr', + 'ResNet_FPN': '.rec_resnet_fpn', + 'ResNet_ASTER': '.resnet31_rnn', + 'SVTRv2LNConv': '.svtrv2_lnconv', + 'SVTRv2LNConvTwo33': '.svtrv2_lnconv_two33', + 'CAMEncoder': '.cam_encoder', + 'ConvNeXtV2': '.convnextv2', + 'AutoSTREncoder': '.autostr_encoder', + 'NRTREncoder': '.nrtr_encoder', + 'RepSVTREncoder': '.repvit', +} + def build_encoder(config): - # from .rec_mobilenet_v3 import MobileNetV3 - from .focalsvtr import FocalSVTR - from .rec_hgnet import PPHGNet_small - from .rec_lcnetv3 import PPLCNetV3 - from .rec_mv1_enhance import MobileNetV1Enhance - from .rec_nrtr_mtb import MTB - from .rec_resnet_31 import ResNet31 - from .rec_resnet_45 import ResNet45 - from .rec_resnet_fpn import ResNet_FPN - from .rec_resnet_vd import ResNet - from .resnet31_rnn import ResNet_ASTER - from .svtrnet import SVTRNet - from .svtrnet2dpos import SVTRNet2DPos - from .svtrv2 import SVTRv2 - from .svtrv2_lnconv import SVTRv2LNConv - from .svtrv2_lnconv_two33 import SVTRv2LNConvTwo33 - from .vit import ViT - from .cam_encoder import CAMEncoder - from .convnextv2 import ConvNeXtV2 - from .autostr_encoder import AutoSTREncoder - from .nrtr_encoder import NRTREncoder - from .repvit import RepSVTREncoder - support_dict = [ - 'MobileNetV1Enhance', 'ResNet31', 'MobileNetV3', 'PPLCNetV3', - 'PPHGNet_small', 'ResNet', 'MTB', 'SVTRNet', 'ResNet45', 'ViT', - 'SVTRNet2DPos', 'SVTRv2', 'FocalSVTR', 'ResNet_FPN', 'ResNet_ASTER', - 'SVTRv2LNConv', 'SVTRv2LNConvTwo33', 'CAMEncoder', 'ConvNeXtV2', - 'AutoSTREncoder', 'NRTREncoder', 'RepSVTREncoder' - ] module_name = config.pop('name') - assert module_name in support_dict, Exception( - 'when encoder of rec model only support {}'.format(support_dict)) - module_class = eval(module_name)(**config) + assert module_name in name_to_module, Exception( + f'Encoder only supports: {list(name_to_module.keys())}') + + module_path = name_to_module[module_name] + mod = import_module(module_path, package=__package__) + module_class = getattr(mod, module_name)(**config) + return module_class diff --git a/openrec/postprocess/__init__.py b/openrec/postprocess/__init__.py index e12a57d..68707cf 100644 --- a/openrec/postprocess/__init__.py +++ b/openrec/postprocess/__init__.py @@ -1,27 +1,25 @@ import copy +from importlib import import_module __all__ = ['build_post_process'] -from .abinet_postprocess import ABINetLabelDecode -from .ar_postprocess import ARLabelDecode -from .ce_postprocess import CELabelDecode -from .char_postprocess import CharLabelDecode -from .cppd_postprocess import CPPDLabelDecode -from .ctc_postprocess import CTCLabelDecode -from .igtr_postprocess import IGTRLabelDecode -from .lister_postprocess import LISTERLabelDecode -from .mgp_postprocess import MPGLabelDecode -from .nrtr_postprocess import NRTRLabelDecode -from .smtr_postprocess import SMTRLabelDecode -from .srn_postprocess import SRNLabelDecode -from .visionlan_postprocess import VisionLANLabelDecode - -support_dict = [ - 'CTCLabelDecode', 'CharLabelDecode', 'CELabelDecode', 'CPPDLabelDecode', - 'NRTRLabelDecode', 'ABINetLabelDecode', 'ARLabelDecode', 'IGTRLabelDecode', - 'VisionLANLabelDecode', 'SMTRLabelDecode', 'SRNLabelDecode', - 'LISTERLabelDecode', 'GTCLabelDecode', 'MPGLabelDecode' -] +# 定义类名到模块路径的映射 +module_mapping = { + 'CTCLabelDecode': '.ctc_postprocess', + 'CharLabelDecode': '.char_postprocess', + 'CELabelDecode': '.ce_postprocess', + 'CPPDLabelDecode': '.cppd_postprocess', + 'NRTRLabelDecode': '.nrtr_postprocess', + 'ABINetLabelDecode': '.abinet_postprocess', + 'ARLabelDecode': '.ar_postprocess', + 'IGTRLabelDecode': '.igtr_postprocess', + 'VisionLANLabelDecode': '.visionlan_postprocess', + 'SMTRLabelDecode': '.smtr_postprocess', + 'SRNLabelDecode': '.srn_postprocess', + 'LISTERLabelDecode': '.lister_postprocess', + 'MPGLabelDecode': '.mgp_postprocess', + 'GTCLabelDecode': '.' # 当前模块中的类 +} def build_post_process(config, global_config=None): @@ -29,10 +27,21 @@ def build_post_process(config, global_config=None): module_name = config.pop('name') if global_config is not None: config.update(global_config) - assert module_name in support_dict, Exception( - 'post process only support {}'.format(support_dict)) - module_class = eval(module_name)(**config) - return module_class + + assert module_name in module_mapping, Exception( + 'post process only support {}'.format(list(module_mapping.keys()))) + + module_path = module_mapping[module_name] + + # 处理当前模块中的类 + if module_path == '.': + module_class = globals()[module_name] + else: + # 动态导入模块 + module = import_module(module_path, package=__package__) + module_class = getattr(module, module_name) + + return module_class(**config) class GTCLabelDecode(object): @@ -48,9 +57,14 @@ def __init__(self, gtc_label_decode['character_dict_path'] = character_dict_path gtc_label_decode['use_space_char'] = use_space_char self.gtc_label_decode = build_post_process(gtc_label_decode) - self.ctc_label_decode = CTCLabelDecode( - character_dict_path=character_dict_path, - use_space_char=use_space_char) + self.ctc_label_decode = build_post_process({ + 'name': + 'CTCLabelDecode', + 'character_dict_path': + character_dict_path, + 'use_space_char': + use_space_char + }) self.gtc_character = self.gtc_label_decode.character self.ctc_character = self.ctc_label_decode.character self.only_gtc = only_gtc diff --git a/openrec/postprocess/abinet_postprocess.py b/openrec/postprocess/abinet_postprocess.py index 2cbae49..2ec286d 100644 --- a/openrec/postprocess/abinet_postprocess.py +++ b/openrec/postprocess/abinet_postprocess.py @@ -29,7 +29,7 @@ def __call__(self, preds, batch=None, *args, **kwargs): text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) if batch is None: return text - label = self.decode(batch[1].cpu().numpy()) + label = self.decode(batch[1]) return text, label def add_special_char(self, dict_character): diff --git a/openrec/postprocess/ar_postprocess.py b/openrec/postprocess/ar_postprocess.py index 0045a7c..25a0104 100644 --- a/openrec/postprocess/ar_postprocess.py +++ b/openrec/postprocess/ar_postprocess.py @@ -30,7 +30,7 @@ def __call__(self, preds, batch=None, *args, **kwargs): if batch is None: return text label = batch[1] - label = self.decode(label[:, 1:].detach().cpu().numpy()) + label = self.decode(label[:, 1:]) return text, label def add_special_char(self, dict_character): diff --git a/openrec/postprocess/cppd_postprocess.py b/openrec/postprocess/cppd_postprocess.py index 4e044c3..0b415ab 100644 --- a/openrec/postprocess/cppd_postprocess.py +++ b/openrec/postprocess/cppd_postprocess.py @@ -34,7 +34,7 @@ def __call__(self, preds, batch=None, *args, **kwargs): if batch is None: return text label = batch[1] - label = self.decode(label.detach().cpu().numpy()) + label = self.decode(label) return text, label def add_special_char(self, dict_character): diff --git a/openrec/postprocess/ctc_postprocess.py b/openrec/postprocess/ctc_postprocess.py index fd16013..8c164cd 100644 --- a/openrec/postprocess/ctc_postprocess.py +++ b/openrec/postprocess/ctc_postprocess.py @@ -1,7 +1,6 @@ import re import numpy as np -import torch class BaseRecLabelDecode(object): @@ -102,16 +101,16 @@ def __init__(self, super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char) - def __call__(self, preds, batch=None, *args, **kwargs): + def __call__(self, preds, batch=None, **kwargs): # preds = preds['res'] - if isinstance(preds, torch.Tensor): + if kwargs.get('torch_tensor', True): preds = preds.detach().cpu().numpy() preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) if batch is None: return text - label = self.decode(batch[1].cpu().numpy()) + label = self.decode(batch[1]) return text, label def add_special_char(self, dict_character): diff --git a/openrec/postprocess/igtr_postprocess.py b/openrec/postprocess/igtr_postprocess.py index 8fc12ca..e5edff7 100644 --- a/openrec/postprocess/igtr_postprocess.py +++ b/openrec/postprocess/igtr_postprocess.py @@ -53,7 +53,7 @@ def __call__(self, preds, batch=None, *args, **kwargs): if batch is None: return text label = batch[1] - label = self.decode(label.detach().cpu().numpy()) + label = self.decode(label) return text, label def add_special_char(self, dict_character): diff --git a/openrec/postprocess/lister_postprocess.py b/openrec/postprocess/lister_postprocess.py index 6b4750a..54a61a9 100644 --- a/openrec/postprocess/lister_postprocess.py +++ b/openrec/postprocess/lister_postprocess.py @@ -26,7 +26,7 @@ def __call__(self, preds, batch=None, *args, **kwargs): if batch is None: return text label = batch[1] - label = self.decode(label.detach().cpu().numpy()) + label = self.decode(label) return text, label def add_special_char(self, dict_character): diff --git a/openrec/postprocess/mgp_postprocess.py b/openrec/postprocess/mgp_postprocess.py index 8677fc9..65f854d 100644 --- a/openrec/postprocess/mgp_postprocess.py +++ b/openrec/postprocess/mgp_postprocess.py @@ -37,7 +37,7 @@ def __call__(self, preds, batch=None, *args, **kwargs): if batch is None: return char_text label = batch[1] - label = self.char_decode(label[:, 1:].detach().cpu().numpy()) + label = self.char_decode(label[:, 1:]) if self.only_char: return char_text, label else: diff --git a/openrec/postprocess/nrtr_postprocess.py b/openrec/postprocess/nrtr_postprocess.py index fcee427..2bfb48f 100644 --- a/openrec/postprocess/nrtr_postprocess.py +++ b/openrec/postprocess/nrtr_postprocess.py @@ -33,7 +33,7 @@ def __call__(self, preds, batch=None, *args, **kwargs): is_remove_duplicate=False) if batch is None: return text - label = self.decode(batch[1][:, 1:].cpu().numpy()) + label = self.decode(batch[1][:, 1:]) else: if isinstance(preds, torch.Tensor): preds = preds.detach().cpu().numpy() @@ -44,7 +44,7 @@ def __call__(self, preds, batch=None, *args, **kwargs): is_remove_duplicate=False) if batch is None: return text - label = self.decode(batch[1][:, 1:].cpu().numpy()) + label = self.decode(batch[1][:, 1:]) return text, label def add_special_char(self, dict_character): diff --git a/openrec/postprocess/smtr_postprocess.py b/openrec/postprocess/smtr_postprocess.py index 2546943..5ead0be 100644 --- a/openrec/postprocess/smtr_postprocess.py +++ b/openrec/postprocess/smtr_postprocess.py @@ -33,7 +33,7 @@ def __call__(self, preds, batch=None, *args, **kwargs): if batch is None: return text label = batch[1] - label = self.decode(label[:, 1:].detach().cpu().numpy()) + label = self.decode(label[:, 1:]) return text, label def add_special_char(self, dict_character): diff --git a/openrec/postprocess/srn_postprocess.py b/openrec/postprocess/srn_postprocess.py index 24b6d66..690e42e 100644 --- a/openrec/postprocess/srn_postprocess.py +++ b/openrec/postprocess/srn_postprocess.py @@ -71,7 +71,7 @@ def __call__(self, preds, batch=None, *args, **kwargs): if batch is None: return text - label = batch[1].cpu().numpy() + label = batch[1] # print(f"label.shape:{label.shape}") label = self.decode(label, is_remove_duplicate=False) return text, label diff --git a/openrec/postprocess/visionlan_postprocess.py b/openrec/postprocess/visionlan_postprocess.py index e0bfe82..04a11b9 100644 --- a/openrec/postprocess/visionlan_postprocess.py +++ b/openrec/postprocess/visionlan_postprocess.py @@ -77,5 +77,5 @@ def __call__(self, preds, batch=None, *args, **kwargs): text.append((preds_text, float(preds_prob))) if batch is None: return text - label = self.decode(label.detach().cpu().numpy()) + label = self.decode(label) return text, label diff --git a/openrec/preprocess/__init__.py b/openrec/preprocess/__init__.py index 0558b30..a5ea2be 100644 --- a/openrec/preprocess/__init__.py +++ b/openrec/preprocess/__init__.py @@ -1,71 +1,33 @@ import io +import copy +import importlib import cv2 import numpy as np from PIL import Image -from .abinet_label_encode import ABINetLabelEncode -from .ar_label_encode import ARLabelEncode -from .ce_label_encode import CELabelEncode -from .char_label_encode import CharLabelEncode -from .cppd_label_encode import CPPDLabelEncode -from .ctc_label_encode import CTCLabelEncode -from .ep_label_encode import EPLabelEncode -from .igtr_label_encode import IGTRLabelEncode -from .mgp_label_encode import MGPLabelEncode -from .rec_aug import ABINetAug -from .rec_aug import BaseDataAugmentation as BDA -from .rec_aug import PARSeqAug, PARSeqAugPIL, SVTRAug -from .resize import (ABINetResize, CDistNetResize, LongResize, RecTVResize, - RobustScannerRecResizeImg, SliceResize, SliceTVResize, - SRNRecResizeImg, SVTRResize, VisionLANResize, - RecDynamicResize) -from .smtr_label_encode import SMTRLabelEncode -from .srn_label_encode import SRNLabelEncode -from .visionlan_label_encode import VisionLANLabelEncode -from .cam_label_encode import CAMLabelEncode - - -class KeepKeys(object): + +class KeepKeys: def __init__(self, keep_keys, **kwargs): self.keep_keys = keep_keys def __call__(self, data): - data_list = [] - for key in self.keep_keys: - data_list.append(data[key]) - return data_list + return [data[key] for key in self.keep_keys] -def transform(data, ops=None): - """transform.""" - if ops is None: - ops = [] - for op in ops: - data = op(data) - if data is None: - return None - return data - - -class Fasttext(object): +class Fasttext: def __init__(self, path='None', **kwargs): - # pip install fasttext==0.9.1 import fasttext - self.fast_model = fasttext.load_model(path) def __call__(self, data): - label = data['label'] - fast_label = self.fast_model[label] - data['fast_label'] = fast_label + data['fast_label'] = self.fast_model[data['label']] return data -class DecodeImage(object): - """decode image.""" +class DecodeImage: def __init__(self, img_mode='RGB', @@ -77,23 +39,15 @@ def __init__(self, self.ignore_orientation = ignore_orientation def __call__(self, data): - img = data['image'] - - assert type(img) is bytes and len( - img) > 0, "invalid input 'img' in DecodeImage" - img = np.frombuffer(img, dtype='uint8') - if self.ignore_orientation: - img = cv2.imdecode( - img, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR) - else: - img = cv2.imdecode(img, 1) - if img is None: - return None + assert isinstance(data['image'], bytes) and len(data['image']) > 0 + img = np.frombuffer(data['image'], dtype='uint8') + + flags = cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR if self.ignore_orientation else 1 + img = cv2.imdecode(img, flags) + if self.img_mode == 'GRAY': img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) elif self.img_mode == 'RGB': - assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( - img.shape) img = img[:, :, ::-1] if self.channel_first: @@ -103,45 +57,93 @@ def __call__(self, data): return data -class DecodeImagePIL(object): - """decode image.""" +class DecodeImagePIL: def __init__(self, img_mode='RGB', **kwargs): self.img_mode = img_mode def __call__(self, data): - img = data['image'] - assert type(img) is bytes and len( - img) > 0, "invalid input 'img' in DecodeImage" - img = data['image'] - buf = io.BytesIO(img) - img = Image.open(buf).convert('RGB') + assert isinstance(data['image'], bytes) and len(data['image']) > 0 + img = Image.open(io.BytesIO(data['image'])).convert('RGB') + if self.img_mode == 'Gray': img = img.convert('L') elif self.img_mode == 'BGR': - img = np.array(img)[:, :, ::-1] # 将图片转为numpy格式,并将最后一维通道倒序 - img = Image.fromarray(np.uint8(img)) + img = Image.fromarray(np.array(img)[:, :, ::-1]) + data['image'] = img return data -def create_operators(op_param_list, global_config=None): - """create operators based on the config. +def transform(data, ops=None): + """transform.""" + if ops is None: + ops = [] + for op in ops: + data = op(data) + if data is None: + return None + return data + + +# 类名到模块的映射 +MODULE_MAPPING = { + 'ABINetLabelEncode': '.abinet_label_encode', + 'ARLabelEncode': '.ar_label_encode', + 'CELabelEncode': '.ce_label_encode', + 'CharLabelEncode': '.char_label_encode', + 'CPPDLabelEncode': '.cppd_label_encode', + 'CTCLabelEncode': '.ctc_label_encode', + 'EPLabelEncode': '.ep_label_encode', + 'IGTRLabelEncode': '.igtr_label_encode', + 'MGPLabelEncode': '.mgp_label_encode', + 'SMTRLabelEncode': '.smtr_label_encode', + 'SRNLabelEncode': '.srn_label_encode', + 'VisionLANLabelEncode': '.visionlan_label_encode', + 'CAMLabelEncode': '.cam_label_encode', + 'ABINetAug': '.rec_aug', + 'BDA': '.rec_aug', + 'PARSeqAug': '.rec_aug', + 'PARSeqAugPIL': '.rec_aug', + 'SVTRAug': '.rec_aug', + 'ABINetResize': '.resize', + 'CDistNetResize': '.resize', + 'LongResize': '.resize', + 'RecTVResize': '.resize', + 'RobustScannerRecResizeImg': '.resize', + 'SliceResize': '.resize', + 'SliceTVResize': '.resize', + 'SRNRecResizeImg': '.resize', + 'SVTRResize': '.resize', + 'VisionLANResize': '.resize', + 'RecDynamicResize': '.resize', +} + + +def dynamic_import(class_name): + module_path = MODULE_MAPPING.get(class_name) + if not module_path: + raise ValueError(f'Unsupported class: {class_name}') + + module = importlib.import_module(module_path, package=__package__) + return getattr(module, class_name) - Args: - params(list): a dict list, used to create some operators - """ - assert isinstance(op_param_list, list), 'operator config should be a list' + +def create_operators(op_param_list, global_config=None): ops = [] - for operator in op_param_list: - assert isinstance(operator, - dict) and len(operator) == 1, 'yaml format error' - op_name = list(operator)[0] - param = {} if operator[op_name] is None else operator[op_name] - if global_config is not None: + for op_info in op_param_list: + op_name = list(op_info.keys())[0] + param = copy.deepcopy(op_info[op_name]) or {} + + if global_config: param.update(global_config) - op = eval(op_name)(**param) - ops.append(op) + + if op_name in globals(): + op_class = globals()[op_name] + else: + op_class = dynamic_import(op_name) + + ops.append(op_class(**param)) return ops @@ -154,14 +156,13 @@ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs): - self.gtc_label_encode = eval(gtc_label_encode['name'])( + self.gtc_label_encode = dynamic_import(gtc_label_encode['name'])( max_text_length=max_text_length, character_dict_path=character_dict_path, use_space_char=use_space_char, **gtc_label_encode) - self.ctc_label_encode = CTCLabelEncode(max_text_length, - character_dict_path, - use_space_char) + self.ctc_label_encode = dynamic_import('CTCLabelEncode')( + max_text_length, character_dict_path, use_space_char) def __call__(self, data): data_ctc = self.ctc_label_encode({'label': data['label']}) diff --git a/openrec/preprocess/dptr_label_encode.py b/openrec/preprocess/dptr_label_encode.py new file mode 100644 index 0000000..0fd4f23 --- /dev/null +++ b/openrec/preprocess/dptr_label_encode.py @@ -0,0 +1,157 @@ +import re +from abc import ABC, abstractmethod +from itertools import groupby +from typing import List, Optional, Tuple +import numpy as np +import torch +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence +import unicodedata +from ..modeling.decoders.dptr_parseq_clip_b_decoder import tokenize + +class CharsetAdapter: + """Transforms labels according to the target charset.""" + + def __init__(self, target_charset) -> None: + super().__init__() + self.lowercase_only = target_charset == target_charset.lower() + self.uppercase_only = target_charset == target_charset.upper() + self.unsupported = re.compile(f'[^{re.escape(target_charset)}]') + + def __call__(self, label): + if self.lowercase_only: + label = label.lower() + elif self.uppercase_only: + label = label.upper() + # Remove unsupported characters + label = self.unsupported.sub('', label) + return label + + +class BaseTokenizer(ABC): +# eos=0, a=1, bos=37, pad=38 + def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None: + self._itos = specials_first + tuple(charset) + specials_last + self._stoi = {s: i for i, s in enumerate(self._itos)} + # print("stoi:", self._stoi) + + def __len__(self): + return len(self._itos) + + def _tok2ids(self, tokens: str) -> List[int]: + # print("tokens", tokens) + return [self._stoi[s] for s in tokens] + + def _ids2tok(self, token_ids: List[int], join: bool = True) -> str: + tokens = [self._itos[i] for i in token_ids] + return ''.join(tokens) if join else tokens + + @abstractmethod + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + """Encode a batch of labels to a representation suitable for the model. + + Args: + labels: List of labels. Each can be of arbitrary length. + device: Create tensor on this device. + + Returns: + Batched tensor representation padded to the max label length. Shape: N, L + """ + raise NotImplementedError + + @abstractmethod + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + """Internal method which performs the necessary filtering prior to decoding.""" + raise NotImplementedError + + def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]: + """Decode a batch of token distributions. + + Args: + token_dists: softmax probabilities over the token distribution. Shape: N, L, C + raw: return unprocessed labels (will return list of list of strings) + + Returns: + list of string labels (arbitrary length) and + their corresponding sequence probabilities as a list of Tensors + """ + batch_tokens = [] + batch_probs = [] + for dist in token_dists: + probs, ids = dist.max(-1) # greedy selection + if not raw: + probs, ids = self._filter(probs, ids) + tokens = self._ids2tok(ids, not raw) + batch_tokens.append(tokens) + batch_probs.append(probs) + return batch_tokens, batch_probs + + +class Tokenizer(BaseTokenizer): + BOS = '[B]' + EOS = '[E]' + PAD = '[P]' + + def __init__(self, charset: str) -> None: + specials_first = (self.EOS,) + specials_last = (self.BOS, self.PAD) + super().__init__(charset, specials_first, specials_last) + self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last] + + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + batch = [self.bos_id] + self._tok2ids(labels) + [self.eos_id] + return batch + # return pad_sequence(batch, batch_first=True, padding_value=self.pad_id) + + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + ids = ids.tolist() + try: + eos_idx = ids.index(self.eos_id) + except ValueError: + eos_idx = len(ids) # Nothing to truncate. + # Truncate after EOS + ids = ids[:eos_idx] + probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists) + return probs, ids + +class DPTRLabelEncode(Tokenizer): + """Convert between text-label and text-index.""" + def __init__(self, max_text_length=25, character_dict_path=None, **kwargs): + self.max_length = max_text_length + charset = get_alpha(character_dict_path) + charset = ''.join(charset) + # print(charset) + super(DPTRLabelEncode, self).__init__(charset) + + def __call__(self, data, normalize_unicode=True): + text = data['label'] + + if normalize_unicode: + text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode() + text = ''.join(text.split()) + if len(text) == 0 or len(text) > self.max_length: + return None + + text_ids = self.encode(text) + clip_ids = tokenize(f"a photo of a '{text}'") + text_ids = text_ids + [self.pad_id] * (self.max_length + 2 - len(text_ids)) + # print(text, len(text_ids), len(clip_ids[0])) + data['clip_label'] = np.array(clip_ids[0]) + data['label'] = np.array(text_ids) + return data + + def add_special_char(self, dict_character): + dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD] + return dict_character + +def get_alpha(alpha_path): + character_str = [] + with open(alpha_path, 'rb') as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip('\n').strip('\r\n') + character_str.append(line) + dict_character = list(character_str) + if 'arabic' in alpha_path: + reverse = True + return dict_character \ No newline at end of file diff --git a/openrec/preprocess/igtr_label_encode.py b/openrec/preprocess/igtr_label_encode.py index 50020d3..fab502f 100644 --- a/openrec/preprocess/igtr_label_encode.py +++ b/openrec/preprocess/igtr_label_encode.py @@ -192,15 +192,26 @@ def add_special_char(self, dict_character): return dict_character def encode(self, text): - """convert text-label into text-index. - input: - text: text labels of each image. [batch_size] - - output: - text: concatenated text index for CTCLoss. - [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] - length: length of each text. [batch_size] """ + Encodes the given text into a list of character IDs and generates various lists for question and prompt sequences. + + Args: + text (str): The input text to be encoded. + + Returns: + tuple: A tuple containing: + - text_list (list): A list of character IDs corresponding to the input text. + - char_num (list): A list of character counts for each character ID. + - ques_list (list): A list of question sequences, each sequence is a list of [position, character ID, character count]. + - prompt_list (list): A list of prompt sequences, each sequence is a list of [position, character ID, character count]. + + Notes: + - If the input text is empty, the function returns None. + - The function handles rare and unrare characters differently. + - The function supports both lowercased and original text based on the `self.lower` attribute. + - The function generates additional sequences if the length of the input text is greater than 1. + """ + if len(text) == 0: return None if self.lower: diff --git a/openrec/preprocess/rec_aug.py b/openrec/preprocess/rec_aug.py index cdc882a..8f66418 100644 --- a/openrec/preprocess/rec_aug.py +++ b/openrec/preprocess/rec_aug.py @@ -3,9 +3,6 @@ import cv2 import numpy as np from PIL import Image -from torchvision.transforms import Compose - -from .abinet_aug import CVColorJitter, CVDeterioration, CVGeometry, SVTRDeterioration, SVTRGeometry from .parseq_aug import rand_augment_transform @@ -41,6 +38,8 @@ def __init__(self, deterioration_p=0.25, colorjitter_p=0.25, **kwargs): + from torchvision.transforms import Compose + from .abinet_aug import CVColorJitter, CVDeterioration, CVGeometry self.transforms = Compose([ CVGeometry( degrees=45, @@ -75,6 +74,8 @@ def __init__(self, deterioration_p=0.25, colorjitter_p=0.25, **kwargs): + from torchvision.transforms import Compose + from .abinet_aug import CVColorJitter, SVTRDeterioration, SVTRGeometry self.transforms = Compose([ SVTRGeometry( aug_type=aug_type, diff --git a/openrec/preprocess/resize.py b/openrec/preprocess/resize.py index bfdb615..941b3e8 100644 --- a/openrec/preprocess/resize.py +++ b/openrec/preprocess/resize.py @@ -3,10 +3,7 @@ import cv2 import numpy as np -import torch from PIL import Image -from torchvision import transforms as T -from torchvision.transforms import functional as F class CDistNetResize(object): @@ -86,6 +83,9 @@ def __call__(self, data): class RecTVResize(object): def __init__(self, image_shape=[32, 128], padding=True, **kwargs): + from torchvision import transforms as T + from torchvision.transforms import functional as F + self.F = F self.padding = padding self.image_shape = image_shape self.interpolation = T.InterpolationMode.BICUBIC @@ -108,11 +108,11 @@ def __call__(self, data): resized_w = imgW else: resized_w = int(math.ceil(imgH * ratio)) - resized_image = F.resize(img, (imgH, resized_w), - interpolation=self.interpolation) + resized_image = self.F.resize(img, (imgH, resized_w), + interpolation=self.interpolation) img = self.transforms(resized_image) if resized_w < imgW: - img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.) + img = self.F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.) valid_ratio = min(1.0, float(resized_w / imgW)) data['image'] = img data['valid_ratio'] = valid_ratio @@ -185,6 +185,11 @@ def __init__(self, max_ratio=12, base_h=32, **kwargs): + import torch + from torchvision import transforms as T + from torchvision.transforms import functional as F + self.F = F + self.torch = torch self.image_shape = image_shape self.padding = padding self.max_ratio = max_ratio @@ -202,14 +207,14 @@ def __call__(self, data): w, h = img.size w_ratio = ((w // h) // 2) * 2 w_ratio = max(6, w_ratio) - img = F.resize(img, (self.base_h, self.base_h * w_ratio), - interpolation=self.interpolation) + img = self.F.resize(img, (self.base_h, self.base_h * w_ratio), + interpolation=self.interpolation) img = self.transforms(img) img_list = [] for i in range(0, w_ratio // 2 - 1): img_list.append(img[None, :, :, i * 2 * self.base_h:(i * 2 + 4) * self.base_h]) - data['image'] = torch.concat(img_list, 0) + data['image'] = self.torch.concat(img_list, 0) data['valid_ratio'] = float(w_ratio) / w return data @@ -223,6 +228,9 @@ def __init__(self, max_ratio=12, base_h=32, **kwargs): + from torchvision import transforms as T + from torchvision.transforms import functional as F + self.F = F self.padding = padding self.image_shape = image_shape self.max_ratio = max_ratio @@ -256,11 +264,11 @@ def __call__(self, data): resized_w = imgW else: resized_w = int(math.ceil(imgH * ratio)) - resized_image = F.resize(img, (imgH, resized_w), - interpolation=self.interpolation) + resized_image = self.F.resize(img, (imgH, resized_w), + interpolation=self.interpolation) img = self.transforms(resized_image) if resized_w < imgW: - img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.) + img = self.F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.) valid_ratio = min(1.0, float(resized_w / imgW)) data['image'] = img data['valid_ratio'] = valid_ratio diff --git a/requirements.txt b/requirements.txt index 1ab8901..01c064a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ imgaug lmdb numpy -opencv-python<=4.6.0.66 +opencv-python +pyclipper pyyaml rapidfuzz tqdm diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/create_lmdb_dataset.py b/tools/create_lmdb_dataset.py index a96aa92..d9a3a44 100644 --- a/tools/create_lmdb_dataset.py +++ b/tools/create_lmdb_dataset.py @@ -99,6 +99,7 @@ def createDataset(data_list, outputPath, checkValid=True): if __name__ == '__main__': data_dir = './Union14M-L/' + # downloading the filtered_label_list from https://drive.google.com/drive/folders/1x1LC8C_W-Frl3sGV9i9_i_OD-bqNdodJ?usp=drive_link label_file_list = [ './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_challenging.jsonl.txt', './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_easy.jsonl.txt', diff --git a/tools/data/__init__.py b/tools/data/__init__.py index ad1a203..8638758 100644 --- a/tools/data/__init__.py +++ b/tools/data/__init__.py @@ -1,71 +1,87 @@ import os import sys +import copy +import importlib __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) -import copy - from torch.utils.data import DataLoader, DistributedSampler -from tools.data.lmdb_dataset import LMDBDataSet -from tools.data.lmdb_dataset_test import LMDBDataSetTest -from tools.data.multi_scale_sampler import MultiScaleSampler -from tools.data.ratio_dataset import RatioDataSet -from tools.data.ratio_dataset_test import RatioDataSetTest -from tools.data.ratio_dataset_tvresize_test import RatioDataSetTVResizeTest -from tools.data.ratio_dataset_tvresize import RatioDataSetTVResize -from tools.data.ratio_sampler import RatioSampler -from tools.data.simple_dataset import MultiScaleDataSet, SimpleDataSet -from tools.data.strlmdb_dataset import STRLMDBDataSet +# 定义支持的 Dataset 类及其对应的模块路径 +DATASET_MODULES = { + 'SimpleDataSet': 'tools.data.simple_dataset', + 'LMDBDataSet': 'tools.data.lmdb_dataset', + 'TextLMDBDataSet': 'tools.data.text_lmdb_dataset', + 'MultiScaleDataSet': 'tools.data.simple_dataset', + 'STRLMDBDataSet': 'tools.data.strlmdb_dataset', + 'LMDBDataSetTest': 'tools.data.lmdb_dataset_test', + 'RatioDataSet': 'tools.data.ratio_dataset', + 'RatioDataSetTest': 'tools.data.ratio_dataset_test', + 'RatioDataSetTVResize': 'tools.data.ratio_dataset_tvresize', + 'RatioDataSetTVResizeTest': 'tools.data.ratio_dataset_tvresize_test' +} + +# 定义支持的 Sampler 类及其对应的模块路径 +SAMPLER_MODULES = { + 'MultiScaleSampler': 'tools.data.multi_scale_sampler', + 'RatioSampler': 'tools.data.ratio_sampler' +} __all__ = [ 'build_dataloader', - 'transform', - 'create_operators', ] -def build_dataloader(config, mode, logger, seed=None, epoch=3): +def build_dataloader(config, mode, logger, seed=None, epoch=3, task='rec'): config = copy.deepcopy(config) + mode = mode.capitalize() # 确保 mode 是首字母大写形式(Train/Eval/Test) + + # 获取 dataset 配置 + dataset_config = config[mode]['dataset'] + module_name = dataset_config['name'] + + # 动态导入 dataset 类 + if module_name not in DATASET_MODULES: + raise ValueError( + f'Unsupported dataset: {module_name}. Supported datasets: {list(DATASET_MODULES.keys())}' + ) + + dataset_module = importlib.import_module(DATASET_MODULES[module_name]) + dataset_class = getattr(dataset_module, module_name) + dataset = dataset_class(config, mode, logger, seed, epoch=epoch, task=task) - support_dict = [ - 'SimpleDataSet', 'LMDBDataSet', 'MultiScaleDataSet', 'STRLMDBDataSet', - 'LMDBDataSetTest', 'RatioDataSet', 'RatioDataSetTest', - 'RatioDataSetTVResize', 'RatioDataSetTVResizeTest' - ] - module_name = config[mode]['dataset']['name'] - assert module_name in support_dict, Exception( - 'DataSet only support {}/{}'.format(support_dict, module_name)) - assert mode in ['Train', 'Eval', - 'Test'], 'Mode should be Train, Eval or Test.' - - dataset = eval(module_name)(config, mode, logger, seed, epoch=epoch) + # DataLoader 配置 loader_config = config[mode]['loader'] batch_size = loader_config['batch_size_per_card'] drop_last = loader_config['drop_last'] shuffle = loader_config['shuffle'] num_workers = loader_config['num_workers'] - if 'pin_memory' in loader_config.keys(): - pin_memory = loader_config['use_shared_memory'] - else: - pin_memory = False + pin_memory = loader_config.get('pin_memory', False) sampler = None batch_sampler = None if 'sampler' in config[mode]: - config_sampler = config[mode]['sampler'] - sampler_name = config_sampler.pop('name') - batch_sampler = eval(sampler_name)(dataset, **config_sampler) + sampler_config = config[mode]['sampler'] + sampler_name = sampler_config.pop('name') + + if sampler_name not in SAMPLER_MODULES: + raise ValueError( + f'Unsupported sampler: {sampler_name}. Supported samplers: {list(SAMPLER_MODULES.keys())}' + ) + + sampler_module = importlib.import_module(SAMPLER_MODULES[sampler_name]) + sampler_class = getattr(sampler_module, sampler_name) + batch_sampler = sampler_class(dataset, **sampler_config) elif config['Global']['distributed'] and mode == 'Train': sampler = DistributedSampler(dataset=dataset, shuffle=shuffle) if 'collate_fn' in loader_config: from . import collate_fn - collate_fn = getattr(collate_fn, loader_config['collate_fn'])() else: collate_fn = None + if batch_sampler is None: data_loader = DataLoader( dataset=dataset, @@ -84,11 +100,14 @@ def build_dataloader(config, mode, logger, seed=None, epoch=3): pin_memory=pin_memory, collate_fn=collate_fn, ) + + # 检查数据加载器是否为空 if len(data_loader) == 0: logger.error( - f'No Images in {mode.lower()} dataloader, please ensure\n' - '\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n' - '\t2. The annotation file and path in the configuration file are provided normally.\n' - '\t3. The BatchSize is large than images.') + f'No Images in {mode.lower()} dataloader. Please check:\n' + '\t1. The images num in the train label_file_list should be >= batch size.\n' + '\t2. The annotation file and path in the configuration are correct.\n' + '\t3. The BatchSize is not larger than the number of images.') sys.exit() + return data_loader diff --git a/tools/data/lmdb_dataset.py b/tools/data/lmdb_dataset.py index be17c50..5978b05 100644 --- a/tools/data/lmdb_dataset.py +++ b/tools/data/lmdb_dataset.py @@ -10,7 +10,7 @@ class LMDBDataSet(Dataset): - def __init__(self, config, mode, logger, seed=None, epoch=1): + def __init__(self, config, mode, logger, seed=None, epoch=1, task='rec'): super(LMDBDataSet, self).__init__() global_config = config['Global'] diff --git a/tools/data/lmdb_dataset_test.py b/tools/data/lmdb_dataset_test.py index 4fffda7..5ebe6a8 100644 --- a/tools/data/lmdb_dataset_test.py +++ b/tools/data/lmdb_dataset_test.py @@ -49,7 +49,8 @@ def __init__(self, remove_whitespace: bool = True, normalize_unicode: bool = True, unlabelled: bool = False, - transform=None): + transform=None, + task='rec'): dataset_config = config[mode]['dataset'] global_config = config['Global'] max_label_len = global_config['max_text_length'] diff --git a/tools/data/ratio_dataset.py b/tools/data/ratio_dataset.py index 514ac3d..7a7433c 100644 --- a/tools/data/ratio_dataset.py +++ b/tools/data/ratio_dataset.py @@ -13,7 +13,7 @@ class RatioDataSet(Dataset): - def __init__(self, config, mode, logger, seed=None, epoch=1): + def __init__(self, config, mode, logger, seed=None, epoch=1, task='rec'): super(RatioDataSet, self).__init__() self.ds_width = config[mode]['dataset'].get('ds_width', True) global_config = config['Global'] diff --git a/tools/data/ratio_dataset_test.py b/tools/data/ratio_dataset_test.py index 18adfd6..ca34d7f 100644 --- a/tools/data/ratio_dataset_test.py +++ b/tools/data/ratio_dataset_test.py @@ -34,7 +34,7 @@ def __call__(self, label): class RatioDataSetTest(Dataset): - def __init__(self, config, mode, logger, seed=None, epoch=1): + def __init__(self, config, mode, logger, seed=None, epoch=1, task='rec'): super(RatioDataSetTest, self).__init__() self.ds_width = config[mode]['dataset'].get('ds_width', True) global_config = config['Global'] diff --git a/tools/data/ratio_dataset_tvresize.py b/tools/data/ratio_dataset_tvresize.py index e0f0469..5dbda1f 100644 --- a/tools/data/ratio_dataset_tvresize.py +++ b/tools/data/ratio_dataset_tvresize.py @@ -15,7 +15,7 @@ class RatioDataSetTVResize(Dataset): - def __init__(self, config, mode, logger, seed=None, epoch=1): + def __init__(self, config, mode, logger, seed=None, epoch=1, task='rec'): super(RatioDataSetTVResize, self).__init__() self.ds_width = config[mode]['dataset'].get('ds_width', True) global_config = config['Global'] diff --git a/tools/data/ratio_dataset_tvresize_test.py b/tools/data/ratio_dataset_tvresize_test.py index abf6ff5..832ec14 100644 --- a/tools/data/ratio_dataset_tvresize_test.py +++ b/tools/data/ratio_dataset_tvresize_test.py @@ -36,7 +36,7 @@ def __call__(self, label): class RatioDataSetTVResizeTest(Dataset): - def __init__(self, config, mode, logger, seed=None, epoch=1): + def __init__(self, config, mode, logger, seed=None, epoch=1, task='rec'): super(RatioDataSetTVResizeTest, self).__init__() self.ds_width = config[mode]['dataset'].get('ds_width', True) global_config = config['Global'] diff --git a/tools/data/ratio_sampler.py b/tools/data/ratio_sampler.py index e0c9d72..51e9c2a 100644 --- a/tools/data/ratio_sampler.py +++ b/tools/data/ratio_sampler.py @@ -56,7 +56,7 @@ def __init__(self, self.base_im_w = base_im_w # Get the GPU and node related information - num_replicas = torch.cuda.device_count() + num_replicas = torch.cuda.device_count() if torch.cuda.is_available() else 1 # rank = dist.get_rank() rank = (int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0) diff --git a/tools/data/simple_dataset.py b/tools/data/simple_dataset.py index 9e79304..61477c9 100644 --- a/tools/data/simple_dataset.py +++ b/tools/data/simple_dataset.py @@ -8,12 +8,12 @@ import numpy as np from torch.utils.data import Dataset -from openrec.preprocess import create_operators, transform +from openrec.preprocess import transform class SimpleDataSet(Dataset): - def __init__(self, config, mode, logger, seed=None, epoch=0): + def __init__(self, config, mode, logger, seed=None, epoch=0, task='rec'): super(SimpleDataSet, self).__init__() self.logger = logger self.mode = mode.lower() @@ -42,7 +42,10 @@ def __init__(self, config, mode, logger, seed=None, epoch=0): self.shuffle_data_random() self.set_epoch_as_seed(self.seed, dataset_config) - + if task == 'rec': + from openrec.preprocess import create_operators + elif task == 'det': + from opendet.preprocess import create_operators self.ops = create_operators(dataset_config['transforms'], global_config) self.ext_op_transform_idx = dataset_config.get('ext_op_transform_idx', diff --git a/tools/data/strlmdb_dataset.py b/tools/data/strlmdb_dataset.py index 76ce439..4b69dd0 100644 --- a/tools/data/strlmdb_dataset.py +++ b/tools/data/strlmdb_dataset.py @@ -18,7 +18,8 @@ def __init__(self, config, mode, logger, seed=None, epoch=1, gpu_i=0): loader_config = config[mode]['loader'] loader_config['batch_size_per_card'] # data_dir = dataset_config['data_dir'] - data_dir = '../training_aug_lmdb_noerror/ep' + str(epoch) + data_dir = '../training_aug_lmdb_noerror/ep' + str( + epoch % 20 if epoch % 20 != 0 else 20) self.do_shuffle = loader_config['shuffle'] self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir) diff --git a/tools/data/text_lmdb_dataset.py b/tools/data/text_lmdb_dataset.py new file mode 100644 index 0000000..23b82d9 --- /dev/null +++ b/tools/data/text_lmdb_dataset.py @@ -0,0 +1,131 @@ +import os +import lmdb +import numpy as np +from torch.utils.data import Dataset + +from openrec.preprocess import create_operators, transform + + +class TextLMDBDataSet(Dataset): + + def __init__(self, config, mode, logger, seed=None, epoch=1, task='rec'): + super(TextLMDBDataSet, self).__init__() + + global_config = config['Global'] + dataset_config = config[mode]['dataset'] + loader_config = config[mode]['loader'] + loader_config['batch_size_per_card'] + data_dir = dataset_config['data_dir'] + self.do_shuffle = loader_config['shuffle'] + + self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir) + logger.info(f'Initialize indexs of datasets: {data_dir}') + self.data_idx_order_list = self.dataset_traversal() + if self.do_shuffle: + np.random.shuffle(self.data_idx_order_list) + self.ops = create_operators(dataset_config['transforms'], + global_config) + self.ext_op_transform_idx = dataset_config.get('ext_op_transform_idx', + 1) + + ratio_list = dataset_config.get('ratio_list', [1.0]) + self.need_reset = True in [x < 1 for x in ratio_list] + + def load_hierarchical_lmdb_dataset(self, data_dir): + lmdb_sets = {} + dataset_idx = 0 + for dirpath, dirnames, filenames in os.walk(data_dir + '/'): + if not dirnames: + env = lmdb.open( + dirpath, + max_readers=32, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + txn = env.begin(write=False) + num_samples = int(txn.get('num-samples'.encode())) + lmdb_sets[dataset_idx] = { + 'dirpath': dirpath, + 'env': env, + 'txn': txn, + 'num_samples': num_samples, + } + dataset_idx += 1 + return lmdb_sets + + def dataset_traversal(self): + lmdb_num = len(self.lmdb_sets) + total_sample_num = 0 + for lno in range(lmdb_num): + total_sample_num += self.lmdb_sets[lno]['num_samples'] + data_idx_order_list = np.zeros((total_sample_num, 2)) + beg_idx = 0 + for lno in range(lmdb_num): + tmp_sample_num = self.lmdb_sets[lno]['num_samples'] + end_idx = beg_idx + tmp_sample_num + data_idx_order_list[beg_idx:end_idx, 0] = lno + data_idx_order_list[beg_idx:end_idx, + 1] = list(range(tmp_sample_num)) + data_idx_order_list[beg_idx:end_idx, 1] += 1 + beg_idx = beg_idx + tmp_sample_num + return data_idx_order_list + + def get_ext_data(self): + ext_data_num = 0 + for op in self.ops: + if hasattr(op, 'ext_data_num'): + ext_data_num = getattr(op, 'ext_data_num') + break + load_data_ops = self.ops[:self.ext_op_transform_idx] + ext_data = [] + + while len(ext_data) < ext_data_num: + lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint( + len(self))] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info( + self.lmdb_sets[lmdb_idx]['txn'], file_idx) + if sample_info is None: + continue + label = sample_info + data = {'label': label} + data = transform(data, load_data_ops) + if data is None: + continue + ext_data.append(data) + return ext_data + + def get_lmdb_sample_info(self, + txn, + index, + normalize_unicode=True, + remove_whitespace=True, + max_length=True): + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key) + if label is None: + return None + label = label.decode('utf-8') + + return label + + def __getitem__(self, idx): + lmdb_idx, file_idx = self.data_idx_order_list[idx] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info( + self.lmdb_sets[lmdb_idx]['txn'], file_idx) + if sample_info is None: + return self.__getitem__(np.random.randint(self.__len__())) + label = sample_info + data = {'label': label} + outs = transform(data, self.ops) + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return self.data_idx_order_list.shape[0] diff --git a/tools/download/download_dataset.py b/tools/download/download_dataset.py new file mode 100644 index 0000000..50e5a3b --- /dev/null +++ b/tools/download/download_dataset.py @@ -0,0 +1,32 @@ +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) + +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..', '..'))) + +from engine import Config +from utility import ArgsParser +import download.utils +from torchvision.datasets.utils import extract_archive + +def main(cfg): + urls, filename_paths, check_validity = download.utils.get_dataset_info(cfg) + for url, filename_path in zip(urls, filename_paths): + print(f"Downloading {filename_path} from {url} . . .") + download.utils.urlretrieve(url=url, filename=filename_path, check_validity=check_validity) + if not filename_path.endswith(".mdb"): + extract_archive(from_path=filename_path, to_path=cfg["root"], remove_finished=True) + + print("Downloads finished!") + +if __name__ == "__main__": + FLAGS = ArgsParser().parse_args() + cfg = Config(FLAGS.config) + FLAGS = vars(FLAGS) + opt = FLAGS.pop('opt') + cfg.merge_dict(FLAGS) + cfg.merge_dict(opt) + main(cfg.cfg) diff --git a/tools/download/utils.py b/tools/download/utils.py new file mode 100644 index 0000000..f83ae39 --- /dev/null +++ b/tools/download/utils.py @@ -0,0 +1,23 @@ +import urllib +import ssl +from tqdm import tqdm +import os + +def get_dataset_info(cfg): + download_urls, filenames, check_validity = cfg["download_links"], cfg["filenames"], cfg["check_validity"] + return download_urls, filenames, check_validity + +# Modified from torchvision as some datasets cant pass the certificate validity check: +# https://github.com/pytorch/vision/blob/868a3b42f4bffe29e4414ad7e4c7d9d0b4690ecb/torchvision/datasets/utils.py#L27C1-L32C40 +def urlretrieve(url, filename, chunk_size=1024 * 32, check_validity=True): + os.makedirs(os.path.dirname(filename), exist_ok=True) + ctx = ssl.create_default_context() + if not check_validity: + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + request = urllib.request.Request(url) + with urllib.request.urlopen(request, context=ctx) as response: + with open(filename, "wb") as fh, tqdm(total=response.length, unit="B", unit_scale=True) as pbar: + while chunk := response.read(chunk_size): + fh.write(chunk) + pbar.update(len(chunk)) \ No newline at end of file diff --git a/tools/engine/__init__.py b/tools/engine/__init__.py deleted file mode 100644 index a54e750..0000000 --- a/tools/engine/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from . import config, trainer -from .config import * -from .trainer import * - -__all__ = config.__all__ + trainer.__all__ diff --git a/tools/engine/trainer.py b/tools/engine/trainer.py index e2382b4..7802a36 100644 --- a/tools/engine/trainer.py +++ b/tools/engine/trainer.py @@ -1,18 +1,13 @@ -import copy import datetime import os import random import time import numpy as np -import torch from tqdm import tqdm -from openrec.losses import build_loss -from openrec.metrics import build_metric -from openrec.modeling import build_model -from openrec.optimizer import build_optimizer -from openrec.postprocess import build_post_process +import torch +import torch.distributed from tools.data import build_dataloader from tools.utils.ckpt import load_ckpt, save_ckpt from tools.utils.logging import get_logger @@ -31,9 +26,9 @@ def get_parameter_number(model): class Trainer(object): - def __init__(self, cfg, mode='train'): + def __init__(self, cfg, mode='train', task='rec'): self.cfg = cfg.cfg - + self.task = task self.local_rank = (int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0) self.set_device(self.cfg['Global']['device']) @@ -64,7 +59,7 @@ def __init__(self, cfg, mode='train'): self.writer = SummaryWriter(self.cfg['Global']['output_dir']) self.logger = get_logger( - 'openrec', + 'openrec' if task == 'rec' else 'opendet', os.path.join(self.cfg['Global']['output_dir'], 'train.log') if 'train' in mode else None, ) @@ -74,10 +69,6 @@ def __init__(self, cfg, mode='train'): if self.cfg['Global']['device'] == 'gpu' and self.device.type == 'cpu': self.logger.info('cuda is not available, auto switch to cpu') - self.grad_clip_val = self.cfg['Global'].get('grad_clip_val', 0) - self.all_ema = self.cfg['Global'].get('all_ema', True) - self.use_ema = self.cfg['Global'].get('use_ema', True) - self.set_random_seed(self.cfg['Global'].get('seed', 48)) # build data loader @@ -86,43 +77,38 @@ def __init__(self, cfg, mode='train'): cfg.save( os.path.join(self.cfg['Global']['output_dir'], 'config.yml'), self.cfg) - self.train_dataloader = build_dataloader(self.cfg, 'Train', - self.logger) + self.train_dataloader = build_dataloader(self.cfg, + 'Train', + self.logger, + task=task) self.logger.info( f'train dataloader has {len(self.train_dataloader)} iters') self.valid_dataloader = None if 'eval' in mode and self.cfg['Eval']: - self.valid_dataloader = build_dataloader(self.cfg, 'Eval', - self.logger) + self.valid_dataloader = build_dataloader(self.cfg, + 'Eval', + self.logger, + task=task) self.logger.info( f'valid dataloader has {len(self.valid_dataloader)} iters') - # build post process - self.post_process_class = build_post_process(self.cfg['PostProcess'], - self.cfg['Global']) - # build model - # for rec algorithm - char_num = self.post_process_class.get_character_num() - self.cfg['Architecture']['Decoder']['out_channels'] = char_num + if task == 'rec': + self._init_rec_model() + elif task == 'det': + self._init_det_model() + else: + raise NotImplementedError - self.model = build_model(self.cfg['Architecture']) self.logger.info(get_parameter_number(model=self.model)) self.model = self.model.to(self.device) - if self.local_rank == 0: - ema_model = build_model(self.cfg['Architecture']) - self.ema_model = ema_model.to(self.device) - self.ema_model.eval() - use_sync_bn = self.cfg['Global'].get('use_sync_bn', False) if use_sync_bn: self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( self.model) self.logger.info('convert_sync_batchnorm') - # build loss - self.loss_class = build_loss(self.cfg['Loss']) - + from openrec.optimizer import build_optimizer self.optimizer, self.lr_scheduler = None, None if self.train_dataloader is not None: # build optim @@ -133,8 +119,7 @@ def __init__(self, cfg, mode='train'): step_each_epoch=len(self.train_dataloader), model=self.model, ) - - self.eval_class = build_metric(self.cfg['Metric']) + self.grad_clip_val = self.cfg['Global'].get('grad_clip_val', 0) self.status = load_ckpt(self.model, self.cfg, self.optimizer, self.lr_scheduler) @@ -150,6 +135,41 @@ def __init__(self, cfg, mode='train'): self.logger.info( f'run with torch {torch.__version__} and device {self.device}') + def _init_rec_model(self): + from openrec.losses import build_loss as build_rec_loss + from openrec.metrics import build_metric as build_rec_metric + from openrec.modeling import build_model as build_rec_model + from openrec.postprocess import build_post_process as build_rec_post_process + + # build post process + self.post_process_class = build_rec_post_process( + self.cfg['PostProcess'], self.cfg['Global']) + # build model + # for rec algorithm + char_num = self.post_process_class.get_character_num() + self.cfg['Architecture']['Decoder']['out_channels'] = char_num + self.model = build_rec_model(self.cfg['Architecture']) + # build loss + self.loss_class = build_rec_loss(self.cfg['Loss']) + # build metric + self.eval_class = build_rec_metric(self.cfg['Metric']) + + def _init_det_model(self): + from opendet.losses import build_loss as build_det_loss + from opendet.metrics import build_metric as build_det_metric + from opendet.modeling import build_model as build_det_model + from opendet.postprocess import build_post_process as build_det_post_process + + # build post process + self.post_process_class = build_det_post_process( + self.cfg['PostProcess'], self.cfg['Global']) + # build detmodel + self.model = build_det_model(self.cfg['Architecture']) + # build loss + self.loss_class = build_det_loss(self.cfg['Loss']) + # build metric + self.eval_class = build_det_metric(self.cfg['Metric']) + def load_params(self, params): self.model.load_state_dict(params) @@ -211,43 +231,43 @@ def train(self): 'an evaluation is run every {} iterations'.format( start_eval_step, eval_batch_step)) + save_epoch_step = self.cfg['Global'].get('save_epoch_step', [0, 1]) + start_save_epoch = save_epoch_step[0] + save_epoch_step = save_epoch_step[1] + start_epoch = self.status.get('epoch', 1) - best_metric = self.status.get('metrics', {}) - if self.eval_class.main_indicator not in best_metric: - best_metric[self.eval_class.main_indicator] = 0 - ema_best_metric = self.status.get('metrics', {}) - ema_best_metric[self.eval_class.main_indicator] = 0 + self.best_metric = self.status.get('metrics', {}) + if self.eval_class.main_indicator not in self.best_metric: + self.best_metric[self.eval_class.main_indicator] = 0 train_stats = TrainingStats(log_smooth_window, ['lr']) self.model.train() total_samples = 0 train_reader_cost = 0.0 train_batch_cost = 0.0 - best_iter = 0 - ema_stpe = 1 - ema_eval_iter = 0 - loss_avg = 0. reader_start = time.time() eta_meter = AverageMeter() for epoch in range(start_epoch, epoch_num + 1): if self.train_dataloader.dataset.need_reset: - self.train_dataloader = build_dataloader( - self.cfg, - 'Train', - self.logger, - epoch=epoch % 20 if epoch % 20 != 0 else 20, - ) + self.train_dataloader = build_dataloader(self.cfg, + 'Train', + self.logger, + epoch=epoch, + task=self.task) for idx, batch in enumerate(self.train_dataloader): - batch = [t.to(self.device) for t in batch] + batch_tensor = [t.to(self.device) for t in batch] + batch_numpy = [t.numpy() for t in batch] self.optimizer.zero_grad() train_reader_cost += time.time() - reader_start # use amp if self.scaler: - with torch.cuda.amp.autocast(): - preds = self.model(batch[0], data=batch[1:]) - loss = self.loss_class(preds, batch) + with torch.cuda.amp.autocast( + enabled=self.device.type == 'cuda'): + preds = self.model(batch_tensor[0], + data=batch_tensor[1:]) + loss = self.loss_class(preds, batch_tensor) self.scaler.scale(loss['loss']).backward() if self.grad_clip_val > 0: torch.nn.utils.clip_grad_norm_( @@ -256,8 +276,8 @@ def train(self): self.scaler.step(self.optimizer) self.scaler.update() else: - preds = self.model(batch[0], data=batch[1:]) - loss = self.loss_class(preds, batch) + preds = self.model(batch_tensor[0], data=batch_tensor[1:]) + loss = self.loss_class(preds, batch_tensor) avg_loss = loss['loss'] avg_loss.backward() if self.grad_clip_val > 0: @@ -268,9 +288,9 @@ def train(self): if cal_metric_during_train: # only rec and cls need post_result = self.post_process_class(preds, - batch, + batch_numpy, training=True) - self.eval_class(post_result, batch, training=True) + self.eval_class(post_result, batch_numpy, training=True) metric = self.eval_class.get_metric() train_stats.update(metric) @@ -282,66 +302,6 @@ def train(self): self.lr_scheduler.step() - if self.local_rank == 0 and self.use_ema and epoch > ( - epoch_num - epoch_num // 10): - with torch.no_grad(): - loss_currn = loss['loss'].detach().cpu().numpy().mean() - loss_avg = ((loss_avg * - (ema_stpe - 1)) + loss_currn) / (ema_stpe) - if ema_stpe == 1: - - # current_weight = copy.deepcopy(self.model.module.state_dict()) - ema_state_dict = copy.deepcopy( - self.model.module.state_dict() if self. - cfg['Global']['distributed'] else self.model. - state_dict()) - self.ema_model.load_state_dict(ema_state_dict) - # if global_step > (epoch_num - epoch_num//10)*max_iter: - elif loss_currn <= loss_avg or self.all_ema: - # eval_batch_step = 500 - current_weight = copy.deepcopy( - self.model.module.state_dict() if self. - cfg['Global']['distributed'] else self.model. - state_dict()) - k1 = 1 / (ema_stpe + 1) - k2 = 1 - k1 - for k, v in ema_state_dict.items(): - # v = (v * (ema_stpe - 1) + current_weight[k])/ema_stpe - v = v * k2 + current_weight[k] * k1 - # v.req = True - ema_state_dict[k] = v - # ema_stpe += 1 - self.ema_model.load_state_dict(ema_state_dict) - ema_stpe += 1 - if global_step > start_eval_step and ( - global_step - - start_eval_step) % eval_batch_step == 0: - ema_cur_metric = self.eval_ema() - ema_cur_metric_str = f"cur ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_cur_metric.items()])}" - self.logger.info(ema_cur_metric_str) - state = { - 'epoch': epoch, - 'global_step': global_step, - 'state_dict': self.ema_model.state_dict(), - 'optimizer': None, - 'scheduler': None, - 'config': self.cfg, - 'metrics': ema_cur_metric, - } - save_path = os.path.join( - self.cfg['Global']['output_dir'], - 'ema_' + str(ema_eval_iter) + '.pth') - torch.save(state, save_path) - self.logger.info(f'save ema ckpt to {save_path}') - ema_eval_iter += 1 - if ema_cur_metric[self.eval_class. - main_indicator] >= ema_best_metric[ - self.eval_class.main_indicator]: - ema_best_metric.update(ema_cur_metric) - ema_best_metric['best_epoch'] = epoch - best_ema_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}" - self.logger.info(best_ema_str) - # logger stats = { k: float(v) @@ -356,8 +316,7 @@ def train(self): self.writer.add_scalar(f'TRAIN/{k}', v, global_step) if self.local_rank == 0 and ( - (global_step > 0 and global_step % print_batch_step == 0) - or (idx >= len(self.train_dataloader) - 1)): + (global_step > 0 and global_step % print_batch_step == 0) or (idx >= len(self.train_dataloader) - 1)): logs = train_stats.log() eta_sec = ( @@ -377,101 +336,15 @@ def train(self): train_reader_cost = 0.0 train_batch_cost = 0.0 reader_start = time.time() - # eval + # eval iter step if (global_step > start_eval_step and - (global_step - start_eval_step) % eval_batch_step - == 0) and self.local_rank == 0: - cur_metric = self.eval() - cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}" - self.logger.info(cur_metric_str) - - # logger metric - if self.writer is not None: - for k, v in cur_metric.items(): - if isinstance(v, (float, int)): - self.writer.add_scalar(f'EVAL/{k}', - cur_metric[k], - global_step) - - if (cur_metric[self.eval_class.main_indicator] >= - best_metric[self.eval_class.main_indicator]): - best_metric.update(cur_metric) - best_metric['best_epoch'] = epoch - if self.writer is not None: - self.writer.add_scalar( - f'EVAL/best_{self.eval_class.main_indicator}', - best_metric[self.eval_class.main_indicator], - global_step, - ) - if epoch > (epoch_num - epoch_num // 10 - 2): - save_ckpt(self.model, - self.cfg, - self.optimizer, - self.lr_scheduler, - epoch, - global_step, - best_metric, - is_best=True, - prefix='best_' + str(best_iter)) - best_iter += 1 - # else: - save_ckpt(self.model, - self.cfg, - self.optimizer, - self.lr_scheduler, - epoch, - global_step, - best_metric, - is_best=True, - prefix=None) - best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}" - self.logger.info(best_str) + (global_step - start_eval_step) % eval_batch_step == 0) and self.local_rank == 0: + self.eval_step(global_step, epoch) + + # eval epoch step if self.local_rank == 0 and epoch > start_eval_epoch and ( epoch - start_eval_epoch) % eval_epoch_step == 0: - cur_metric = self.eval() - cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}" - self.logger.info(cur_metric_str) - - # logger metric - if self.writer is not None: - for k, v in cur_metric.items(): - if isinstance(v, (float, int)): - self.writer.add_scalar(f'EVAL/{k}', cur_metric[k], - global_step) - - if (cur_metric[self.eval_class.main_indicator] >= - best_metric[self.eval_class.main_indicator]): - best_metric.update(cur_metric) - best_metric['best_epoch'] = epoch - if self.writer is not None: - self.writer.add_scalar( - f'EVAL/best_{self.eval_class.main_indicator}', - best_metric[self.eval_class.main_indicator], - global_step, - ) - if epoch > (epoch_num - epoch_num // 10 - 2): - save_ckpt(self.model, - self.cfg, - self.optimizer, - self.lr_scheduler, - epoch, - global_step, - best_metric, - is_best=True, - prefix='best_' + str(best_iter)) - best_iter += 1 - # else: - save_ckpt(self.model, - self.cfg, - self.optimizer, - self.lr_scheduler, - epoch, - global_step, - best_metric, - is_best=True, - prefix=None) - best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}" - self.logger.info(best_str) + self.eval_step(global_step, epoch) if self.local_rank == 0: save_ckpt(self.model, @@ -480,53 +353,65 @@ def train(self): self.lr_scheduler, epoch, global_step, - best_metric, + self.best_metric, is_best=False, prefix=None) - if epoch > (epoch_num - epoch_num // 10 - 2): + if epoch > start_save_epoch and ( + epoch - start_save_epoch) % save_epoch_step == 0: save_ckpt(self.model, self.cfg, self.optimizer, self.lr_scheduler, epoch, global_step, - best_metric, + self.best_metric, is_best=False, prefix='epoch_' + str(epoch)) - if self.use_ema and epoch > (epoch_num - epoch_num // 10): - # if global_step > start_eval_step and (global_step - start_eval_step) % eval_batch_step == 0: - ema_cur_metric = self.eval_ema() - ema_cur_metric_str = f"cur ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_cur_metric.items()])}" - self.logger.info(ema_cur_metric_str) - state = { - 'epoch': epoch, - 'global_step': global_step, - 'state_dict': self.ema_model.state_dict(), - 'optimizer': None, - 'scheduler': None, - 'config': self.cfg, - 'metrics': ema_cur_metric, - } - save_path = os.path.join( - self.cfg['Global']['output_dir'], - 'ema_' + str(ema_eval_iter) + '.pth') - torch.save(state, save_path) - self.logger.info(f'save ema ckpt to {save_path}') - ema_eval_iter += 1 - if (ema_cur_metric[self.eval_class.main_indicator] >= - ema_best_metric[self.eval_class.main_indicator]): - ema_best_metric.update(ema_cur_metric) - ema_best_metric['best_epoch'] = epoch - # ema_cur_metric_str = f"best ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}" - best_ema_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}" - self.logger.info(best_ema_str) - best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}" + + best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in self.best_metric.items()])}" self.logger.info(best_str) if self.writer is not None: self.writer.close() if torch.cuda.device_count() > 1: + torch.distributed.barrier() torch.distributed.destroy_process_group() + def eval_step(self, global_step, epoch): + cur_metric = self.eval() + cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}" + self.logger.info(cur_metric_str) + + # logger metric + if self.writer is not None: + for k, v in cur_metric.items(): + if isinstance(v, (float, int)): + self.writer.add_scalar(f'EVAL/{k}', cur_metric[k], + global_step) + + if (cur_metric[self.eval_class.main_indicator] >= + self.best_metric[self.eval_class.main_indicator]): + self.best_metric.update(cur_metric) + self.best_metric['best_epoch'] = epoch + + if self.writer is not None: + self.writer.add_scalar( + f'EVAL/best_{self.eval_class.main_indicator}', + self.best_metric[self.eval_class.main_indicator], + global_step, + ) + + save_ckpt(self.model, + self.cfg, + self.optimizer, + self.lr_scheduler, + epoch, + global_step, + self.best_metric, + is_best=True, + prefix=None) + best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in self.best_metric.items()])}" + self.logger.info(best_str) + def eval(self): self.model.eval() with torch.no_grad(): @@ -540,19 +425,22 @@ def eval(self): ) sum_images = 0 for idx, batch in enumerate(self.valid_dataloader): - batch = [t.to(self.device) for t in batch] + batch_tensor = [t.to(self.device) for t in batch] + batch_numpy = [t.numpy() for t in batch] start = time.time() if self.scaler: - with torch.cuda.amp.autocast(): - preds = self.model(batch[0], data=batch[1:]) + with torch.cuda.amp.autocast( + enabled=self.device.type == 'cuda'): + preds = self.model(batch_tensor[0], + data=batch_tensor[1:]) else: - preds = self.model(batch[0], data=batch[1:]) + preds = self.model(batch_tensor[0], data=batch_tensor[1:]) total_time += time.time() - start # Obtain usable results from post-processing methods # Evaluate the results of the current batch - post_result = self.post_process_class(preds, batch) - self.eval_class(post_result, batch) + post_result = self.post_process_class(preds, batch_numpy) + self.eval_class(post_result, batch_numpy) pbar.update(1) total_frame += len(batch[0]) @@ -565,44 +453,6 @@ def eval(self): metric['fps'] = total_frame / total_time return metric - def eval_ema(self): - # self.model.eval() - with torch.no_grad(): - total_frame = 0.0 - total_time = 0.0 - pbar = tqdm( - total=len(self.valid_dataloader), - desc='eval ema_model:', - position=0, - leave=True, - ) - sum_images = 0 - for idx, batch in enumerate(self.valid_dataloader): - batch = [t.to(self.device) for t in batch] - start = time.time() - if self.scaler: - with torch.cuda.amp.autocast(): - preds = self.ema_model(batch[0], data=batch[1:]) - else: - preds = self.ema_model(batch[0], data=batch[1:]) - - total_time += time.time() - start - # Obtain usable results from post-processing methods - # Evaluate the results of the current batch - post_result = self.post_process_class(preds, batch) - self.eval_class(post_result, batch) - - pbar.update(1) - total_frame += len(batch[0]) - sum_images += 1 - # Get final metric,eg. acc or hmean - metric = self.eval_class.get_metric() - - pbar.close() - # self.model.train() - metric['fps'] = total_frame / total_time - return metric - def test_dataloader(self): starttime = time.time() count = 0 diff --git a/tools/eval_det.py b/tools/eval_det.py new file mode 100644 index 0000000..dd0f1be --- /dev/null +++ b/tools/eval_det.py @@ -0,0 +1,42 @@ +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) + +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) + +from tools.engine.config import Config +from tools.engine.trainer import Trainer +from tools.utility import ArgsParser + + +def parse_args(): + parser = ArgsParser() + args = parser.parse_args() + return args + + +def main(): + FLAGS = parse_args() + cfg = Config(FLAGS.config) + FLAGS = vars(FLAGS) + opt = FLAGS.pop('opt') + cfg.merge_dict(FLAGS) + cfg.merge_dict(opt) + trainer = Trainer(cfg, mode='eval', task='det') + + best_model_dict = trainer.status.get('metrics', {}) + trainer.logger.info('metric in ckpt ***************') + for k, v in best_model_dict.items(): + trainer.logger.info('{}:{}'.format(k, v)) + + metric = trainer.eval() + + trainer.logger.info('metric eval ***************') + for k, v in metric.items(): + trainer.logger.info('{}:{}'.format(k, v)) + + +if __name__ == '__main__': + main() diff --git a/tools/eval_rec.py b/tools/eval_rec.py index 649c20c..509a280 100644 --- a/tools/eval_rec.py +++ b/tools/eval_rec.py @@ -6,7 +6,8 @@ sys.path.append(__dir__) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) -from tools.engine import Config, Trainer +from tools.engine.config import Config +from tools.engine.trainer import Trainer from tools.utility import ArgsParser diff --git a/tools/eval_rec_all_ch.py b/tools/eval_rec_all_ch.py index 1b7671b..efa8d22 100644 --- a/tools/eval_rec_all_ch.py +++ b/tools/eval_rec_all_ch.py @@ -9,7 +9,8 @@ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) from tools.data import build_dataloader -from tools.engine import Config, Trainer +from tools.engine.config import Config +from tools.engine.trainer import Trainer from tools.utility import ArgsParser diff --git a/tools/eval_rec_all_en.py b/tools/eval_rec_all_en.py index a8c11c1..400f51f 100644 --- a/tools/eval_rec_all_en.py +++ b/tools/eval_rec_all_en.py @@ -9,7 +9,8 @@ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) from tools.data import build_dataloader -from tools.engine import Config, Trainer +from tools.engine.config import Config +from tools.engine.trainer import Trainer from tools.utility import ArgsParser diff --git a/tools/eval_rec_all_long.py b/tools/eval_rec_all_long.py index ae60e5b..4ea58f8 100644 --- a/tools/eval_rec_all_long.py +++ b/tools/eval_rec_all_long.py @@ -10,7 +10,8 @@ import numpy as np from tools.data import build_dataloader -from tools.engine import Config, Trainer +from tools.engine.config import Config +from tools.engine.trainer import Trainer from tools.utility import ArgsParser diff --git a/tools/eval_rec_all_long_simple.py b/tools/eval_rec_all_long_simple.py index a86d4df..7c2cfe4 100644 --- a/tools/eval_rec_all_long_simple.py +++ b/tools/eval_rec_all_long_simple.py @@ -10,7 +10,8 @@ import numpy as np from tools.data import build_dataloader -from tools.engine import Config, Trainer +from tools.engine.config import Config +from tools.engine.trainer import Trainer from tools.utility import ArgsParser diff --git a/tools/export_rec.py b/tools/export_rec.py index 882ed3a..17cb5d6 100644 --- a/tools/export_rec.py +++ b/tools/export_rec.py @@ -9,7 +9,7 @@ from openrec.modeling import build_model from openrec.postprocess import build_post_process -from tools.engine import Config +from tools.engine.config import Config from tools.infer_rec import build_rec_process from tools.utility import ArgsParser from tools.utils.ckpt import load_ckpt diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 0906014..7a50516 100644 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -13,7 +13,7 @@ from openrec.postprocess import build_post_process from openrec.preprocess import create_operators, transform -from tools.engine import Config +from tools.engine.config import Config from tools.infer.onnx_engine import ONNXEngine from tools.infer.utility import check_gpu, parse_args from tools.utils.logging import get_logger diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 944f8cb..ae1b05c 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -3,11 +3,11 @@ import cv2 import numpy as np -import torch import PIL from PIL import Image, ImageDraw, ImageFont import random + def str2bool(v): return v.lower() in ('true', 'yes', 't', 'y', '1') @@ -77,44 +77,43 @@ def parse_args(): parser = init_args() return parser.parse_args() -def create_font(txt, sz, font_path="./doc/fonts/simfang.ttf"): + +def create_font(txt, sz, font_path='./doc/fonts/simfang.ttf'): font_size = int(sz[1] * 0.99) - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") - if int(PIL.__version__.split(".")[0]) < 10: + font = ImageFont.truetype(font_path, font_size, encoding='utf-8') + if int(PIL.__version__.split('.')[0]) < 10: length = font.getsize(txt)[0] else: length = font.getlength(txt) if length > sz[0]: font_size = int(font_size * sz[0] / length) - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + font = ImageFont.truetype(font_path, font_size, encoding='utf-8') return font -def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"): + +def draw_box_txt_fine(img_size, box, txt, font_path='./doc/fonts/simfang.ttf'): box_height = int( - math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2) - ) + math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][1])**2)) box_width = int( - math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2) - ) + math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][1])**2)) if box_height > 2 * box_width and box_height > 30: - img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255)) + img_text = Image.new('RGB', (box_height, box_width), (255, 255, 255)) draw_text = ImageDraw.Draw(img_text) if txt: font = create_font(txt, (box_height, box_width), font_path) draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font) img_text = img_text.transpose(Image.ROTATE_270) else: - img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255)) + img_text = Image.new('RGB', (box_width, box_height), (255, 255, 255)) draw_text = ImageDraw.Draw(img_text) if txt: font = create_font(txt, (box_width, box_height), font_path) draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font) - pts1 = np.float32( - [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]] - ) + pts1 = np.float32([[0, 0], [box_width, 0], [box_width, box_height], + [0, box_height]]) pts2 = np.array(box, dtype=np.float32) M = cv2.getPerspectiveTransform(pts1, pts2) @@ -129,13 +128,14 @@ def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"): ) return img_right_text + def draw_ocr_box_txt( image, boxes, txts=None, scores=None, drop_score=0.5, - font_path="./doc/fonts/simfang.ttf", + font_path='./doc/fonts/simfang.ttf', ): h, w = image.height, image.width img_left = image.copy() @@ -148,7 +148,8 @@ def draw_ocr_box_txt( for idx, (box, txt) in enumerate(zip(boxes, txts)): if scores is not None and scores[idx] < drop_score: continue - color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + color = (random.randint(0, 255), random.randint(0, 255), + random.randint(0, 255)) if isinstance(box[0], list): box = list(map(tuple, box)) draw_left.polygon(box, fill=color) @@ -157,7 +158,7 @@ def draw_ocr_box_txt( cv2.polylines(img_right_text, [pts], True, color, 1) img_right = cv2.bitwise_and(img_right, img_right_text) img_left = Image.blend(image, img_left, 0.5) - img_show = Image.new("RGB", (w * 2, h), (255, 255, 255)) + img_show = Image.new('RGB', (w * 2, h), (255, 255, 255)) img_show.paste(img_left, (0, 0, w, h)) img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h)) return np.array(img_show) @@ -225,6 +226,7 @@ def get_minarea_rect_crop(img, points): def check_gpu(use_gpu): + import torch if use_gpu and not torch.cuda.is_available(): use_gpu = False return use_gpu diff --git a/tools/infer_det.py b/tools/infer_det.py index 67cdef3..dba47bc 100644 --- a/tools/infer_det.py +++ b/tools/infer_det.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from pathlib import Path import time import numpy as np @@ -16,15 +17,74 @@ import cv2 import json -import torch -from tools.engine import Config +from tools.engine.config import Config from tools.utility import ArgsParser -from tools.utils.ckpt import load_ckpt from tools.utils.logging import get_logger from tools.utils.utility import get_image_file_list +logger = get_logger() + +root_dir = Path(__file__).resolve().parent +DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml') + +MODEL_NAME_DET = './openocr_det_repvit_ch.pth' # 模型文件名称 +DOWNLOAD_URL_DET = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth' # 模型文件 URL +MODEL_NAME_DET_ONNX = './openocr_det_model.onnx' # 模型文件名称 +DOWNLOAD_URL_DET_ONNX = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_model.onnx' # 模型文件 URL + + +def check_and_download_model(model_name: str, url: str): + """ + 检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。 + + Args: + model_name (str): 模型文件的名称,例如 "model.pt" + url (str): 模型文件的下载地址 + + Returns: + str: 模型文件的完整路径 + """ + if os.path.exists(model_name): + return model_name + + # 固定缓存路径为用户主目录下的 ".cache/openocr" + cache_dir = Path.home() / '.cache' / 'openocr' + model_path = cache_dir / model_name + + # 如果模型文件已存在,直接返回路径 + if model_path.exists(): + logger.info(f'Model already exists at: {model_path}') + return str(model_path) + + # 如果文件不存在,下载模型 + logger.info(f'Model not found. Downloading from {url}...') + + # 创建缓存目录(如果不存在) + cache_dir.mkdir(parents=True, exist_ok=True) + + try: + # 下载文件 + import urllib.request + with urllib.request.urlopen(url) as response, open(model_path, + 'wb') as out_file: + out_file.write(response.read()) + logger.info(f'Model downloaded and saved at: {model_path}') + return str(model_path) + + except Exception as e: + logger.error(f'Error downloading the model: {e}') + # 提示用户手动下载 + logger.error( + f'Unable to download the model automatically. ' + f'Please download the model manually from the following URL:\n{url}\n' + f'and save it to: {model_name} or {model_path}') + raise RuntimeError( + f'Failed to download the model. Please download it manually from {url} ' + f'and save it to {model_path}') from e + def replace_batchnorm(net): + import torch for child_name, child in net.named_children(): if hasattr(child, 'fuse'): fused = child.fuse() @@ -36,117 +96,6 @@ def replace_batchnorm(net): replace_batchnorm(child) -def padding_image(img, size=(640, 640)): - """ - Padding an image using OpenCV: - - If the image is smaller than the target size, pad it to 640x640. - - If the image is larger than the target size, split it into multiple 640x640 images and record positions. - - :param image_path: Path to the input image. - :param output_dir: Directory to save the output images. - :param size: The target size for padding or splitting (default 640x640). - :return: List of tuples containing the coordinates of the top-left corner of each cropped 640x640 image. - """ - - img_height, img_width = img.shape[:2] - target_width, target_height = size - - # If image is smaller than target size, pad the image to 640x640 - - # Calculate padding amounts (top, bottom, left, right) - pad_top = 0 - pad_bottom = target_height - img_height - pad_left = 0 - pad_right = target_width - img_width - - # Pad the image (white padding, border type: constant) - padded_img = cv2.copyMakeBorder(img, - pad_top, - pad_bottom, - pad_left, - pad_right, - cv2.BORDER_CONSTANT, - value=[0, 0, 0]) - - # Return the padded area positions (top-left and bottom-right coordinates of the original image) - return padded_img - - -def resize_image(img, size=(640, 640), over_lap=64): - """ - Resize an image using OpenCV: - - If the image is smaller than the target size, pad it to 640x640. - - If the image is larger than the target size, split it into multiple 640x640 images and record positions. - - :param image_path: Path to the input image. - :param output_dir: Directory to save the output images. - :param size: The target size for padding or splitting (default 640x640). - :return: List of tuples containing the coordinates of the top-left corner of each cropped 640x640 image. - """ - - img_height, img_width = img.shape[:2] - target_width, target_height = size - - # If image is smaller than target size, pad the image to 640x640 - if img_width <= target_width and img_height <= target_height: - # Calculate padding amounts (top, bottom, left, right) - if img_width == target_width and img_height == target_height: - return [img], [[0, 0, img_width, img_height]] - padded_img = padding_image(img, size) - - # Return the padded area positions (top-left and bottom-right coordinates of the original image) - return [padded_img], [[0, 0, img_width, img_height]] - - img_height, img_width = img.shape[:2] - # If image is larger than or equal to target size, crop it into 640x640 tiles - crop_positions = [] - count = 0 - cropped_img_list = [] - for top in range(0, img_height - over_lap, target_height - over_lap): - for left in range(0, img_width - over_lap, target_width - over_lap): - # Calculate the bottom and right boundaries for the crop - right = min(left + target_width, img_width) - bottom = min(top + target_height, img_height) - if right >= img_width: - right = img_width - left = max(0, right - target_width) - if bottom >= img_height: - bottom = img_height - top = max(0, bottom - target_height) - # Crop the image - cropped_img = img[top:bottom, left:right] - if bottom - top < target_height or right - left < target_width: - cropped_img = padding_image(cropped_img, size) - count += 1 - cropped_img_list.append(cropped_img) - - # Record the position of the cropped image - crop_positions.append([left, top, right, bottom]) - - return cropped_img_list, crop_positions - - -def restore_preds(preds, crop_positions, original_size): - - restored_pred = torch.zeros((1, 1, original_size[0], original_size[1]), - dtype=preds.dtype, - device=preds.device) - count = 0 - for cropped_pred, (left, top, right, bottom) in zip(preds, crop_positions): - - crop_height = bottom - top - crop_width = right - left - - corp_vis_img = cropped_pred[:, :crop_height, :crop_width] - mask = corp_vis_img > 0.3 - count += 1 - restored_pred[:, :, top:top + crop_height, left:left + - crop_width] += mask[:, :crop_height, :crop_width].to( - preds.dtype) - - return restored_pred - - def draw_det_res(dt_boxes, img, img_name, save_path): src_im = img for box in dt_boxes: @@ -159,48 +108,68 @@ def draw_det_res(dt_boxes, img, img_name, save_path): def set_device(device, numId=0): + import torch if device == 'gpu' and torch.cuda.is_available(): device = torch.device(f'cuda:{numId}') else: + logger.info('GPU is not available, using CPU.') device = torch.device('cpu') return device class OpenDetector(object): - def __init__(self, config=None, numId=0): + def __init__(self, + config=None, + backend='torch', + onnx_model_path=None, + numId=0): """ - 初始化函数。 - Args: - config (dict, optional): 配置文件,默认为None。如果为None,则使用默认配置文件。 - numId (int, optional): 设备编号,默认为0。 - - Returns: - None - - Raises: - 无 + config (dict, optional): 配置信息。默认为None。 + backend (str): 'torch' 或 'onnx' + onnx_model_path (str): ONNX模型路径(仅当backend='onnx'时需要) + numId (int, optional): 设备编号。默认为0。 """ if config is None: - config = Config('./configs/det/dbnet/repvit_db.yml').cfg + config = Config(DEFAULT_CFG_PATH_DET).cfg + + self._init_common(config) + backend = backend if config['Global'].get( + 'backend', None) is None else config['Global']['backend'] + self.backend = backend + if backend == 'torch': + import torch + self.torch = torch + if config['Architecture']['algorithm'] == 'DB_mobile': + if not os.path.exists(config['Global']['pretrained_model']): + config['Global'][ + 'pretrained_model'] = check_and_download_model( + MODEL_NAME_DET, DOWNLOAD_URL_DET) + self._init_torch_model(config, numId) + elif backend == 'onnx': + from tools.infer.onnx_engine import ONNXEngine + onnx_model_path = onnx_model_path if config['Global'].get( + 'onnx_model_path', + None) is None else config['Global']['onnx_model_path'] + if onnx_model_path is None: + if config['Architecture']['algorithm'] == 'DB_mobile': + onnx_model_path = check_and_download_model( + MODEL_NAME_DET_ONNX, DOWNLOAD_URL_DET_ONNX) + else: + raise ValueError('ONNX模式需要指定onnx_model_path参数') + self.onnx_det_engine = ONNXEngine( + onnx_model_path, use_gpu=config['Global']['device'] == 'gpu') + else: + raise ValueError("backend参数必须是'torch'或'onnx'") - from opendet.modeling import build_model as build_det_model + def _init_common(self, config): from opendet.postprocess import build_post_process from opendet.preprocess import create_operators, transform - self.transform = transform global_config = config['Global'] - - # build model - self.model = build_det_model(config['Architecture']) - self.model.eval() - load_ckpt(self.model, config) - replace_batchnorm(self.model.backbone) - self.device = set_device(config['Global']['device'], numId=numId) - self.model.to(device=self.device) - # create data ops + self.transform = transform transforms = [] for op in config['Eval']['dataset']['transforms']: op_name = list(op)[0] @@ -211,86 +180,34 @@ def __init__(self, config=None, numId=0): transforms.append(op) self.ops = create_operators(transforms, global_config) - - save_res_path = config['Global']['save_res_path'] - if not os.path.exists(os.path.dirname(save_res_path)): - os.makedirs(os.path.dirname(save_res_path)) - # build post process self.post_process_class = build_post_process(config['PostProcess'], global_config) - def crop_infer( - self, - img_path=None, - img_numpy_list=None, - img_numpy=None, - ): - if img_numpy is not None: - img_numpy_list = [img_numpy] - num_img = 1 - elif img_path is not None: - num_img = len(img_path) - elif img_numpy_list is not None: - num_img = len(img_numpy_list) - else: - raise Exception('No input image path or numpy array.') - results = [] - for img_idx in range(num_img): - if img_numpy_list is not None: - img = img_numpy_list[img_idx] - data = {'image': img} - elif img_path is not None: - with open(img_path[img_idx], 'rb') as f: - img = f.read() - data = {'image': img} - data = self.transform(data, self.ops[:1]) - src_img_ori = data['image'] - img_height, img_width = src_img_ori.shape[:2] - - target_size = 640 - over_lap = 64 - if img_height > img_width: - r_h = target_size * 2 - over_lap - r_w = img_width * (target_size * 2 - over_lap) // img_height - else: - r_w = target_size * 2 - over_lap - r_h = img_height * (target_size * 2 - over_lap) // img_width - src_img = cv2.resize(src_img_ori, (r_w, r_h)) - shape_list_ori = np.array([[ - img_height, img_width, - float(r_h) / img_height, - float(r_w) / img_width - ]]) - img_height, img_width = src_img.shape[:2] - cropped_img_list, crop_positions = resize_image(src_img, - size=(target_size, - target_size), - over_lap=over_lap) - - image_list = [] - shape_list = [] - for img in cropped_img_list: - batch_i = self.transform({'image': img}, self.ops[-3:-1]) - image_list.append(batch_i['image']) - shape_list.append([640, 640, 1, 1]) - images = np.array(image_list) - shape_list = np.array(shape_list) - images = torch.from_numpy(images).to(device=self.device) + def _init_torch_model(self, config, numId=0): - t_start = time.time() - preds = self.model(images) - torch.cuda.synchronize() - t_cost = time.time() - t_start + from opendet.modeling import build_model as build_det_model + from tools.utils.ckpt import load_ckpt - preds['maps'] = restore_preds(preds['maps'], crop_positions, - (img_height, img_width)) - post_result = self.post_process_class(preds, shape_list_ori) - info = {'boxes': post_result[0]['points'], 'elapse': t_cost} - results.append(info) - return results + # build model + self.model = build_det_model(config['Architecture']) + self.model.eval() + load_ckpt(self.model, config) + if config['Architecture']['algorithm'] == 'DB_mobile': + replace_batchnorm(self.model.backbone) + self.device = set_device(config['Global']['device'], numId=numId) + self.model.to(device=self.device) - def __call__(self, img_path=None, img_numpy_list=None, img_numpy=None): + def _inference_onnx(self, images): + # ONNX输入需要为numpy数组 + return self.onnx_det_engine.run(images) + + def __call__(self, + img_path=None, + img_numpy_list=None, + img_numpy=None, + return_mask=False, + **kwargs): """ 对输入图像进行处理,并返回处理结果。 @@ -328,35 +245,49 @@ def __call__(self, img_path=None, img_numpy_list=None, img_numpy=None): img = f.read() data = {'image': img} data = self.transform(data, self.ops[:1]) + if kwargs.get('det_input_size', None) is not None: + data['max_sile_len'] = kwargs['det_input_size'] batch = self.transform(data, self.ops[1:]) images = np.expand_dims(batch[0], axis=0) shape_list = np.expand_dims(batch[1], axis=0) - images = torch.from_numpy(images).to(device=self.device) - with torch.no_grad(): - t_start = time.time() - preds = self.model(images) - t_cost = time.time() - t_start - post_result = self.post_process_class(preds, shape_list) + t_start = time.time() + + if self.backend == 'torch': + images = self.torch.from_numpy(images).to(device=self.device) + with self.torch.no_grad(): + preds = self.model(images) + kwargs['torch_tensor'] = True + elif self.backend == 'onnx': + preds_det = self._inference_onnx(images) + preds = {'maps': preds_det[0]} + kwargs['torch_tensor'] = False + + t_cost = time.time() - t_start + post_result = self.post_process_class(preds, [None, shape_list], + **kwargs) info = {'boxes': post_result[0]['points'], 'elapse': t_cost} + if return_mask: + if isinstance(preds['maps'], self.torch.Tensor): + mask = preds['maps'].detach().cpu().numpy() + else: + mask = preds['maps'] + info['mask'] = mask results.append(info) return results -@torch.no_grad() def main(cfg): - logger = get_logger() is_visualize = cfg['Global'].get('is_visualize', False) model = OpenDetector(cfg) - save_res_path = cfg['Global']['output_dir'] + save_res_path = './det_results/' if not os.path.exists(save_res_path): os.makedirs(save_res_path) sample_num = 0 with open(save_res_path + '/det_results.txt', 'wb') as fout: for file in get_image_file_list(cfg['Global']['infer_img']): - preds_result = model(img_path=file)[0] logger.info('{} infer_img: {}, time cost: {}'.format( sample_num, file, preds_result['elapse'])) @@ -368,14 +299,16 @@ def main(cfg): dt_boxes_json.append(tmp_json) if is_visualize: src_img = cv2.imread(file) - save_det_path = save_res_path + '/det_results/' - draw_det_res(boxes, src_img, file, save_det_path) + draw_det_res(boxes, src_img, file, save_res_path) logger.info('The detected Image saved in {}'.format( - os.path.join(save_det_path, os.path.basename(file)))) + os.path.join(save_res_path, os.path.basename(file)))) otstr = file + '\t' + json.dumps(dt_boxes_json) + '\n' logger.info('results: {}'.format(json.dumps(dt_boxes_json))) fout.write(otstr.encode()) sample_num += 1 + logger.info( + f"Results saved to {os.path.join(save_res_path, 'det_results.txt')}.)" + ) logger.info('success!') diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py index 54633c9..fbe143d 100644 --- a/tools/infer_e2e.py +++ b/tools/infer_e2e.py @@ -3,6 +3,7 @@ from __future__ import print_function import os +from pathlib import Path import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) @@ -17,32 +18,37 @@ import cv2 import json from PIL import Image -import torch from tools.utils.utility import get_image_file_list, check_and_read from tools.infer_rec import OpenRecognizer from tools.infer_det import OpenDetector -from tools.engine import Config +from tools.engine.config import Config from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt +from tools.utils.logging import get_logger +root_dir = Path(__file__).resolve().parent +DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml') +DEFAULT_CFG_PATH_REC_SERVER = str(root_dir / + '../configs/rec/svtrv2/svtrv2_ch.yml') +DEFAULT_CFG_PATH_REC = str(root_dir / '../configs/rec/svtrv2/repsvtr_ch.yml') -def set_device(device): - if device == 'gpu' and torch.cuda.is_available(): - device = torch.device('cuda:0') - else: - device = torch.device('cpu') - return device +logger = get_logger() def check_and_download_font(font_path): if not os.path.exists(font_path): - print(f"Downloading '{font_path}' ...") + cache_dir = Path.home() / '.cache' / 'openocr' + font_path = str(cache_dir / font_path) + if os.path.exists(font_path): + return font_path + logger.info(f"Downloading '{font_path}' ...") try: import urllib.request font_url = 'https://shuiche-shop.oss-cn-chengdu.aliyuncs.com/fonts/simfang.ttf' urllib.request.urlretrieve(font_url, font_path) - print(f'Downloading font success: {font_path}') + logger.info(f'Downloading font success: {font_path}') except Exception as e: - print(f'Downloading font error: {e}') + logger.info(f'Downloading font error: {e}') + return font_path def sorted_boxes(dt_boxes): @@ -71,7 +77,14 @@ def sorted_boxes(dt_boxes): class OpenOCR(object): - def __init__(self, mode='mobile', drop_score=0.5, det_box_type='quad'): + def __init__(self, + mode='mobile', + backend='torch', + onnx_det_model_path=None, + onnx_rec_model_path=None, + drop_score=0.5, + det_box_type='quad', + device='gpu'): """ 初始化函数,用于初始化OCR引擎的相关配置和组件。 @@ -84,16 +97,20 @@ def __init__(self, mode='mobile', drop_score=0.5, det_box_type='quad'): 无返回值。 """ - cfg_det = Config( - './configs/det/dbnet/repvit_db.yml').cfg # mobile model + cfg_det = Config(DEFAULT_CFG_PATH_DET).cfg # mobile model + cfg_det['Global']['device'] = device if mode == 'server': - cfg_rec = Config( - './configs/det/svtrv2/svtrv2_ch.yml').cfg # server model + cfg_rec = Config(DEFAULT_CFG_PATH_REC_SERVER).cfg # server model else: - cfg_rec = Config( - './configs/rec/svtrv2/repsvtr_ch.yml').cfg # mobile model - self.text_detector = OpenDetector(cfg_det) - self.text_recognizer = OpenRecognizer(cfg_rec) + cfg_rec = Config(DEFAULT_CFG_PATH_REC).cfg # mobile model + + cfg_rec['Global']['device'] = device + + self.text_detector = OpenDetector(cfg_det, + backend=backend, + onnx_model_path=onnx_det_model_path) + self.text_recognizer = OpenRecognizer( + cfg_rec, backend=backend, onnx_model_path=onnx_rec_model_path) self.det_box_type = det_box_type self.drop_score = drop_score @@ -114,14 +131,19 @@ def infer_single_image(self, img_numpy, ori_img, crop_infer=False, - rec_batch_num=6): + rec_batch_num=6, + return_mask=False, + **kwargs): start = time.time() if crop_infer: dt_boxes = self.text_detector.crop_infer( img_numpy=img_numpy)[0]['boxes'] else: - dt_boxes = self.text_detector(img_numpy=img_numpy)[0]['boxes'] - # print(dt_boxes) + det_res = self.text_detector(img_numpy=img_numpy, + return_mask=return_mask, + **kwargs)[0] + dt_boxes = det_res['boxes'] + # logger.info(dt_boxes) det_time_cost = time.time() - start if dt_boxes is None: @@ -155,6 +177,13 @@ def infer_single_image(self, avg_rec_time_cost = rec_time_cost_sig / len(dt_boxes) if len( dt_boxes) > 0 else 0.0 + if return_mask: + return filter_boxes, filter_rec_res, { + 'time_cost': det_time_cost + rec_time_cost, + 'detection_time': det_time_cost, + 'recognition_time': rec_time_cost, + 'avg_rec_time_cost': avg_rec_time_cost + }, det_res['mask'] return filter_boxes, filter_rec_res, { 'time_cost': det_time_cost + rec_time_cost, @@ -169,7 +198,9 @@ def __call__(self, is_visualize=False, img_numpy=None, rec_batch_num=6, - crop_infer=False): + crop_infer=False, + return_mask=False, + **kwargs): """ img_path: str, optional, default=None Path to the directory containing images or the image filename. @@ -194,11 +225,21 @@ def __call__(self, time_dicts = [] for index, img in enumerate(img_numpy): ori_img = img.copy() - dt_boxes, rec_res, time_dict = self.infer_single_image( - img_numpy=img, - ori_img=ori_img, - crop_infer=crop_infer, - rec_batch_num=rec_batch_num) + if return_mask: + dt_boxes, rec_res, time_dict, mask = self.infer_single_image( + img_numpy=img, + ori_img=ori_img, + crop_infer=crop_infer, + rec_batch_num=rec_batch_num, + return_mask=return_mask, + **kwargs) + else: + dt_boxes, rec_res, time_dict = self.infer_single_image( + img_numpy=img, + ori_img=ori_img, + crop_infer=crop_infer, + rec_batch_num=rec_batch_num, + **kwargs) if dt_boxes is None: results.append([]) time_dicts.append({}) @@ -210,6 +251,8 @@ def __call__(self, } for i in range(len(dt_boxes))] results.append(res) time_dicts.append(time_dict) + if return_mask: + return results, time_dicts, mask return results, time_dicts image_file_list = get_image_file_list(img_path) @@ -225,7 +268,8 @@ def __call__(self, imgs = [img] else: imgs = img - print(f'Processing {idx+1}/{len(image_file_list)}: {image_file}') + logger.info( + f'Processing {idx+1}/{len(image_file_list)}: {image_file}') res_list = [] time_dicts = [] @@ -235,7 +279,8 @@ def __call__(self, img_numpy=img_numpy, ori_img=ori_img, crop_infer=crop_infer, - rec_batch_num=rec_batch_num) + rec_batch_num=rec_batch_num, + **kwargs) if dt_boxes is None: res_list.append([]) time_dicts.append({}) @@ -252,10 +297,10 @@ def __call__(self, time_dicts)): if len(res) > 0: - print(f'Results: {res}.') - print(f'Time cost: {time_dict}.') + logger.info(f'Results: {res}.') + logger.info(f'Time cost: {time_dict}.') else: - print('No text detected.') + logger.info('No text detected.') if len(res_list) > 1: save_pred = (os.path.basename(image_file) + '_' + @@ -274,12 +319,12 @@ def __call__(self, if is_visualize and len(res) > 0: if idx == 0: font_path = './simfang.ttf' - check_and_download_font(font_path) + font_path = check_and_download_font(font_path) os.makedirs(save_dir, exist_ok=True) draw_img_save_dir = os.path.join( save_dir, 'vis_results/') os.makedirs(draw_img_save_dir, exist_ok=True) - print( + logger.info( f'Visualized results will be saved to {draw_img_save_dir}.' ) dt_boxes = [res[i]['points'] for i in range(len(res))] @@ -320,14 +365,15 @@ def __call__(self, 'w', encoding='utf-8') as f: f.writelines(save_results) - print( + logger.info( f"Results saved to {os.path.join(save_dir, 'system_results.txt')}." ) if is_visualize: - print(f'Visualized results saved to {draw_img_save_dir}.') + logger.info( + f'Visualized results saved to {draw_img_save_dir}.') return save_results, time_dicts_return else: - print('No text detected.') + logger.info('No text detected.') return None, None @@ -342,6 +388,19 @@ def main(): type=str, default='mobile', help="Mode of the OCR system, e.g., 'mobile' or 'server'.") + parser.add_argument( + '--backend', + type=str, + default='torch', + help="Backend of the OCR system, e.g., 'torch' or 'onnx'.") + parser.add_argument('--onnx_det_model_path', + type=str, + default=None, + help='Path to the ONNX model for text detection.') + parser.add_argument('--onnx_rec_model_path', + type=str, + default=None, + help='Path to the ONNX model for text recognition.') parser.add_argument( '--save_dir', type=str, @@ -356,16 +415,29 @@ def main(): type=float, default=0.5, help='Score threshold for text recognition.') + parser.add_argument('--device', + type=str, + default='gpu', + help='Device to use for inference.') args = parser.parse_args() img_path = args.img_path mode = args.mode + backend = args.backend + onnx_det_model_path = args.onnx_det_model_path + onnx_rec_model_path = args.onnx_rec_model_path save_dir = args.save_dir is_visualize = args.is_vis drop_score = args.drop_score - - text_sys = OpenOCR(mode=mode, drop_score=drop_score, - det_box_type='quad') # det_box_type: 'quad' or 'poly' + device = args.device + + text_sys = OpenOCR(mode=mode, + backend=backend, + onnx_det_model_path=onnx_det_model_path, + onnx_rec_model_path=onnx_rec_model_path, + drop_score=drop_score, + det_box_type='quad', + device=device) # det_box_type: 'quad' or 'poly' text_sys(img_path=img_path, save_dir=save_dir, is_visualize=is_visualize) diff --git a/tools/infer_e2e_parallel.py b/tools/infer_e2e_parallel.py index f13a5e6..c9b1697 100644 --- a/tools/infer_e2e_parallel.py +++ b/tools/infer_e2e_parallel.py @@ -20,7 +20,7 @@ from tools.infer_rec import OpenRecognizer from tools.infer_det import OpenDetector from tools.infer_e2e import check_and_download_font, sorted_boxes -from tools.engine import Config +from tools.engine.config import Config from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 033e788..9e79d50 100644 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import sys import time @@ -8,15 +9,75 @@ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) import numpy as np -import torch -from torchvision import transforms as T -from torchvision.transforms import functional as F -from tools.engine import Config +from tools.engine.config import Config from tools.utility import ArgsParser -from tools.utils.ckpt import load_ckpt from tools.utils.logging import get_logger from tools.utils.utility import get_image_file_list -from tools.infer_det import replace_batchnorm + +logger = get_logger() + +root_dir = Path(__file__).resolve().parent +DEFAULT_CFG_PATH_REC_SERVER = str(root_dir / + '../configs/rec/svtrv2/svtrv2_ch.yml') +DEFAULT_CFG_PATH_REC = str(root_dir / '../configs/rec/svtrv2/repsvtr_ch.yml') +DEFAULT_DICT_PATH_REC = str(root_dir / './utils/ppocr_keys_v1.txt') + +MODEL_NAME_REC = './openocr_repsvtr_ch.pth' # 模型文件名称 +DOWNLOAD_URL_REC = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth' # 模型文件 URL +MODEL_NAME_REC_SERVER = './openocr_svtrv2_ch.pth' # 模型文件名称 +DOWNLOAD_URL_REC_SERVER = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth' # 模型文件 URL +MODEL_NAME_REC_ONNX = './openocr_rec_model.onnx' # 模型文件名称 +DOWNLOAD_URL_REC_ONNX = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_rec_model.onnx' # 模型文件 URL + + +def check_and_download_model(model_name: str, url: str): + """ + 检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。 + + Args: + model_name (str): 模型文件的名称,例如 "model.pt" + url (str): 模型文件的下载地址 + + Returns: + str: 模型文件的完整路径 + """ + if os.path.exists(model_name): + return model_name + + # 固定缓存路径为用户主目录下的 ".cache/openocr" + cache_dir = Path.home() / '.cache' / 'openocr' + model_path = cache_dir / model_name + + # 如果模型文件已存在,直接返回路径 + if model_path.exists(): + logger.info(f'Model already exists at: {model_path}') + return str(model_path) + + # 如果文件不存在,下载模型 + logger.info(f'Model not found. Downloading from {url}...') + + # 创建缓存目录(如果不存在) + cache_dir.mkdir(parents=True, exist_ok=True) + + try: + # 下载文件 + import urllib.request + with urllib.request.urlopen(url) as response, open(model_path, + 'wb') as out_file: + out_file.write(response.read()) + logger.info(f'Model downloaded and saved at: {model_path}') + return str(model_path) + + except Exception as e: + logger.error(f'Error downloading the model: {e}') + # 提示用户手动下载 + logger.error( + f'Unable to download the model automatically. ' + f'Please download the model manually from the following URL:\n{url}\n' + f'and save it to: {model_name} or {model_path}') + raise RuntimeError( + f'Failed to download the model. Please download it manually from {url} ' + f'and save it to {model_path}') from e class RatioRecTVReisze(object): @@ -26,6 +87,10 @@ def __init__(self, cfg): self.base_shape = cfg['Eval']['dataset'].get( 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]]) self.base_h = cfg['Eval']['dataset'].get('base_h', 32) + + from torchvision import transforms as T + from torchvision.transforms import functional as F + self.F = F self.interpolation = T.InterpolationMode.BICUBIC transforms = [] transforms.extend([ @@ -50,8 +115,8 @@ def __call__(self, data): ratio_resize, self.base_h ] resized_w = imgW - resized_image = F.resize(img, (imgH, resized_w), - interpolation=self.interpolation) + resized_image = self.F.resize(img, (imgH, resized_w), + interpolation=self.interpolation) img = self.transforms(resized_image) data['image'] = img return data @@ -81,71 +146,108 @@ def build_rec_process(cfg): def set_device(device, numId=0): + import torch if device == 'gpu' and torch.cuda.is_available(): device = torch.device(f'cuda:{numId}') else: + logger.info('GPU is not available, using CPU.') device = torch.device('cpu') return device -class OpenRecognizer(object): +class OpenRecognizer: - def __init__(self, config=None, mode='mobile', numId=0): + def __init__(self, + config=None, + mode='mobile', + backend='torch', + onnx_model_path=None, + numId=0): """ - 初始化方法。 - Args: config (dict, optional): 配置信息。默认为None。 mode (str, optional): 模式,'server' 或 'mobile'。默认为'mobile'。 + backend (str): 'torch' 或 'onnx' + onnx_model_path (str): ONNX模型路径(仅当backend='onnx'时需要) numId (int, optional): 设备编号。默认为0。 - - Returns: - None - - Raises: - 无 - """ - if config is None: - if mode == 'server': - config = Config( - './configs/det/svtrv2/svtrv2_ch.yml').cfg # server model - else: - config = Config( - './configs/rec/svtrv2/repsvtr_ch.yml').cfg # mobile model - global_config = config['Global'] + if config is None: + config_file = DEFAULT_CFG_PATH_REC_SERVER if mode == 'server' else DEFAULT_CFG_PATH_REC + config = Config(config_file).cfg self.cfg = config - if global_config['pretrained_model'] is None: - global_config[ - 'pretrained_model'] = global_config['output_dir'] + '/best.pth' - # build post process - from openrec.modeling import build_model as build_rec_model + # 公共初始化 + self._init_common() + backend = backend if config['Global'].get( + 'backend', None) is None else config['Global']['backend'] + self.backend = backend + if backend == 'torch': + import torch + self.torch = torch + self._init_torch_model(numId) + elif backend == 'onnx': + from tools.infer.onnx_engine import ONNXEngine + onnx_model_path = onnx_model_path if config['Global'].get( + 'onnx_model_path', + None) is None else config['Global']['onnx_model_path'] + if not onnx_model_path: + if self.cfg['Architecture']['algorithm'] == 'SVTRv2_mobile': + onnx_model_path = check_and_download_model( + MODEL_NAME_REC_ONNX, DOWNLOAD_URL_REC_ONNX) + else: + raise ValueError('ONNX模式需要指定onnx_model_path参数') + self.onnx_rec_engine = ONNXEngine( + onnx_model_path, use_gpu=config['Global']['device'] == 'gpu') + else: + raise ValueError("backend参数必须是'torch'或'onnx'") + + def _init_common(self): + # 初始化公共组件 from openrec.postprocess import build_post_process from openrec.preprocess import create_operators, transform self.transform = transform - self.post_process_class = build_post_process(config['PostProcess'], - global_config) - + # 构建预处理流程 + algorithm_name = self.cfg['Architecture']['algorithm'] + if algorithm_name in ['SVTRv2_mobile', 'SVTRv2_server']: + self.cfg['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC + self.post_process_class = build_post_process(self.cfg['PostProcess'], + self.cfg['Global']) char_num = self.post_process_class.get_character_num() - config['Architecture']['Decoder']['out_channels'] = char_num - # print(char_num) - self.model = build_rec_model(config['Architecture']) - load_ckpt(self.model, config) - - # exit(0) - self.device = set_device(global_config['device'], numId=numId) - self.model.eval() - replace_batchnorm(self.model.encoder) - self.model.to(device=self.device) - + self.cfg['Architecture']['Decoder']['out_channels'] = char_num transforms, ratio_resize_flag = build_rec_process(self.cfg) - global_config['infer_mode'] = True - self.ops = create_operators(transforms, global_config) + self.ops = create_operators(transforms, self.cfg['Global']) if ratio_resize_flag: ratio_resize = RatioRecTVReisze(cfg=self.cfg) self.ops.insert(-1, ratio_resize) + def _init_torch_model(self, numId): + from tools.utils.ckpt import load_ckpt + from tools.infer_det import replace_batchnorm + # PyTorch专用初始化 + algorithm_name = self.cfg['Architecture']['algorithm'] + if algorithm_name in ['SVTRv2_mobile', 'SVTRv2_server']: + if not os.path.exists(self.cfg['Global']['pretrained_model']): + pretrained_model = check_and_download_model( + MODEL_NAME_REC, DOWNLOAD_URL_REC + ) if algorithm_name == 'SVTRv2_mobile' else check_and_download_model( + MODEL_NAME_REC_SERVER, DOWNLOAD_URL_REC_SERVER) + self.cfg['Global']['pretrained_model'] = pretrained_model + + from openrec.modeling import build_model as build_rec_model + + self.model = build_rec_model(self.cfg['Architecture']) + load_ckpt(self.model, self.cfg) + + self.device = set_device(self.cfg['Global']['device'], numId) + self.model.to(self.device) + self.model.eval() + if algorithm_name == 'SVTRv2_mobile': + replace_batchnorm(self.model.encoder) + + def _inference_onnx(self, images): + # ONNX输入需要为numpy数组 + return self.onnx_rec_engine.run(images) + def __call__(self, img_path=None, img_numpy_list=None, @@ -204,38 +306,39 @@ def __call__(self, ]: valid_ratio = np.expand_dims(batch[-1], axis=0) batch_others.append(valid_ratio) - # others = [torch.from_numpy(valid_ratio).to(device=self.device)] - resized_image = batch[0] + + resized_image = batch[0] if isinstance( + batch[0], np.ndarray) else batch[0].numpy() h, w = resized_image.shape[-2:] max_width = max(max_width, w) max_height = max(max_height, h) batch_data.append(batch[0]) - padded_batch_data = [] - for resized_image in batch_data: - padded_image = np.zeros([1, 3, max_height, max_width], - dtype=np.float32) - h, w = resized_image.shape[-2:] - - # Apply padding (bottom-right padding) - padded_image[:, :, :h, : - w] = resized_image # 0 is typically used for padding - padded_batch_data.append(padded_image) + padded_batch = np.zeros( + (len(batch_data), 3, max_height, max_width), dtype=np.float32) + for i, img in enumerate(batch_data): + h, w = img.shape[-2:] + padded_batch[i, :, :h, :w] = img if batch_others: others = np.concatenate(batch_others, axis=0) else: others = None - images = np.concatenate(padded_batch_data, axis=0) - images = torch.from_numpy(images).to(device=self.device) - - with torch.no_grad(): - t_start = time.time() - preds = self.model(images, others) - torch.cuda.synchronize() - t_cost = time.time() - t_start - post_results = self.post_process_class(preds) - + t_start = time.time() + if self.backend == 'torch': + images = self.torch.from_numpy(padded_batch).to( + device=self.device) + with self.torch.no_grad(): + preds = self.model(images, others) # bs, len, num_classes + torch_tensor = True + elif self.backend == 'onnx': + # ONNX推理 + preds = self._inference_onnx(padded_batch) + preds = preds[0] # bs, len, num_classes + torch_tensor = False + t_cost = time.time() - t_start + post_results = self.post_process_class(preds, + torch_tensor=torch_tensor) for i, post_result in enumerate(post_results): if img_path is not None: info = { @@ -256,10 +359,9 @@ def __call__(self, def main(cfg): - logger = get_logger() model = OpenRecognizer(cfg) - save_res_path = cfg['Global']['output_dir'] + save_res_path = './rec_results/' if not os.path.exists(save_res_path): os.makedirs(save_res_path) @@ -272,9 +374,7 @@ def main(cfg): sample_num = 0 with open(save_res_path + '/rec_results.txt', 'wb') as fout: for file in get_image_file_list(cfg['Global']['infer_img']): - preds_result = model(img_path=file, batch_num=1)[0] - rec_text = preds_result['text'] score = preds_result['score'] t_cost = preds_result['elapse'] @@ -287,6 +387,9 @@ def main(cfg): t_sum += t_cost fout.write(otstr.encode()) sample_num += 1 + logger.info( + f"Results saved to {os.path.join(save_res_path, 'rec_results.txt')}.)" + ) print(text_len_num) w_avg_t_cost = [] diff --git a/tools/toonnx.py b/tools/toonnx.py new file mode 100644 index 0000000..68250d7 --- /dev/null +++ b/tools/toonnx.py @@ -0,0 +1,77 @@ +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) + +import torch + +from tools.engine.config import Config +from tools.utility import ArgsParser +from tools.utils.logging import get_logger + + +def to_onnx(model, dummy_input, dynamic_axes, sava_path='model.onnx'): + input_axis_name = ['batch_size', 'channel', 'in_width', 'int_height'] + output_axis_name = ['batch_size', 'channel', 'out_width', 'out_height'] + torch.onnx.export( + model.to('cpu'), + dummy_input, + sava_path, + input_names=['input'], + output_names=['output'], # the model's output names + dynamic_axes={ + 'input': {axis: input_axis_name[axis] + for axis in dynamic_axes}, + 'output': {axis: output_axis_name[axis] + for axis in dynamic_axes}, + }, + ) + + +def main(cfg): + _cfg = cfg.cfg + logger = get_logger() + global_config = _cfg['Global'] + + export_dir = global_config.get('export_dir', '') + + if _cfg['Architecture']['algorithm'] == 'SVTRv2_mobile': + from tools.infer_rec import OpenRecognizer + model = OpenRecognizer(_cfg).model + dynamic_axes = [0, 3] + dummy_input = torch.randn([1, 3, 48, 320], device='cpu') + if not export_dir: + export_dir = os.path.join( + global_config.get('output_dir', 'output'), 'export_rec') + save_path = os.path.join(export_dir, 'rec_model.onnx') + if _cfg['Architecture']['algorithm'] == 'DB_mobile': + from tools.infer_det import OpenDetector + model = OpenDetector(_cfg).model + dynamic_axes = [0, 2, 3] + dummy_input = torch.randn([1, 3, 960, 960], device='cpu') + if not export_dir: + export_dir = os.path.join( + global_config.get('output_dir', 'output'), 'export_det') + save_path = os.path.join(export_dir, 'det_model.onnx') + + os.makedirs(export_dir, exist_ok=True) + to_onnx(model, dummy_input, dynamic_axes, save_path) + logger.info(f'finish export model to {save_path}') + + +def parse_args(): + parser = ArgsParser() + args = parser.parse_args() + return args + + +if __name__ == '__main__': + FLAGS = parse_args() + cfg = Config(FLAGS.config) + FLAGS = vars(FLAGS) + opt = FLAGS.pop('opt') + cfg.merge_dict(FLAGS) + cfg.merge_dict(opt) + main(cfg) diff --git a/tools/train_det.py b/tools/train_det.py new file mode 100644 index 0000000..511d521 --- /dev/null +++ b/tools/train_det.py @@ -0,0 +1,40 @@ +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) + +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) + +from tools.engine.config import Config +from tools.engine.trainer import Trainer +from tools.utility import ArgsParser + + +def parse_args(): + parser = ArgsParser() + parser.add_argument( + '--eval', + action='store_true', + default=True, + help='Whether to perform evaluation in train', + ) + args = parser.parse_args() + return args + + +def main(): + FLAGS = parse_args() + cfg = Config(FLAGS.config) + FLAGS = vars(FLAGS) + opt = FLAGS.pop('opt') + cfg.merge_dict(FLAGS) + cfg.merge_dict(opt) + trainer = Trainer(cfg, + mode='train_eval' if FLAGS['eval'] else 'train', + task='det') + trainer.train() + + +if __name__ == '__main__': + main() diff --git a/tools/train_rec.py b/tools/train_rec.py index 1351c97..3d12420 100644 --- a/tools/train_rec.py +++ b/tools/train_rec.py @@ -6,7 +6,8 @@ sys.path.append(__dir__) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) -from tools.engine import Config, Trainer +from tools.engine.config import Config +from tools.engine.trainer import Trainer from tools.utility import ArgsParser @@ -29,7 +30,9 @@ def main(): opt = FLAGS.pop('opt') cfg.merge_dict(FLAGS) cfg.merge_dict(opt) - trainer = Trainer(cfg, mode='train_eval' if FLAGS['eval'] else 'train') + trainer = Trainer(cfg, + mode='train_eval' if FLAGS['eval'] else 'train', + task='rec') trainer.train() diff --git a/tools/utils/logging.py b/tools/utils/logging.py index bd7d641..7eed207 100644 --- a/tools/utils/logging.py +++ b/tools/utils/logging.py @@ -2,8 +2,6 @@ import sys import logging import functools -import torch -import torch.distributed as dist logger_initialized = {}
Method