diff --git a/.gitignore b/.gitignore index 6262738..a166590 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,9 @@ __pycache__/ inference/ inference_results/ output/ +e2e_results/ +rec_results/ +det_results/ train_data/ log/ *.DS_Store diff --git a/README.md b/README.md index d0f19c8..bf67a07 100644 --- a/README.md +++ b/README.md @@ -1,57 +1,64 @@ # OpenOCR -We aim to establishing a unified benchmark for training and evaluating models for scene text detection and recognition. Based on this benchmark, we introduce an accurate and efficient general OCR system, OpenOCR. Additionally, this repository will serve as the official codebase for the OCR team from the [FVL](https://fvl.fudan.edu.cn) Laboratory, Fudan University. +We aim to establish a unified benchmark for training and evaluating models in scene text detection and recognition. Building on this benchmark, we introduce a general OCR system with accuracy and efficiency, **OpenOCR**. This repository also serves as the official codebase of the OCR team from the [FVL Laboratory](https://fvl.fudan.edu.cn), Fudan University. We sincerely welcome the researcher to recommend OCR or relevant algorithms and point out any potential factual errors or bugs. Upon receiving the suggestions, we will promptly evaluate and critically reproduce them. We look forward to collaborating with you to advance the development of OpenOCR and continuously contribute to the OCR community! ## Features -- 🔥**OpenOCR: A general OCR system for accuracy and efficiency** - - ⚡\[[Quick Start](#quick-start)\] \[[Demo](<>)(TODO)\] +- 🔥**OpenOCR: A general OCR system with accuracy and efficiency** + - ⚡\[[Quick Start](#quick-start)\] \[[Model](https://github.com/Topdu/OpenOCR/releases/tag/develop0.0.1)\] \[[ModelScope Demo](https://modelscope.cn/studios/topdktu/OpenOCR-Demo)\] \[[Hugging Face Demo](https://huggingface.co/spaces/topdu/OpenOCR-Demo)\] \[[Local Demo](#local-demo)\] \[[PaddleOCR Implementation](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/text_recognition/algorithm_rec_svtrv2.html)\] - [Introduction](./docs/openocr.md) - - A practical version of the model builds on SVTRv2. - - Outperforming [PP-OCRv4](<>) released by [PaddleOCR](<>) by 4.5% on the [OCR competition leaderboard](<>). - - [x] Supporting Chinese and English text detection and recognition. - - [x] Providing server model and mobile model. - - [ ] Fine-tuning OpenOCR on a custom dataset - - [ ] Export to ONNX engine + - A practical OCR system building on SVTRv2. + - Outperforms [PP-OCRv4](https://paddlepaddle.github.io/PaddleOCR/latest/ppocr/model_list.html) baseline by 4.5% on the [OCR competition leaderboard](https://aistudio.baidu.com/competition/detail/1131/0/leaderboard) in terms of accuracy, while preserving quite similar inference speed. + - [x] Supports Chinese and English text detection and recognition. + - [x] Provides server model and mobile model. + - [ ] Fine-tunes OpenOCR on a custom dataset. + - [ ] ONNX model export for wider compatibility. - 🔥**SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition** - - \[[Paper](../configs/rec/svtrv2/SVTRv2.pdf)\] \[[Model](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Config, Training and Inference](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] + - \[[Paper](https://arxiv.org/abs/2411.15858)\] \[[Doc](./configs/rec/svtrv2/)\] \[[Model](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Datasets](./docs/svtrv2.md#downloading-datasets)\] \[[Config, Training and Inference](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] \[[Benchmark](./docs/svtrv2.md#results-benchmark--configs--checkpoints)\] - [Introduction](./docs/svtrv2.md) - - Developing a unified training and evaluation benchmark for Scene Text Recognition - - Supporting for 24 Scene Text Recognition methods trained from scratch on large-scale real datasets, and will continue to add the latest methods. - - Improving results by 20-30% compared to training on synthetic datasets. + - A unified training and evaluation benchmark (on top of [Union14M](https://github.com/Mountchicken/Union14M?tab=readme-ov-file#3-union14m-dataset)) for Scene Text Recognition + - Supports 24 Scene Text Recognition methods trained from scratch on the large-scale real dataset [Union14M-L-Filter](./docs/svtrv2.md#dataset-details), and will continue to add the latest methods. + - Improves accuracy by 20-30% compared to models trained based on synthetic datasets. - Towards Arbitrary-Shaped Text Recognition and Language modeling with a Single Visual Model. - - Surpasses Attention-based Decoder Methods across challenging scenarios in terms of accuracy and speed - - [Get Started](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) with training a SoTA Scene Text Recognition model from scratch. + - Surpasses Attention-based Encoder-Decoder Methods across challenging scenarios in terms of accuracy and speed + - [Get Started](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) with training a SOTA Scene Text Recognition model from scratch. ## Ours STR algorithms -- [**DPTR**](<>) (*Shuai Zhao, Yongkun Du, Zhineng Chen\*, Yu-Gang Jiang. Decoder Pre-Training with only Text for Scene Text Recognition,* ACM MM 2024. [paper](https://arxiv.org/abs/2408.05706)) -- [**IGTR**](./configs/rec/igtr/) (*Yongkun Du, Zhineng Chen\*, Yuchen Su, Caiyan Jia, Yu-Gang Jiang. Instruction-Guided Scene Text Recognition,* Under TPAMI minor revison 2024. [Doc](./configs/rec/igtr/readme.md), [paper](https://arxiv.org/abs/2401.17851)) -- [**SVTRv2**](./configs/rec/svtrv2) (*Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang. SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition,* 2024. [paper](./configs/rec/svtrv2/SVTRv2.pdf)) -- [**SMTR&FocalSVTR**](./configs/rec/smtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang. Out of Length Text Recognition with Sub-String Matching,* 2024. [paper](https://arxiv.org/abs/2407.12317)) -- [**CDistNet**](./configs/rec/cdistnet/) (*Tianlun Zheng, Zhineng Chen\*, Shancheng Fang, Hongtao Xie, Yu-Gang Jiang. CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition,* IJCV 2024. [paper](https://link.springer.com/article/10.1007/s11263-023-01880-0)) -- **MRN** (*Tianlun Zheng, Zhineng Chen\*, Bingchen Huang, Wei Zhang, Yu-Gang Jiang. MRN: Multiplexed routing network for incremental multilingual text recognition,* ICCV 2023. [paper](https://openaccess.thecvf.com/content/ICCV2023/html/Zheng_MRN_Multiplexed_Routing_Network_for_Incremental_Multilingual_Text_Recognition_ICCV_2023_paper.html)) -- **TPS++** (*Tianlun Zheng, Zhineng Chen\*, Jinfeng Bai, Hongtao Xie, Yu-Gang Jiang. TPS++: Attention-Enhanced Thin-Plate Spline for Scene Text Recognition,* IJCAI 2023. [paper](https://arxiv.org/abs/2305.05322)) -- [**CPPD**](./configs/rec/cppd/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Chenxia Li, Yuning Du, Yu-Gang Jiang. Context Perception Parallel Decoder for Scene Text Recognition,* Under TPAMI minor revision 2023. [PaddleOCR Doc](https://github.com/Topdu/PaddleOCR/blob/main/doc/doc_ch/algorithm_rec_cppd.md), [paper](https://arxiv.org/abs/2307.12270)) -- [**SVTR**](./configs/rec/svtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Tianlun Zheng, Chenxia Li, Yuning Du, Yu-Gang Jiang. SVTR: Scene Text Recognition with a Single Visual Model,* IJCAI 2022 (Long). [PaddleOCR Doc](https://github.com/Topdu/PaddleOCR/blob/main/doc/doc_ch/algorithm_rec_svtr.md), [paper](https://www.ijcai.org/proceedings/2022/124)) -- [**NRTR**](./configs/rec/nrtr/) (*Fenfen Sheng, Zhineng Chen\*, Bo Xu. NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition,* ICDAR 2019. [paper](https://arxiv.org/abs/1806.00926)) +- [**SMTR&FocalSVTR**](./configs/rec/smtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang. Out of Length Text Recognition with Sub-String Matching,* AAAI 2025. [Doc](./configs/rec/smtr/), [Paper](https://arxiv.org/abs/2407.12317)) +- [**DPTR**](./configs/rec/dptr/) (*Shuai Zhao, Yongkun Du, Zhineng Chen\*, Yu-Gang Jiang. Decoder Pre-Training with only Text for Scene Text Recognition,* ACM MM 2024. [Paper](https://arxiv.org/abs/2408.05706)) +- [**IGTR**](./configs/rec/igtr/) (*Yongkun Du, Zhineng Chen\*, Yuchen Su, Caiyan Jia, Yu-Gang Jiang. Instruction-Guided Scene Text Recognition,* TPAMI (accepted). [Doc](./configs/rec/igtr), [Paper](https://arxiv.org/abs/2401.17851)) +- [**SVTRv2**](./configs/rec/svtrv2) (*Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang. SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition,* 2024. [Doc](./configs/rec/svtrv2/), [Paper](https://arxiv.org/abs/2411.15858)) +- [**CDistNet**](./configs/rec/cdistnet/) (*Tianlun Zheng, Zhineng Chen\*, Shancheng Fang, Hongtao Xie, Yu-Gang Jiang. CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition,* IJCV 2024. [Paper](https://link.springer.com/article/10.1007/s11263-023-01880-0)) +- **MRN** (*Tianlun Zheng, Zhineng Chen\*, Bingchen Huang, Wei Zhang, Yu-Gang Jiang. MRN: Multiplexed routing network for incremental multilingual text recognition,* ICCV 2023. [Paper](https://openaccess.thecvf.com/content/ICCV2023/html/Zheng_MRN_Multiplexed_Routing_Network_for_Incremental_Multilingual_Text_Recognition_ICCV_2023_paper.html), [Code](https://github.com/simplify23/MRN)) +- **TPS++** (*Tianlun Zheng, Zhineng Chen\*, Jinfeng Bai, Hongtao Xie, Yu-Gang Jiang. TPS++: Attention-Enhanced Thin-Plate Spline for Scene Text Recognition,* IJCAI 2023. [Paper](https://arxiv.org/abs/2305.05322), [Code](https://github.com/simplify23/TPS_PP)) +- [**CPPD**](./configs/rec/cppd/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Chenxia Li, Yuning Du, Yu-Gang Jiang. Context Perception Parallel Decoder for Scene Text Recognition,* Under TPAMI minor revision 2023. [PaddleOCR Doc](https://github.com/Topdu/PaddleOCR/blob/main/doc/doc_ch/algorithm_rec_cppd.md), [Paper](https://arxiv.org/abs/2307.12270)) +- [**SVTR**](./configs/rec/svtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Tianlun Zheng, Chenxia Li, Yuning Du, Yu-Gang Jiang. SVTR: Scene Text Recognition with a Single Visual Model,* IJCAI 2022 (Long). [PaddleOCR Doc](https://github.com/Topdu/PaddleOCR/blob/main/doc/doc_ch/algorithm_rec_svtr.md), [Paper](https://www.ijcai.org/proceedings/2022/124)) +- [**NRTR**](./configs/rec/nrtr/) (*Fenfen Sheng, Zhineng Chen\*, Bo Xu. NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition,* ICDAR 2019. [Paper](https://arxiv.org/abs/1806.00926)) ## Recent Updates +- **2024.12.31**: Our paper [IGTR](https://arxiv.org/abs/2401.17851) is accepted by TPAMI. Accessible in [Doc](./configs/rec/igtr/). + +- **2024.12.16**: Our paper [SMTR](https://arxiv.org/abs/2407.12317) is accepted by AAAI 2025. Accessible in [Doc](./configs/rec/smtr/). + +- **2024.12.03**: The pre-training code for [DPTR](https://arxiv.org/abs/2408.05706) is merged. + - **🔥 2024.11.23 release notes**: - - **OpenOCR: A general OCR system for accuracy and efficiency** - - ⚡\[[Quick Start](#quick-start)\] \[[Demo](<>)(TODO)\] + + - **OpenOCR: A general OCR system with accuracy and efficiency** + - ⚡\[[Quick Start](#quick-start)\] \[[Model](https://github.com/Topdu/OpenOCR/releases/tag/develop0.0.1)\] \[[ModelScope Demo](https://modelscope.cn/studios/topdktu/OpenOCR-Demo)\] \[[Hugging Face Demo](https://huggingface.co/spaces/topdu/OpenOCR-Demo)\] \[[Local Demo](#local-demo)\] \[[PaddleOCR Implementation](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/text_recognition/algorithm_rec_svtrv2.html)\] - [Introduction](./docs/openocr.md) - **SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition** - - \[[Paper](../configs/rec/svtrv2/SVTRv2.pdf)\] \[[Model](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Config, Training and Inference](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] + - \[[Paper](https://arxiv.org/abs/2411.15858)\] \[[Doc](./configs/rec/svtrv2/)\] \[[Model](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Datasets](./docs/svtrv2.md#downloading-datasets)\] \[[Config, Training and Inference](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] \[[Benchmark](./docs/svtrv2.md#results--configs--checkpoints)\] - [Introduction](./docs/svtrv2.md) - - [Get Started](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) with training a SoTA Scene Text Recognition model from scratch. + - [Get Started](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) with training a SOTA Scene Text Recognition model from scratch. -## ⚡[Quick Start](./docs/openocr.md#quick-start) +## Quick Start -#### Dependencies: +### Dependencies: - [PyTorch](http://pytorch.org/) version >= 1.13.0 - Python version >= 3.7 @@ -59,12 +66,15 @@ We sincerely welcome the researcher to recommend OCR or relevant algorithms and ```shell conda create -n openocr python==3.8 conda activate openocr +# install gpu version torch conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia +# or cpu version +conda install pytorch torchvision torchaudio cpuonly -c pytorch ``` After installing dependencies, the following two installation methods are available. Either one can be chosen. -#### 1. Python Modules +### 1. Python Modules ```shell pip install openocr-python @@ -79,19 +89,21 @@ engine = OpenOCR() img_path = '/path/img_path or /path/img_file' result, elapse = engine(img_path) -print(result) -print(elapse) # Server mode -engine = OpenOCR(mode='server') +# engine = OpenOCR(mode='server') ``` -#### 2. Clone this repository: +### 2. Clone this repository: ```shell git clone https://github.com/Topdu/OpenOCR.git cd OpenOCR pip install -r requirements.txt +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth +# Rec Server model +# wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth ``` **Usage**: @@ -99,14 +111,22 @@ pip install -r requirements.txt ```shell # OpenOCR system: Det + Rec model python tools/infer_e2e.py --img_path=/path/img_fold or /path/img_file - # Det model python tools/infer_det.py --c ./configs/det/dbnet/repvit_db.yml --o Global.infer_img=/path/img_fold or /path/img_file - # Rec model python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.infer_img=/path/img_fold or /path/img_file ``` +#### Local Demo + +```shell +pip install gradio==4.20.0 +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/OCR_e2e_img.tar +tar xf OCR_e2e_img.tar +# start demo +python demo_gradio.py +``` + ## Reproduction schedule: ### Scene Text Recognition @@ -117,7 +137,7 @@ python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.inf | [ASTER](./configs/rec/aster/) | [TPAMI 2019](https://ieeexplore.ieee.org/document/8395027) | ✅ | ✅ | [pretto0](https://github.com/pretto0) | | [NRTR](./configs/rec/nrtr/) | [ICDAR 2019](https://arxiv.org/abs/1806.00926) | ✅ | ✅ | | | [SAR](./configs/rec/sar/) | [AAAI 2019](https://aaai.org/papers/08610-show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition/) | ✅ | ✅ | [pretto0](https://github.com/pretto0) | -| [MORAN](./configs/rec/moran/) | [PR 2019](https://www.sciencedirect.com/science/article/abs/pii/S0031320319300263) | ✅ | ✅ | Debug | +| [MORAN](./configs/rec/moran/) | [PR 2019](https://www.sciencedirect.com/science/article/abs/pii/S0031320319300263) | ✅ | ✅ | | | [DAN](./configs/rec/dan/) | [AAAI 2020](https://arxiv.org/pdf/1912.10205) | ✅ | ✅ | | | [RobustScanner](./configs/rec/robustscanner/) | [ECCV 2020](https://www.ecva.net/papers/eccv_2020/papers_ECCV/html/3160_ECCV_2020_paper.php) | ✅ | ✅ | [pretto0](https://github.com/pretto0) | | [AutoSTR](./configs/rec/autostr/) | [ECCV 2020](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123690732.pdf) | ✅ | ✅ | | @@ -125,6 +145,7 @@ python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.inf | [SEED](./configs/rec/seed/) | [CVPR 2020](https://openaccess.thecvf.com/content_CVPR_2020/html/Qiao_SEED_Semantics_Enhanced_Encoder-Decoder_Framework_for_Scene_Text_Recognition_CVPR_2020_paper.html) | ✅ | ✅ | | | [ABINet](./configs/rec/abinet/) | [CVPR 2021](https://openaccess.thecvf.com//content/CVPR2021/html/Fang_Read_Like_Humans_Autonomous_Bidirectional_and_Iterative_Language_Modeling_for_CVPR_2021_paper.html) | ✅ | ✅ | [YesianRohn](https://github.com/YesianRohn) | | [VisionLAN](./configs/rec/visionlan/) | [ICCV 2021](https://openaccess.thecvf.com/content/ICCV2021/html/Wang_From_Two_to_One_A_New_Scene_Text_Recognizer_With_ICCV_2021_paper.html) | ✅ | ✅ | [YesianRohn](https://github.com/YesianRohn) | +| PIMNet | [ACM MM 2021](https://dl.acm.org/doi/10.1145/3474085.3475238) | | | TODO | | [SVTR](./configs/rec/svtrs/) | [IJCAI 2022](https://www.ijcai.org/proceedings/2022/124) | ✅ | ✅ | | | [PARSeq](./configs/rec/parseq/) | [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136880177.pdf) | ✅ | ✅ | | | [MATRN](./configs/rec/matrn/) | [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136880442.pdf) | ✅ | ✅ | | @@ -137,14 +158,14 @@ python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.inf | [BUSNet](./configs/rec/busnet/) | [AAAI 2024](https://ojs.aaai.org/index.php/AAAI/article/view/28402) | ✅ | ✅ | | | DCTC | [AAAI 2024](https://ojs.aaai.org/index.php/AAAI/article/view/28575) | | | TODO | | [CAM](./configs/rec/cam/) | [PR 2024](https://arxiv.org/abs/2402.13643) | ✅ | ✅ | | -| [OTE](./configs/rec/ote/) | [CVPR 2024](https://openaccess.thecvf.com/content/CVPR2024/papers/Xu_OTE_Exploring_Accurate_Scene_Text_Recognition_Using_One_Token_CVPR_2024_paper.pdf) | ✅ | ✅ | | +| [OTE](./configs/rec/ote/) | [CVPR 2024](https://openaccess.thecvf.com/content/CVPR2024/html/Xu_OTE_Exploring_Accurate_Scene_Text_Recognition_Using_One_Token_CVPR_2024_paper.html) | ✅ | ✅ | | | CFF | [IJCAI 2024](https://arxiv.org/abs/2407.05562) | | | TODO | -| DPTR | [ACM MM 2024](https://arxiv.org/abs/2408.05706) | | | TODO | +| [DPTR](./configs/rec/dptr/) | [ACM MM 2024](https://arxiv.org/abs/2408.05706) | | | [fd-zs](https://github.com/fd-zs) | | VIPTR | [ACM CIKM 2024](https://arxiv.org/abs/2401.10110) | | | TODO | -| [IGTR](./configs/rec/igtr/) | [2024](https://arxiv.org/abs/2401.17851) | ✅ | ✅ | | -| [SMTR](./configs/rec/smtr/) | [2024](https://arxiv.org/abs/2407.12317) | ✅ | ✅ | | +| [IGTR](./configs/rec/igtr/) | [TPAMI](https://arxiv.org/abs/2401.17851) | ✅ | ✅ | | +| [SMTR](./configs/rec/smtr/) | [AAAI 2025](https://arxiv.org/abs/2407.12317) | ✅ | ✅ | | | [FocalSVTR-CTC](./configs/rec/svtrs/) | [2024](https://arxiv.org/abs/2407.12317) | ✅ | ✅ | | -| [SVTRv2](./configs/rec/svtrv2/) | [2024](./configs/rec/svtrv2/SVTRv2.pdf) | ✅ | ✅ | | +| [SVTRv2](./configs/rec/svtrv2/) | [2024](https://arxiv.org/abs/2411.15858) | ✅ | ✅ | | | [ResNet+Trans-CTC](./configs/rec/svtrs/) | | ✅ | ✅ | | | [ViT-CTC](./configs/rec/svtrs/) | | ✅ | ✅ | | @@ -152,7 +173,7 @@ python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.inf ______________________________________________________________________ -Yiming Lei ([pretto0](https://github.com/pretto0)) and Xingsong Ye ([YesianRohn](https://github.com/YesianRohn)) from the [FVL](https://fvl.fudan.edu.cn) Laboratory, Fudan University, under the guidance of Professor Zhineng Chen, completed the majority of the algorithm reproduction work. Grateful for their outstanding contributions. +Yiming Lei ([pretto0](https://github.com/pretto0)), Xingsong Ye ([YesianRohn](https://github.com/YesianRohn)), and Shuai Zhao ([fd-zs](https://github.com/fd-zs)) from the [FVL Laboratory](https://fvl.fudan.edu.cn), Fudan University, with guidance from Dr. Zhineng Chen ([Homepage](https://zhinchenfd.github.io/)), completed the majority work of the algorithm reproduction. Grateful for their outstanding contributions. ### Scene Text Detection (STD) @@ -164,6 +185,23 @@ TODO ______________________________________________________________________ +## Citation + +If you find our method useful for your reserach, please cite: + +```bibtex +@article{Du2024SVTRv2, + title={SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition}, + author={Yongkun Du and Zhineng Chen and Hongtao Xie and Caiyan Jia and Yu-Gang Jiang}, + journal={CoRR}, + volume={abs/2411.15858}, + eprinttype={arXiv}, + year={2024}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2411.15858} +} +``` + # Acknowledgement This codebase is built based on the [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [PytorchOCR](https://github.com/WenmuZhou/PytorchOCR), and [MMOCR](https://github.com/open-mmlab/mmocr). Thanks for their awesome work! diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..1fa1819 --- /dev/null +++ b/__init__.py @@ -0,0 +1,11 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) + +from tools.infer_e2e import OpenOCR, OpenDetector, OpenRecognizer diff --git a/configs/dataset/rec/evaluation.yaml b/configs/dataset/rec/evaluation.yaml new file mode 100644 index 0000000..7c10194 --- /dev/null +++ b/configs/dataset/rec/evaluation.yaml @@ -0,0 +1,41 @@ +root: ../evaluation +task: str +download_links: + # IC15_1811 + - https://drive.usercontent.google.com/download?id=1eGY0kXNV1qVxeUpoGzs-ioUO-ky7msH6&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1BWv7aLoLAT7avY326gXP3GJF48UZpuBC&authuser=0&confirm=t + # SVT + - https://drive.usercontent.google.com/download?id=1ecEZ4cJ7dIbTCZRltE0s5KzUotQWagH-&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1OygBP7i9R-3Pwi6WodCcW31J8CUMugOJ&authuser=0&confirm=t + # IIIT5k + - https://drive.usercontent.google.com/download?id=1PJ9_IvIGZTS5hHdGLnpKuYKZcCO8jE0E&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=10P3MixSBt1v8k8_6aFfziC33Z5IlM6Uf&authuser=0&confirm=t + # IC13_857 + - https://drive.usercontent.google.com/download?id=1-wMHOFBXJaOaY-UD00nDn6qw2s_8R4Vd&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1J1QCFtOFxFKiLJIgTqZ6eRo9Y5QGqHpA&authuser=0&confirm=t + # SVTP + - https://drive.usercontent.google.com/download?id=1kckwfZkdaHG8k_FW5IIJKUaYZkF21Hza&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1x61lm_ea7lvIdxNPMG-jy-5W0MxtdH0N&authuser=0&confirm=t + # CUTE80 + - https://drive.usercontent.google.com/download?id=1Zv_91c81tinLy5Je89HPr-5wUSnqXKIB&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1OuJ6QoJ9AlyNHIM9j2WedAPxTnac7kyY&authuser=0&confirm=t +filenames: + # IC15_1811 + - ../evaluation/IC15_1811/data.mdb + - ../evaluation/IC15_1811/lock.mdb + # SVT + - ../evaluation/SVT/data.mdb + - ../evaluation/SVT/lock.mdb + # IIIT5k + - ../evaluation/IIIT5k/data.mdb + - ../evaluation/IIIT5k/lock.mdb + # IC13_857 + - ../evaluation/IC13_857/data.mdb + - ../evaluation/IC13_857/lock.mdb + # SVTP + - ../evaluation/SVTP/data.mdb + - ../evaluation/SVTP/lock.mdb + # CUTE80 + - ../evaluation/CUTE80/data.mdb + - ../evaluation/CUTE80/lock.mdb +check_validity: true diff --git a/configs/dataset/rec/ltb.yaml b/configs/dataset/rec/ltb.yaml new file mode 100644 index 0000000..a0b03a5 --- /dev/null +++ b/configs/dataset/rec/ltb.yaml @@ -0,0 +1,9 @@ +root: ../ltb +task: str +download_links: + - https://drive.usercontent.google.com/download?id=16AEA1YGTsyVB44uEjKi4ZUV1snjCYBr4&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1xU4OStrOaI23bPG4flWAPWn2YrQe2bmY&authuser=0&confirm=t +filenames: + - ../ltb/data.mdb + - ../ltb/lock.mdb +check_validity: true diff --git a/configs/dataset/rec/mjsynth.yaml b/configs/dataset/rec/mjsynth.yaml new file mode 100644 index 0000000..7a9fc75 --- /dev/null +++ b/configs/dataset/rec/mjsynth.yaml @@ -0,0 +1,11 @@ +root: ../synth +task: str +download_links: + - https://drive.usercontent.google.com/download?id=1FIoplSFZ-BKQoRDHDXsVMKa844e-K8PD&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1eckTvaeRtlTZvbO2orrVz-cIuIk6i87K&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1PBXTf-2PnmEvJBsqzJqxxRwzhAZGTiMG&authuser=0&confirm=t +filenames: + - ../synth/MJ_train.zip + - ../synth/MJ_val.zip + - ../synth/MJ_test.zip +check_validity: true \ No newline at end of file diff --git a/configs/dataset/rec/openvino.yaml b/configs/dataset/rec/openvino.yaml new file mode 100644 index 0000000..f60e448 --- /dev/null +++ b/configs/dataset/rec/openvino.yaml @@ -0,0 +1,25 @@ +root: ../OpenVINO +task: str +download_links: + # train_1 + - https://drive.usercontent.google.com/download?id=1q23QAIRTyG0t-bBm4aAwRwiqB6VUfphw&authuser=0&confirm= + # train_2 + - https://drive.usercontent.google.com/download?id=1AtbaJljM68cbZqi5lcM92d9VkQUCbSqI&authuser=0&confirm= + # train_5 + - https://drive.usercontent.google.com/download?id=1dejstYnJ8_sESuO_uvwi__jT1B8gPxf3&authuser=0&confirm=t + # train_f + - https://drive.usercontent.google.com/download?id=1C4akchTc7-yi1OS_sJ3KP693UKcnecke&authuser=0&confirm=t + # validation + - https://drive.usercontent.google.com/download?id=17TRzSQhuK_juAxAv3KmX0y13pQP2cz6R&authuser=0&confirm=t +filenames: + # train_1 + - ../OpenVINO/train_1.zip + # train_2 + - ../OpenVINO/train_2.zip + # train_5 + - ../OpenVINO/train_5.zip + # train_f + - ../OpenVINO/train_f.zip + # validation + - ../OpenVINO/validation.zip +check_validity: true diff --git a/configs/dataset/rec/ost.yaml b/configs/dataset/rec/ost.yaml new file mode 100644 index 0000000..6690f12 --- /dev/null +++ b/configs/dataset/rec/ost.yaml @@ -0,0 +1,17 @@ +root: ../OST +task: str +download_links: + # OST heavy + - https://drive.usercontent.google.com/download?id=1RGpIFbD_SRlrzZFBoVF_LGvetNx1-5pg&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1Th4MfDf44k0EBpIqCLqVoGRu6G-FP1hq&authuser=0&confirm=t + # OST weak + - https://drive.usercontent.google.com/download?id=1z5CTDJucUnvALG12Q4UXk1DDKJDd8WJn&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1V17TTkX3sjpV7v0km_F2SDCK0tL3k_ls&authuser=0&confirm=t +filenames: + # OST heavy + - ../OST/heavy/data.mdb + - ../OST/heavy/lock.mdb + # OST weak + - ../OST/weak/data.mdb + - ../OST/weak/lock.mdb +check_validity: true diff --git a/configs/dataset/rec/synthtext.yaml b/configs/dataset/rec/synthtext.yaml new file mode 100644 index 0000000..4b84088 --- /dev/null +++ b/configs/dataset/rec/synthtext.yaml @@ -0,0 +1,7 @@ +root: ../synth +task: str +download_links: + - https://drive.usercontent.google.com/download?id=1T-enqkq6_l2HqrsV3da_h0oJ7CUKu_oc&authuser=0&confirm=t +filenames: + - ../synth/ST.zip +check_validity: true diff --git a/configs/dataset/rec/test.yaml b/configs/dataset/rec/test.yaml new file mode 100644 index 0000000..b043ba4 --- /dev/null +++ b/configs/dataset/rec/test.yaml @@ -0,0 +1,77 @@ +root: ../test +task: str +download_links: + # IC13_857 + - https://drive.usercontent.google.com/download?id=1PZSCbe6_DI8MlCqCRWXGT2PP92_frIXq&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1qkN7NDg0zUHxUiZHAeEatDTqlsgpFWp3&authuser=0&confirm=t + # IC15_2077 + - https://drive.usercontent.google.com/download?id=1dFkY3DNbr-Mepn3TWBiA9COEJ63fGFcp&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1UvVwLNZ3tS1YdTBa8MulPzjeVezKaDro&authuser=0&confirm=t + # SVTP + - https://drive.usercontent.google.com/download?id=1aofeerilxJ7J3S7QxuCEXbmXTpz8Xshx&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1rJ1KoO4K_VUxEAUN_bMgBGzK8_JZAAno&authuser=0&confirm=t + # IIIT5k + - https://drive.usercontent.google.com/download?id=1XFO2M1Kbgwv3-iTNTmhQXAEjNmKYOeoT&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1stwK2hFsyaV7HHsEG9EYgnUQebNb2_nG&authuser=0&confirm=t + # COCOv1.4 + - https://drive.usercontent.google.com/download?id=1Se2QSGS19xx7Gfy-SUdX9mlAOr2eYsfA&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1xvekFi389QfkH7yS0JIVV0QzjhUspjDv&authuser=0&confirm=t + # IC15_1811 + - https://drive.usercontent.google.com/download?id=1pHsw8wrThD9EGEE6AusQLZozefSj4iyR&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1TXZ1qHuKAksaAlvd3qMv4IHKnN-IJW9a&authuser=0&confirm=t + # Uber + - https://drive.usercontent.google.com/download?id=1L2j6BZeLTGQ1FIl8HB_D3AFiWLltGV5r&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=12DUj28yzLWxFO_gfMfSjTkRujYD5MNEE&authuser=0&confirm=t + # IC13_1095 + - https://drive.usercontent.google.com/download?id=1fu8onMt3Z6fDLNAiHcm-sQ2qCXduE-FU&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1OQAZtLj8U2Cl4L0ErGFsz6vGIVTTWasD&authuser=0&confirm=t + # IC13_1015 + - https://drive.usercontent.google.com/download?id=1mbsfuvWB282HYfn9tbqcj1nUDkLXcSNB&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1QGogU_hV-oN7iY2POutdD2LDcmK6plnV&authuser=0&confirm=t + # ArT + - https://drive.usercontent.google.com/download?id=1-53knSy-uTSngCG7wyBngVyTuTCmdnWl&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=172EsSaf7BVaB1ORtohi-Jc_8SuUKZGGf&authuser=0&confirm=t + # SVT + - https://drive.usercontent.google.com/download?id=1p7aVUr9Yr7c4X4YUBvk2-YP28rraHjn9&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1ALmhvSleZ0yf-lcdbQPP3M9Zc3oqnXij&authuser=0&confirm=t + # CUTE80 + - https://drive.usercontent.google.com/download?id=1Ujr4axHKnu54P2rIGUhkjdM6XlhDYrI_&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1DvZi9L3MqjO2zRUyCg3YvP4qMAt2bsme&authuser=0&confirm=t +filenames: + # IC13_857 + - ../test/IC13_857/data.mdb + - ../test/IC13_857/lock.mdb + # IC15_2077 + - ../test/IC15_2077/data.mdb + - ../test/IC15_2077/lock.mdb + # SVTP + - ../test/SVTP/data.mdb + - ../test/SVTP/lock.mdb + # IIIT5k + - ../test/IIIT5k/data.mdb + - ../test/IIIT5k/lock.mdb + # COCOv1.4 + - ../test/COCOv1.4/data.mdb + - ../test/COCOv1.4/lock.mdb + # IC15_1811 + - ../test/IC15_1811/data.mdb + - ../test/IC15_1811/lock.mdb + # Uber + - ../test/Uber/data.mdb + - ../test/Uber/lock.mdb + # IC13_1095 + - ../test/IC13_1095/data.mdb + - ../test/IC13_1095/lock.mdb + # IC13_1015 + - ../test/IC13_1015/data.mdb + - ../test/IC13_1015/lock.mdb + # ArT + - ../test/ArT/data.mdb + - ../test/ArT/lock.mdb + # SVT + - ../test/SVT/data.mdb + - ../test/SVT/lock.mdb + # CUTE80 + - ../test/CUTE80/data.mdb + - ../test/CUTE80/lock.mdb +check_validity: true diff --git a/configs/dataset/rec/textocr.yaml b/configs/dataset/rec/textocr.yaml new file mode 100644 index 0000000..abfb4b7 --- /dev/null +++ b/configs/dataset/rec/textocr.yaml @@ -0,0 +1,13 @@ +root: ../TextOCR +task: str +download_links: + # train + - https://drive.usercontent.google.com/download?id=1jVjJFno4pnsU0Cp_kn4MIXQrChmELy92&authuser=0&confirm= + # val + - https://drive.usercontent.google.com/download?id=1ubIRu01MXIek6OvInu-XjaIbw6277-vw&authuser=0&confirm=t +filenames: + # train + - ../TextOCR/train.zip + # val + - ../TextOCR/val.zip +check_validity: true diff --git a/configs/dataset/rec/textocr_horizontal.yaml b/configs/dataset/rec/textocr_horizontal.yaml new file mode 100644 index 0000000..1ccec04 --- /dev/null +++ b/configs/dataset/rec/textocr_horizontal.yaml @@ -0,0 +1,13 @@ +root: ../TextOCR_horizontal +task: str +download_links: + # train + - https://drive.usercontent.google.com/download?id=1sWH6J11xbjQb8SH7fdG_8mIKVI81ZQy5&authuser=0&confirm= + # val + - https://drive.usercontent.google.com/download?id=1gIE-AU2o-5hvg288-bjphO6UkI5AEQ2d&authuser=0&confirm=t +filenames: + # train + - ../TextOCR_horizontal/train.zip + # val + - ../TextOCR_horizontal/val.zip +check_validity: true diff --git a/configs/dataset/rec/union14m_b.yaml b/configs/dataset/rec/union14m_b.yaml new file mode 100644 index 0000000..7479207 --- /dev/null +++ b/configs/dataset/rec/union14m_b.yaml @@ -0,0 +1,47 @@ +root: ../u14m +task: str +download_links: + # artistic + - https://drive.usercontent.google.com/download?id=1Je2DTuFHnkXDI99yDnm9Anl5naWaCQwd&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1xtT_Q0juBJUIvAG55qBxoVNNTECd2usZ&authuser=0&confirm=t + # contextless + - https://drive.usercontent.google.com/download?id=1_0OzyzWhZOmGrHkayFTVrzhrQrNRDRPR&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1PPgC42y3xoM9bR0HQFbDYbcT3PzMdD_y&authuser=0&confirm=t + # salient + - https://drive.usercontent.google.com/download?id=1tHLMYBmTqRnxvFOTT3dfLfQiundqFWfd&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=13NQgpAtCK0kh9M5E2pAUmKKEp6Qu5Xwj&authuser=0&confirm=t + # multi_words + - https://drive.usercontent.google.com/download?id=1IlnDKX3V_Vp9gsDGFB0xoqsVLH1vtxUI&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1mFFjC7C0CwevvkwFU9YeVbZBdps_3Qpb&authuser=0&confirm=t + # curve + - https://drive.usercontent.google.com/download?id=1MxhMd85cmhUtI2lmtXhZQuFk7lav0_fw&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1N03g-4e-kJG2mRvlM0c5TrwWAkd-iG-Q&authuser=0&confirm=t + # general + - https://drive.usercontent.google.com/download?id=1Oqt7OaycP466NWoDmoJ3FqS8YP3YRgvu&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1K0MrX5eYNt8IIGFHXCwg0_oI5OF5PPFO&authuser=0&confirm=t + # multi_oriented + - https://drive.usercontent.google.com/download?id=1TKZFcZPVk0ThqfF-AGhJk_OCLg0ykKbv&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1PAoLMUWuR7O2-7XRoKkNzQcSiznErQzD&authuser=0&confirm=t +filenames: + # artistic + - ../u14m/artistic/data.mdb + - ../u14m/artistic/lock.mdb + # contextless + - ../u14m/contextless/data.mdb + - ../u14m/contextless/lock.mdb + # salient + - ../u14m/salient/data.mdb + - ../u14m/salient/lock.mdb + # multi_words + - ../u14m/multi_words/data.mdb + - ../u14m/multi_words/lock.mdb + # curve + - ../u14m/curve/data.mdb + - ../u14m/curve/lock.mdb + # general + - ../u14m/general/data.mdb + - ../u14m/general/lock.mdb + # multi_oriented + - ../u14m/multi_oriented/data.mdb + - ../u14m/multi_oriented/lock.mdb +check_validity: true diff --git a/configs/dataset/rec/union14m_l_filtered.yaml b/configs/dataset/rec/union14m_l_filtered.yaml new file mode 100644 index 0000000..86ab60e --- /dev/null +++ b/configs/dataset/rec/union14m_l_filtered.yaml @@ -0,0 +1,35 @@ +root: ../Union14M-L-LMDB-Filtered +task: str +download_links: + # train_challenging + - https://drive.usercontent.google.com/download?id=1etwzBgGHjsFsb0sygsaRnKbanW2PMe07&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1ly6FJfPjItwGlVQ-ifTrzzM3rVu3Ezhr&authuser=0&confirm=t + # train_easy + - https://drive.usercontent.google.com/download?id=1_zeNluTnywIaa5h3PN-Ah9tKyByypot7&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1caYLeQHDidXgVBDi9IWXbO1gg__DYq9a&authuser=0&confirm=t + # train_hard + - https://drive.usercontent.google.com/download?id=1eP6s2xyYPZX9gykvWA4VSOc3Fqul_UB_&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1-ZlCvocX8P5uVRclUXp_5DNGLDzd16EO&authuser=0&confirm=t + # train_medium + - https://drive.usercontent.google.com/download?id=1s_CoaLNJEr-UxHYiqZ5jOcliMCFiRUUy&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1Wpj6WVpZ5Ily77kVwfQ18CiZBzkgmEnF&authuser=0&confirm=t + # train_normal + - https://drive.usercontent.google.com/download?id=1jPt44arlAswl9cXZjzmVcdpptdTPpJ3I&authuser=0&confirm=t + - https://drive.usercontent.google.com/download?id=1Rfc5kE03AzOUv7B_eYcBhUV8KMQ2MZ1m&authuser=0&confirm=t +filenames: + # train_challenging + - ../Union14M-L-LMDB-Filtered/train_challenging/data.mdb + - ../Union14M-L-LMDB-Filtered/train_challenging/lock.mdb + # train_easy + - ../Union14M-L-LMDB-Filtered/train_easy/data.mdb + - ../Union14M-L-LMDB-Filtered/train_easy/lock.mdb + # train_hard + - ../Union14M-L-LMDB-Filtered/train_hard/data.mdb + - ../Union14M-L-LMDB-Filtered/train_hard/lock.mdb + # train_medium + - ../Union14M-L-LMDB-Filtered/train_medium/data.mdb + - ../Union14M-L-LMDB-Filtered/train_medium/lock.mdb + # train_normal + - ../Union14M-L-LMDB-Filtered/train_normal/data.mdb + - ../Union14M-L-LMDB-Filtered/train_normal/lock.mdb +check_validity: true diff --git a/configs/det/dbnet/repvit_db.yml b/configs/det/dbnet/repvit_db.yml index c9b1bc1..689f655 100644 --- a/configs/det/dbnet/repvit_db.yml +++ b/configs/det/dbnet/repvit_db.yml @@ -53,7 +53,7 @@ Architecture: PostProcess: name: DBPostProcess thresh: 0.3 - box_thresh: 0.4 + box_thresh: 0.6 max_candidates: 1000 unclip_ratio: 1.5 score_mode: 'slow' diff --git a/configs/rec/dptr/dptr_parseq_pretrain.yml b/configs/rec/dptr/dptr_parseq_pretrain.yml new file mode 100644 index 0000000..32b8adf --- /dev/null +++ b/configs/rec/dptr/dptr_parseq_pretrain.yml @@ -0,0 +1,88 @@ +Global: + device: gpu + epoch_num: 20 + log_smooth_window: 20 + print_batch_step: 10 + output_dir: /share/ckpt/zhaoshuai/openocr/dptr_parseq/ + eval_epoch_step: [0, 1] + eval_batch_step: [0, 500] + cal_metric_during_train: True + pretrained_model: + checkpoints: + use_tensorboard: false + infer_img: + # for data or label process + character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt + max_text_length: &max_text_length 25 + use_space_char: &use_space_char False + use_amp: True + save_res_path: /share/ckpt/zhaoshuai/openocr/dptr_parseq/predicts_dptr_parseq.txt + grad_clip_val: 20 + +Optimizer: + name: AdamW + lr: 0.001485 # 2gpus 384bs/gpu + weight_decay: 0. + filter_bias_and_bn: False + +LRScheduler: + name: OneCycleLR + warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep + cycle_momentum: False + +Architecture: + model_type: rec + algorithm: DPTR + Decoder: + name: DptrParseq + decode_ar: True + refine_iters: 1 + is_pretrain: True + ORP_path: /share/ckpt/zhaoshuai/parseq/clip_background.pth + +Loss: + name: PARSeqLoss + +PostProcess: + name: ARLabelDecode + character_dict_path: *character_dict_path + use_space_char: *use_space_char + +Metric: + name: RecMetric + main_indicator: acc + is_filter: True + +Train: + dataset: + name: TextLMDBDataSet + data_dir: /share/test/zhaoshuai/parseq-data/data/train/real/ArT + transforms: + - DPTRLabelEncode: # Class handling label + character_dict_path: *character_dict_path + use_space_char: *use_space_char + max_text_length: *max_text_length + - KeepKeys: + keep_keys: ['clip_label', 'label'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 256 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: TextLMDBDataSet + data_dir: /share/test/zhaoshuai/parseq-data/data/val + transforms: + - DPTRLabelEncode: # Class handling label + character_dict_path: *character_dict_path + use_space_char: *use_space_char + max_text_length: *max_text_length + - KeepKeys: + keep_keys: ['clip_label', 'label'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 2 diff --git a/configs/rec/igtr/readme.md b/configs/rec/igtr/readme.md index 61c7d33..dc420f7 100644 --- a/configs/rec/igtr/readme.md +++ b/configs/rec/igtr/readme.md @@ -13,11 +13,12 @@ Paper: -> [Instruction-Guided Scene Text Recognition](https://arxiv.org/abs/2401.17851) -> Yongkun Du, Zhineng Chen, Yuchen Su, Caiyan Jia, Yu-Gang Jiang +> [Instruction-Guided Scene Text Recognition](https://arxiv.org/abs/2401.17851), +> Yongkun Du, Zhineng Chen, Yuchen Su, Caiyan Jia, Yu-Gang Jiang, +> TPAMI -Multi-modal models show appealing performance in visual recognition tasks recently, as free-form text-guided training evokes the ability to understand fine-grained visual content. However, current models are either inefficient or cannot be trivially upgraded to scene text recognition (STR) due to the composition difference between natural and text images. We propose a novel instruction-guided scene text recognition (IGTR) paradigm that formulates STR as an instruction learning problem and understands text images by predicting character attributes, e.g., character frequency, position, etc. IGTR first devises $\\left \\langle condition,question,answer\\right \\rangle$ instruction triplets, providing rich and diverse descriptions of character attributes. To effectively learn these attributes through question-answering, IGTR develops lightweight instruction encoder, cross-modal feature fusion module and multi-task answer head, which guides nuanced text image understanding. Furthermore, IGTR realizes different recognition pipelines simply by using different instructions, enabling a character-understanding-based text reasoning paradigm that considerably differs from current methods. Experiments on English and Chinese benchmarks show that IGTR outperforms existing models by significant margins, while maintaining a small model size and efficient inference speed. Moreover, by adjusting the sampling of instructions, IGTR offers an elegant way to tackle the recognition of both rarely appearing and morphologically similar characters, which were previous challenges. +Multi-modal models have shown appealing performance in visual recognition tasks, as free-form text-guided training evokes the ability to understand fine-grained visual content. However, current models cannot be trivially applied to scene text recognition (STR) due to the compositional difference between natural and text images. We propose a novel instruction-guided scene text recognition (IGTR) paradigm that formulates STR as an instruction learning problem and understands text images by predicting character attributes, e.g., character frequency, position, etc. IGTR first devises $\\left \\langle condition,question,answer\\right \\rangle$ instruction triplets, providing rich and diverse descriptions of character attributes. To effectively learn these attributes through question-answering, IGTR develops a lightweight instruction encoder, a cross-modal feature fusion module and a multi-task answer head, which guides nuanced text image understanding. Furthermore, IGTR realizes different recognition pipelines simply by using different instructions, enabling a character-understanding-based text reasoning paradigm that differs from current methods considerably. Experiments on English and Chinese benchmarks show that IGTR outperforms existing models by significant margins, while maintaining a small model size and fast inference speed. Moreover, by adjusting the sampling of instructions, IGTR offers an elegant way to tackle the recognition of rarely appearing and morphologically similar characters, which were previous challenges. The accuracy (%) and model files of IGTR on the public dataset of scene text recognition are as follows: @@ -88,11 +89,11 @@ pip install -r requirements.txt #### Dataset Preparation -[English dataset download](https://github.com/baudm/parseq) +- [English dataset download](https://github.com/baudm/parseq) -[Union14M-L download](https://github.com/Mountchicken/Union14M) +- [Union14M-L-LMDB-Filtered download](https://drive.google.com/drive/folders/1OlDWJZgvd6s4S09S3IGeAI90jI0i7AB_?usp=sharing) -[Chinese dataset download](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download) +- [Chinese dataset download](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download) The expected filesystem structure is as follows: @@ -143,7 +144,7 @@ u14m # lmdb format ├── multi_oriented ├── multi_words └── salient -Union14M-LMDB-L # lmdb format +Union14M-L-LMDB-Filtered # lmdb format ├── train_challenging ├── train_easy ├── train_hard @@ -168,13 +169,15 @@ Evaluation: ```shell # The configuration file is available from the link provided in the table above. # en -python tools/eval_rec_all_ratio.py --c PATH/svtr_base_igtr_syn.yml +python tools/eval_rec_all_en.py --c PATH/svtr_base_igtr_syn.yml # ch python tools/eval_rec_all_ch.py --c PATH/svtr_base_igtr_ch_aug.yml ``` ## Citation +If you find our method useful for your reserach, please cite: + ```bibtex @article{Du2024IGTR, title = {Instruction-Guided Scene Text Recognition}, diff --git a/configs/rec/smtr/readme.md b/configs/rec/smtr/readme.md index ddda487..52bbedd 100644 --- a/configs/rec/smtr/readme.md +++ b/configs/rec/smtr/readme.md @@ -13,8 +13,9 @@ Paper: -> [Out of Length Text Recognition with Sub-String Matching](https://arxiv.org/abs/2407.12317) -> Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang +> [Out of Length Text Recognition with Sub-String Matching](https://arxiv.org/abs/2407.12317). +> Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang. +> AAAI 2025 Scene Text Recognition (STR) methods have demonstrated robust performance in word-level text recognition. However, in applications the text image is sometimes long due to detected with multiple horizontal words. It triggers the requirement to build long text recognition models from readily available short word-level text datasets, which has been less studied previously. In this paper, we term this the Out of Length (OOL) text recognition. We establish the first Long Text Benchmark (LTB) to facilitate the assessment of different methods in long text recognition. Meanwhile, we propose a novel method called OOL Text Recognition with sub-String Matching (SMTR). SMTR comprises two cross-attention-based modules: one encodes a sub-string containing multiple characters into next and previous queries, and the other employs the queries to attend to the image features, matching the sub-string and simultaneously recognizing its next and previous character. SMTR can recognize text of arbitrary length by iterating the process above. To avoid being trapped in recognizing highly similar sub-strings, we introduce a regularization training to compel SMTR to effectively discover subtle differences between similar sub-strings for precise matching. In addition, we propose an inference augmentation to alleviate confusion caused by identical sub-strings and improve the overall recognition efficiency. Extensive experimental results reveal that SMTR, even when trained exclusively on short text, outperforms existing methods in public short text benchmarks and exhibits a clear advantage on LTB. @@ -23,31 +24,31 @@ The accuracy (%) and model files of SMTR on the public dataset of scene text rec - Syn: Synth dataset(MJ+ST) from [PARSeq](https://github.com/baudm/parseq) -- U14M: Union14M-L from [Union14M](https://github.com/Mountchicken/Union14M/) +- Union14M-L-LMDB-Filtered: A filtered version of [Union14M](https://github.com/Mountchicken/Union14M/) - Test on Long Text Benchmark ([Download LTB](https://drive.google.com/drive/folders/1NChdlw7ustbXtlFBmh_0xnHvRkffb9Ge?usp=sharing)): -| Model | Training Data | LTB | Config&Model&Log | -| :-------: | :-----------: | :--: | :---------------------------------------------------------------------------------------------: | -| SMTR | Syn | 39.6 | [link](https://drive.google.com/drive/folders/11SplakPPOFDMhPixv7ABNgjeTg4jKyfU?usp=sharing) | -| SMTR | U14M | 51.0 | [link](https://drive.google.com/drive/folders/1-K5O0d0q9fhY5fJvU6nn5fFFtSMnbE_-?usp=drive_link) | -| FocalSVTR | U14M | 42.1 | [link](https://drive.google.com/drive/folders/100xF5wFr7xSCVBYM1h_0d_8xv5Qeqobp?usp=sharing) | +| Model | Training Data | LTB | Config&Model&Log | +| :-------: | :----------------------: | :--: | :---------------------------------------------------------------------------------------------: | +| SMTR | Syn | 39.6 | [link](https://drive.google.com/drive/folders/11SplakPPOFDMhPixv7ABNgjeTg4jKyfU?usp=sharing) | +| SMTR | Union14M-L-LMDB-Filtered | 51.0 | [link](https://drive.google.com/drive/folders/1-K5O0d0q9fhY5fJvU6nn5fFFtSMnbE_-?usp=drive_link) | +| FocalSVTR | Union14M-L-LMDB-Filtered | 42.1 | [link](https://drive.google.com/drive/folders/100xF5wFr7xSCVBYM1h_0d_8xv5Qeqobp?usp=sharing) | - Test on Common Benchmarks from [PARSeq](https://github.com/baudm/parseq): -| Model | Training Data | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log | -| :-------: | :-----------: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :---------------------: | -| SMTR | Syn | 97.4 | 94.9 | 97.4 | 88.4 | 89.9 | 96.2 | 94.02 | Same as the above table | -| SMTR | U14M | 98.3 | 97.4 | 99.0 | 90.1 | 92.7 | 97.9 | 95.90 | Same as the above table | -| FocalSVTR | U14M | 97.3 | 96.3 | 98.2 | 87.4 | 88.4 | 96.2 | 93.97 | Same as the above table | +| Model | Training Data | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log | +| :-------: | :----------------------: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :---------------------: | +| SMTR | Syn | 97.4 | 94.9 | 97.4 | 88.4 | 89.9 | 96.2 | 94.02 | Same as the above table | +| SMTR | Union14M-L-LMDB-Filtered | 98.3 | 97.4 | 99.0 | 90.1 | 92.7 | 97.9 | 95.90 | Same as the above table | +| FocalSVTR | Union14M-L-LMDB-Filtered | 97.3 | 96.3 | 98.2 | 87.4 | 88.4 | 96.2 | 93.97 | Same as the above table | - Test on Union14M-L benchmark from [Union14M](https://github.com/Mountchicken/Union14M/). -| Model | Traing Data | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log | -| :-------: | :---------: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: | -| SMTR | Syn | 74.2 | 30.6 | 58.5 | 67.6 | 79.6 | 75.1 | 67.9 | 64.79 | Same as the above table | -| SMTR | U14M | 89.1 | 87.7 | 76.8 | 83.9 | 84.6 | 89.3 | 83.7 | 85.00 | Same as the above table | -| FocalSVTR | U14M | 77.7 | 62.4 | 65.7 | 78.6 | 71.6 | 81.3 | 79.2 | 73.80 | Same as the above table | +| Model | Traing Data | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log | +| :-------: | :----------------------: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: | +| SMTR | Syn | 74.2 | 30.6 | 58.5 | 67.6 | 79.6 | 75.1 | 67.9 | 64.79 | Same as the above table | +| SMTR | Union14M-L-LMDB-Filtered | 89.1 | 87.7 | 76.8 | 83.9 | 84.6 | 89.3 | 83.7 | 85.00 | Same as the above table | +| FocalSVTR | Union14M-L-LMDB-Filtered | 77.7 | 62.4 | 65.7 | 78.6 | 71.6 | 81.3 | 79.2 | 73.80 | Same as the above table | - Training and test on Chinese dataset, from [Chinese Benckmark](https://github.com/FudanVI/benchmarking-chinese-text-recognition). @@ -79,7 +80,7 @@ pip install -r requirements.txt - [English dataset download](https://github.com/baudm/parseq) -- [Union14M-L download](https://github.com/Mountchicken/Union14M) +- [Union14M-L-LMDB-Filtered download](https://drive.google.com/drive/folders/1OlDWJZgvd6s4S09S3IGeAI90jI0i7AB_?usp=sharing) - [Chinese dataset download](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download) @@ -135,7 +136,7 @@ u14m # lmdb format ├── multi_words └── salient ltb # download link: https://drive.google.com/drive/folders/1NChdlw7ustbXtlFBmh_0xnHvRkffb9Ge?usp=sharing -Union14M-LMDB-L # lmdb format +Union14M-L-LMDB-Filtered # lmdb format ├── train_challenging ├── train_easy ├── train_hard @@ -160,7 +161,7 @@ Evaluation: ```shell # en -python tools/eval_rec_all_ratio.py --c configs/rec/smtr/focalsvtr_smtr.yml +python tools/eval_rec_all_en.py --c configs/rec/smtr/focalsvtr_smtr.yml # long text python tools/eval_rec_all_long_simple.py --c configs/rec/smtr/focalsvtr_smtr_long.yml # ch @@ -169,6 +170,8 @@ python tools/eval_rec_all_ch.py --c configs/rec/smtr/focalsvtr_smtr_ch.yml ## Citation +If you find our method useful for your reserach, please cite: + ```bibtex @article{Du2024SMTR, title = {Out of Length Text Recognition with Sub-String Matching}, diff --git a/configs/rec/svtrv2/SVTRv2.pdf b/configs/rec/svtrv2/SVTRv2.pdf deleted file mode 100644 index 0be5128..0000000 Binary files a/configs/rec/svtrv2/SVTRv2.pdf and /dev/null differ diff --git a/configs/rec/svtrv2/readme.md b/configs/rec/svtrv2/readme.md index aade0b5..31abbf6 100644 --- a/configs/rec/svtrv2/readme.md +++ b/configs/rec/svtrv2/readme.md @@ -18,7 +18,7 @@ Paper: -> [SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition](./SVTRv2.pdf) +> [SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition](https://arxiv.org/abs/2411.15858) > Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang @@ -30,11 +30,11 @@ The accuracy (%) and model files of SVTRv2 on the public dataset of scene text r Download all Configs, Models, and Logs from [Google Drive](https://drive.google.com/drive/folders/1i2EZVT-oxfDIDdhwQRm9E6Fk8s6qD3C1?usp=sharing). -| Model | Model size | FPS | -| :------: | :--------: | :-: | -| SVTRv2-T | 5.13 | 5.0 | -| SVTRv2-S | 11.25 | 5.3 | -| SVTRv2-B | 19.76 | 7.0 | +| Model | Model size | Latency | +| :------: | :--------: | :-----: | +| SVTRv2-T | 5.13 | 5.0 | +| SVTRv2-S | 11.25 | 5.3 | +| SVTRv2-B | 19.76 | 7.0 | - Test on Common Benchmarks from [PARSeq](https://github.com/baudm/parseq): @@ -102,10 +102,10 @@ Referring to [Downloading Datasets](../../../docs/svtrv2.md#downloading-datasets CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_rctc.yml # Second stage -CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Global.pretrained_model=./output/rec/u14m_filter/svtrv2_rctc/best.pth +CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch --master_port=23332 --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Global.pretrained_model=./output/rec/u14m_filter/svtrv2_rctc/best.pth # For Multi RTX 4090 -NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_rctc.yml +NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port=23333 --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_rctc.yml # 20epoch runs for about 6 hours ``` @@ -113,9 +113,10 @@ NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.laun ```shell # short text: Common, Union14M-Benchmark, OST -python tools/eval_rec_all_ratio.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml -# long text -python tools/eval_rec_all_long_simple.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +python tools/eval_rec_all_en.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml + +# long text: LTB +python tools/eval_rec_all_long.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml --o Eval.loader.max_ratio=20 ``` After a successful run, the results are saved in a csv file in `output_dir` in the config file. @@ -123,7 +124,7 @@ After a successful run, the results are saved in a csv file in `output_dir` in t ### Inference ```shell -python tools/infer_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Global.infer_img=/path/img_fold or /path/img_file +python tools/infer_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml --o Global.infer_img=/path/img_fold or /path/img_file ``` ### Latency Measurement @@ -131,15 +132,22 @@ python tools/infer_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Gl Firstly, downloading the IIIT5K images from [Google Drive](https://drive.google.com/drive/folders/1Po1LSBQb87DxGJuAgLNxhsJ-pdXxpIfS?usp=drive_link). Then, running the following command: ```shell -python tools/infer_rec.py --c configs/rec/SVTRv2/svtrv2_rctc.yml --o Global.infer_img=../iiit5k_test_image +python tools/infer_rec.py --c configs/rec/SVTRv2/svtrv2_smtr_gtc_rctc_infer.yml --o Global.infer_img=../iiit5k_test_image ``` ## Citation +If you find our method useful for your reserach, please cite: + ```bibtex -@article{Du2024SVTRv4, - title = {SVTRv2: Scene Text Recognition with a Single Visual Model}, - author = {Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang}, - year = {2024} +@article{Du2024SVTRv2, + title={SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition}, + author={Yongkun Du and Zhineng Chen and Hongtao Xie and Caiyan Jia and Yu-Gang Jiang}, + journal={CoRR}, + volume={abs/2411.15858}, + eprinttype={arXiv}, + year={2024}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2411.15858} } ``` diff --git a/configs/rec/svtrv2/repsvtr_ch.yml b/configs/rec/svtrv2/repsvtr_ch.yml index 033d6e7..fd68312 100644 --- a/configs/rec/svtrv2/repsvtr_ch.yml +++ b/configs/rec/svtrv2/repsvtr_ch.yml @@ -34,7 +34,7 @@ LRScheduler: Architecture: model_type: rec - algorithm: SVTRv2 + algorithm: SVTRv2_mobile Transform: Encoder: name: RepSVTREncoder @@ -53,6 +53,7 @@ Loss: PostProcess: name: CTCLabelDecode + character_dict_path: *character_dict_path Metric: name: RecMetric diff --git a/configs/rec/svtrv2/svtrv2_ch.yml b/configs/rec/svtrv2/svtrv2_ch.yml index a6538df..6ee2e8f 100644 --- a/configs/rec/svtrv2/svtrv2_ch.yml +++ b/configs/rec/svtrv2/svtrv2_ch.yml @@ -34,7 +34,7 @@ LRScheduler: Architecture: model_type: rec - algorithm: SVTRv2 + algorithm: SVTRv2_server Transform: Encoder: name: SVTRv2LNConvTwo33 @@ -65,6 +65,7 @@ Loss: PostProcess: name: CTCLabelDecode + character_dict_path: *character_dict_path Metric: name: RecMetric diff --git a/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml index aa36492..20a9d7c 100644 --- a/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +++ b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml @@ -3,7 +3,7 @@ Global: epoch_num: 20 log_smooth_window: 20 print_batch_step: 10 - output_dir: ./output/rec/u14m_filter/svtrv2_smtr_gtc_rctc_maxratio12 + output_dir: ./output/rec/u14m_filter/svtrv2_smtr_gtc_rctc save_epoch_step: 1 # evaluation is run every 2000 iterations eval_batch_step: [0, 500] diff --git a/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml new file mode 100644 index 0000000..2361eb9 --- /dev/null +++ b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml @@ -0,0 +1,151 @@ +Global: + device: gpu + epoch_num: 20 + log_smooth_window: 20 + print_batch_step: 10 + output_dir: ./output/rec/u14m_filter/svtrv2_smtr_gtc_rctc + save_epoch_step: 1 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 500] + eval_epoch_step: [0, 1] + cal_metric_during_train: True + pretrained_model: + # ./output/rec/u14m_filter/svtrv2_rctc/best.pth + checkpoints: + use_tensorboard: false + infer_img: + # for data or label process + character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en + # ./tools/utils/ppocr_keys_v1.txt # ch + max_text_length: &max_text_length 25 + use_space_char: &use_space_char False + save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_smtr_gtc_rctc.txt + use_amp: True + +Optimizer: + name: AdamW + lr: 0.000325 # for 4gpus bs128/gpu + weight_decay: 0.05 + filter_bias_and_bn: True + +LRScheduler: + name: OneCycleLR + warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep + cycle_momentum: False + +Architecture: + model_type: rec + algorithm: SVTRv2 + in_channels: 3 + Transform: + Encoder: + name: SVTRv2LNConvTwo33 + use_pos_embed: False + dims: [128, 256, 384] + depths: [6, 6, 6] + num_heads: [4, 8, 12] + mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']] + local_k: [[5, 5], [5, 5], [-1, -1]] + sub_k: [[1, 1], [2, 1], [-1, -1]] + last_stage: false + feat2d: True + Decoder: + name: GTCDecoder + infer_gtc: False + detach: False + gtc_decoder: + name: SMTRDecoder + num_layer: 1 + ds: True + max_len: *max_text_length + next_mode: &next True + sub_str_len: &subsl 5 + ctc_decoder: + name: RCTCDecoder + +Loss: + name: CTCLoss + zero_infinity: True + +PostProcess: + name: CTCLabelDecode + character_dict_path: *character_dict_path + use_space_char: *use_space_char + +Metric: + name: RecMetric + main_indicator: acc + is_filter: True + +Train: + dataset: + name: RatioDataSetTVResize + ds_width: True + padding: false + data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging', + '../Union14M-L-LMDB-Filtered/filter_train_hard', + '../Union14M-L-LMDB-Filtered/filter_train_medium', + '../Union14M-L-LMDB-Filtered/filter_train_normal', + '../Union14M-L-LMDB-Filtered/filter_train_easy', + ] + transforms: + - DecodeImagePIL: # load image + img_mode: RGB + - PARSeqAugPIL: + - CTCLabelEncode: # Class handling label + character_dict_path: *character_dict_path + use_space_char: *use_space_char + max_text_length: *max_text_length + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + sampler: + name: RatioSampler + scales: [[128, 32]] # w, h + # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple + first_bs: &bs 128 + fix_bs: false + divided_factor: [4, 16] # w, h + is_training: True + loader: + shuffle: True + batch_size_per_card: *bs + drop_last: True + max_ratio: &max_ratio 12 + num_workers: 4 + +Eval: + dataset: + name: RatioDataSetTVResize + ds_width: True + padding: False + data_dir_list: [ + '../evaluation/CUTE80', + '../evaluation/IC13_857', + '../evaluation/IC15_1811', + '../evaluation/IIIT5k', + '../evaluation/SVT', + '../evaluation/SVTP', + ] + transforms: + - DecodeImagePIL: # load image + img_mode: RGB + - CTCLabelEncode: # Class handling label + character_dict_path: *character_dict_path + use_space_char: *use_space_char + max_text_length: *max_text_length + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + sampler: + name: RatioSampler + scales: [[128, 32]] # w, h + # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple + first_bs: *bs + fix_bs: false + divided_factor: [4, 16] # w, h + is_training: False + loader: + shuffle: False + drop_last: False + batch_size_per_card: *bs + max_ratio: *max_ratio + num_workers: 4 diff --git a/demo_gradio.py b/demo_gradio.py new file mode 100644 index 0000000..a40ba0e --- /dev/null +++ b/demo_gradio.py @@ -0,0 +1,184 @@ +# @Author: OpenOCR +# @Contact: 784990967@qq.com +import os +import gradio as gr # gradio==4.20.0 + +os.environ['FLAGS_allocator_strategy'] = 'auto_growth' +import cv2 +import numpy as np +import json +import time +from PIL import Image +from tools.infer_e2e import OpenOCR, check_and_download_font, draw_ocr_box_txt + +drop_score = 0.4 +text_sys = OpenOCR(drop_score=drop_score) +# warm up 5 times +if True: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(5): + res = text_sys(img_numpy=img) +font_path = './simfang.ttf' +font_path = check_and_download_font(font_path) + + +def main(input_image, + det_input_size_textbox=960, + rec_drop_score=0.01, + mask_thresh=0.3, + box_thresh=0.6, + unclip_ratio=1.5, + det_score_mode='slow'): + img = input_image[:, :, ::-1] + starttime = time.time() + results, time_dict, mask = text_sys( + img_numpy=img, + return_mask=True, + det_input_size=int(det_input_size_textbox), + thresh=mask_thresh, + box_thresh=box_thresh, + unclip_ratio=unclip_ratio, + score_mode=det_score_mode) + elapse = time.time() - starttime + save_pred = json.dumps(results[0], ensure_ascii=False) + image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + boxes = [res['points'] for res in results[0]] + txts = [res['transcription'] for res in results[0]] + scores = [res['score'] for res in results[0]] + draw_img = draw_ocr_box_txt( + image, + boxes, + txts, + scores, + drop_score=rec_drop_score, + font_path=font_path, + ) + mask = mask[0, 0, :, :] > mask_thresh + return save_pred, elapse, draw_img, mask.astype('uint8') * 255 + + +def get_all_file_names_including_subdirs(dir_path): + all_file_names = [] + + for root, dirs, files in os.walk(dir_path): + for file_name in files: + all_file_names.append(os.path.join(root, file_name)) + + file_names_only = [os.path.basename(file) for file in all_file_names] + return file_names_only + + +def list_image_paths(directory): + image_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff') + + image_paths = [] + + for root, dirs, files in os.walk(directory): + for file in files: + if file.lower().endswith(image_extensions): + relative_path = os.path.relpath(os.path.join(root, file), + directory) + full_path = os.path.join(directory, relative_path) + image_paths.append(full_path) + image_paths = sorted(image_paths) + return image_paths + + +def find_file_in_current_dir_and_subdirs(file_name): + for root, dirs, files in os.walk('.'): + if file_name in files: + relative_path = os.path.join(root, file_name) + return relative_path + + +e2e_img_example = list_image_paths('./OCR_e2e_img') + +if __name__ == '__main__': + css = '.image-container img { width: 100%; max-height: 320px;}' + + with gr.Blocks(css=css) as demo: + gr.HTML(""" +

OpenOCR

+

准确高效的通用 OCR 系统 (由FVL实验室 OCR Team 创建)

""" + ) + with gr.Row(): + with gr.Column(scale=1): + input_image = gr.Image(label='Input image', + elem_classes=['image-container']) + + examples = gr.Examples(examples=e2e_img_example, + inputs=input_image, + label='Examples') + downstream = gr.Button('Run') + + # 添加参数调节组件 + with gr.Column(): + with gr.Row(): + det_input_size_textbox = gr.Number( + label='Det Input Size', + value=960, + info='检测网络输入尺寸的最长边,默认为960。') + det_score_mode_dropdown = gr.Dropdown( + ['slow', 'fast'], + value='slow', + label='Det Score Mode', + info= + '文本框的置信度计算模式,默认为 slow。slow 模式计算速度较慢,但准确度较高。fast 模式计算速度较快,但准确度较低。' + ) + with gr.Row(): + rec_drop_score_slider = gr.Slider( + 0.0, + 1.0, + value=0.4, + step=0.01, + label='Recognition Drop Score', + info='识别置信度阈值,默认值为0.4。低于该阈值的识别结果和对应的文本框被丢弃。') + mask_thresh_slider = gr.Slider( + 0.0, + 1.0, + value=0.3, + step=0.01, + label='Mask Threshold', + info='Mask 阈值,用于二值化 mask,默认值为0.3。如果存在文本截断时,请调低该值。') + with gr.Row(): + box_thresh_slider = gr.Slider( + 0.0, + 1.0, + value=0.6, + step=0.01, + label='Box Threshold', + info='文本框置信度阈值,默认值为0.6。如果存在文本被漏检时,请调低该值。') + unclip_ratio_slider = gr.Slider( + 1.5, + 2.0, + value=1.5, + step=0.05, + label='Unclip Ratio', + info='文本框解析时的膨胀系数,默认值为1.5。值越大文本框越大。') + + with gr.Column(scale=1): + img_mask = gr.Image(label='mask', + interactive=False, + elem_classes=['image-container']) + img_output = gr.Image(label=' ', + interactive=False, + elem_classes=['image-container']) + + output = gr.Textbox(label='Result') + confidence = gr.Textbox(label='Latency') + + downstream.click(fn=main, + inputs=[ + input_image, det_input_size_textbox, + rec_drop_score_slider, mask_thresh_slider, + box_thresh_slider, unclip_ratio_slider, + det_score_mode_dropdown + ], + outputs=[ + output, + confidence, + img_output, + img_mask, + ]) + + demo.launch(share=True) diff --git a/demo_rec.py b/demo_rec.py deleted file mode 100644 index a1c604a..0000000 --- a/demo_rec.py +++ /dev/null @@ -1,131 +0,0 @@ -import os - -import gradio as gr -import numpy as np -import torch - -from openrec.modeling import build_model -from openrec.postprocess import build_post_process -from openrec.preprocess import create_operators, transform -from tools.engine import Config -from tools.utils.ckpt import load_ckpt - - -def build_rec_process(cfg): - transforms = [] - for op in cfg['Eval']['dataset']['transforms']: - op_name = list(op)[0] - if 'Label' in op_name: - continue - # TODO - elif op_name in ['DecodeImage']: - op[op_name]['gradio_infer_mode'] = True - - elif op_name in ['RecResizeImg']: - op[op_name]['infer_mode'] = True - elif op_name == 'KeepKeys': - if cfg['Architecture']['algorithm'] == 'SRN': - op[op_name]['keep_keys'] = [ - 'image', - 'encoder_word_pos', - 'gsrm_word_pos', - 'gsrm_slf_attn_bias1', - 'gsrm_slf_attn_bias2', - ] - elif cfg['Architecture']['algorithm'] == 'SAR': - op[op_name]['keep_keys'] = ['image', 'valid_ratio'] - elif cfg['Architecture']['algorithm'] == 'RobustScanner': - op[op_name]['keep_keys'] = [ - 'image', 'valid_ratio', 'word_positons' - ] - else: - op[op_name]['keep_keys'] = ['image'] - transforms.append(op) - return transforms - - -def get_all_file_names_including_subdirs(dir_path): - all_file_names = [] - - for root, dirs, files in os.walk(dir_path): - for file_name in files: - all_file_names.append(os.path.join(root, file_name)) - - file_names_only = [os.path.basename(file) for file in all_file_names] - return file_names_only - - -root_directory = './configs/rec' -yml_Config = get_all_file_names_including_subdirs(root_directory) - - -def find_file_in_current_dir_and_subdirs(file_name): - for root, dirs, files in os.walk('.'): - if file_name in files: - relative_path = os.path.join(root, file_name) - return relative_path - - -def predict(input_image, Model_type, OCR_type): - - path = find_file_in_current_dir_and_subdirs(Model_type) - - cfg = Config(path).cfg - post_process_class = build_post_process(cfg['PostProcess']) - global_config = cfg['Global'] - char_num = len(getattr(post_process_class, 'character')) - cfg['Architecture']['Decoder']['out_channels'] = char_num - model = build_model(cfg['Architecture']) - load_ckpt(model, cfg) - model.eval() - - transforms = build_rec_process(cfg) - global_config['infer_mode'] = True - ops = create_operators(transforms, global_config) - data = {'image': input_image} - batch = transform(data, ops) - others = None - images = np.expand_dims(batch[0], axis=0) - images = torch.from_numpy(images) - with torch.no_grad(): - preds = model(images, others) - post_result = post_process_class(preds) - return post_result[0][0], post_result[0][1] - - -if __name__ == '__main__': - - with gr.Blocks() as demo: - with gr.Row(): - with gr.Column(scale=1): - input_image = gr.Image(label='Input Image') - - # TODO - OCR_type = gr.Radio(['STR', 'STD', 'E2E'], label='模型类别') - - Model_type = gr.Dropdown(choices=yml_Config, label='现有模型配置文件') - - downstream = gr.Button('识别结果') - - with gr.Column(scale=1): - - # TODO - img_output = gr.Image(label='图片识别结果') - - output = gr.Textbox(label='文字识别结果') - confidence = gr.Textbox(label='置信度') - - downstream.click( - fn=predict, - inputs=[ - input_image, - Model_type, - OCR_type, - ], - outputs=[ - output, - confidence, - # TODO img_output, - ]) - - demo.launch(debug=True) diff --git a/docs/openocr.md b/docs/openocr.md index 4129747..2eb7b80 100644 --- a/docs/openocr.md +++ b/docs/openocr.md @@ -1,29 +1,34 @@ -# OpenOCR: A general OCR system for accuracy and efficiency +# OpenOCR: A general OCR system with accuracy and efficiency -We proposed strategies to comprehensively enhance CTC-based STR models and developed a novel CTC-based method, [SVTRv2](../configs/rec/svtrv2/). SVTRv2 can outperform previous attention-based STR methods in terms of accuracy while maintaining the advantages of CTC, such as fast inference and robust recognition of long text sequences. These features make SVTRv2 particularly well-suited for commercial applications. To this end, building on SVTRv2, we develop a practical version of the model from scratch on publicly available Chinese and English datasets. Combined with a detection model, this forms an accurate and efficient general OCR system, OpenOCR. Comparing with PP-OCRv4 released by PaddleOCR, OpenOCR achieve a 4.5% improvement on the [OCR competition leaderboard](https://aistudio.baidu.com/competition/detail/1131/0/leaderboard). +⚡\[[Quick Start](#quick-start)\] \[[Model](https://github.com/Topdu/OpenOCR/releases/tag/develop0.0.1)\] \[[ModelScope Demo](https://modelscope.cn/studios/topdktu/OpenOCR-Demo)\] \[[Hugging Face Demo](https://huggingface.co/spaces/topdu/OpenOCR-Demo)\] \[[Local Demo](#local-demo)\] \[[PaddleOCR Implementation](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/text_recognition/algorithm_rec_svtrv2.html)\] + +We proposed strategies to comprehensively enhance CTC-based STR models and developed a novel CTC-based method, [SVTRv2](../configs/rec/svtrv2/). SVTRv2 can outperform previous attention-based STR methods in terms of accuracy while maintaining the advantages of CTC, such as fast inference and robust recognition of long text. These features make SVTRv2 particularly well-suited for practical applications. To this end, building on SVTRv2, we develop a practical version of the model from scratch on publicly available Chinese and English datasets. Combined with a detection model, this forms a general OCR system with accuracy and efficiency, **OpenOCR**. Comparing with [PP-OCRv4](https://paddlepaddle.github.io/PaddleOCR/latest/ppocr/model_list.html) baseline in the [OCR competition leaderboard](https://aistudio.baidu.com/competition/detail/1131/0/leaderboard), OpenOCR (mobile) achieve a 4.5% improvement in terms of accuracy, while preserving quite similar inference speed on NVIDIA 1080Ti GPU. | Model | Config | E2E Metric | Downloading | | ------------------- | ----------------------------------------------------------------------------------- | ---------- | ---------------------------------------------------------------------------------------- | | PP-OCRv4 | | 62.77% | [PaddleOCR Model List](../../ppocr/model_list.md) | -| SVTRv2 (Rec Server) | [configs/rec/svtrv2/svtrv2_ch.yml](../configs/rec/svtrv2/svtrv2_ch.yml) | 68.81% | [Google Dirve ](https://drive.google.com/file/d/13LXbIVEyx2Aat3X_vVte4JQgQ7yJWdxH/view?usp=drive_link) | -| RepSVTR (Mobile) | [Rec: configs/rec/svtrv2/repsvtr_ch.yml](../configs/rec/svtrv2/repsvtr_ch.yml)
[Det: configs/det/dbnet/repvit_db.yml](../configs/det/dbnet/repvit_db.yml) | 67.22% | [Rec: Google Drive](https://drive.google.com/file/d/1DNfarP_UmTqZnENjmmQHCexqzVmrIfLF/view?usp=drive_link)
[Det: Google Drive](https://drive.google.com/file/d/1eR6k5NitCvFEiGlYx1lAArVupIszfEmM/view?usp=drive_link) | +| SVTRv2 (Rec Server) | [configs/rec/svtrv2/svtrv2_ch.yml](../configs/rec/svtrv2/svtrv2_ch.yml) | 68.81% | [Google Dirve](https://drive.google.com/file/d/13LXbIVEyx2Aat3X_vVte4JQgQ7yJWdxH/view?usp=drive_link), [Github Released](https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth) | +| RepSVTR (Mobile) | [Rec: configs/rec/svtrv2/repsvtr_ch.yml](../configs/rec/svtrv2/repsvtr_ch.yml)
[Det: configs/det/dbnet/repvit_db.yml](../configs/det/dbnet/repvit_db.yml) | 67.22% | [Rec: Google Drive](https://drive.google.com/file/d/1DNfarP_UmTqZnENjmmQHCexqzVmrIfLF/view?usp=drive_link), [Github Released](https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth)
[Det: Google Drive](https://drive.google.com/file/d/1eR6k5NitCvFEiGlYx1lAArVupIszfEmM/view?usp=drive_link), [Github Released](https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth) | ## Quick Start -#### Dependencies: +### Dependencies: - [PyTorch](http://pytorch.org/) version >= 1.13.0 - Python version >= 3.7 ```shell -conda create -n openocre2e python==3.8 -conda activate openocre2e +conda create -n openocr python==3.8 +conda activate openocr +# install gpu version torch conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia +# or cpu version +conda install pytorch torchvision torchaudio cpuonly -c pytorch ``` After installing dependencies, the following two installation methods are available. Either one can be chosen. -#### 1. Python Modules +### 1. Python Modules ```shell pip install openocr-python @@ -38,19 +43,21 @@ engine = OpenOCR() img_path = '/path/img_path or /path/img_file' result, elapse = engine(img_path) -print(result) -print(elapse) # Server mode -engine = OpenOCR(mode='server') +# engine = OpenOCR(mode='server') ``` -#### 2. Clone this repository: +### 2. Clone this repository: ```shell git clone https://github.com/Topdu/OpenOCR.git cd OpenOCR pip install -r requirements.txt +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth +# Rec Server model +# wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth ``` **Usage**: @@ -58,14 +65,22 @@ pip install -r requirements.txt ```shell # OpenOCR system: Det + Rec model python tools/infer_e2e.py --img_path=/path/img_fold or /path/img_file - # Det model python tools/infer_det.py --c ./configs/det/dbnet/repvit_db.yml --o Global.infer_img=/path/img_fold or /path/img_file - # Rec model python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.infer_img=/path/img_fold or /path/img_file ``` +#### Local Demo + +```shell +pip install gradio==4.20.0 +wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/OCR_e2e_img.tar +tar xf OCR_e2e_img.tar +# start demo +python demo_gradio.py +``` + ## Fine-tuning on a Custom dataset TODO @@ -88,6 +103,50 @@ TODO -The results show that OpenOCR’s detection model outperforms PP-OCRv4 in generating more complete and accurate text boundaries, effectively capturing entire text instances. This reflects its larger receptive field and its ability to avoid common issues like merging separate text instances or splitting a single instance into multiple fragments. +### Det + Rec System results -In terms of recognition, OpenOCR demonstrates superior adaptability to challenging scenarios, such as artistic fonts, handwriting, blur, low resolution, incomplete text, and occlusion. Notably, the OpenOCR mobile model performs at a level comparable to PP-OCRv4's larger server-side model, showcasing its efficiency and robustness. +
+ +
+
+ +
+
+ +
+ +### **Detection Model Performance** + +In the examples provided, OpenOCR's detection model generates bounding boxes that are generally more comprehensive and better aligned with the boundaries of text instances compared to PP-OCRv4. In addition, OpenOCR excels in distinguishing separate text instances, avoiding errors such as merging two distinct text instances into one or splitting a single instance into multiple parts. This indicates superior handling of **semantic completeness and spatial understanding**, making it particularly effective for complex layouts. + +### **Recognition Model Generalization** + +OpenOCR's recognition model demonstrates enhanced generalization capabilities when compared to PP-OCRv4. It performs exceptionally well in recognizing text under difficult conditions, such as: + +- Artistic or stylized fonts. +- Handwritten text. +- Blurry or low-resolution images. +- Incomplete or occluded text. + +Remarkably, the **OpenOCR mobile recognition model** delivers results comparable to the larger and more resource-intensive **PP-OCRv4 server model**. This highlights OpenOCR's efficiency and accuracy, making it a versatile solution across different hardware platforms. + +### **System used in Real-World Scenarios** + +As shown in Det + Rec System results, OpenOCR demonstrates outstanding performance in practical scenarios, including documents, tables, invoices, and similar contexts. This underscores its potential as a **general-purpose OCR system**. It is capable of adapting to diverse use cases with high accuracy and reliability. + +## Citation + +If you find our method useful for your reserach, please cite: + +```bibtex +@article{Du2024SVTRv2, + title={SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition}, + author={Yongkun Du and Zhineng Chen and Hongtao Xie and Caiyan Jia and Yu-Gang Jiang}, + journal={CoRR}, + volume={abs/2411.15858}, + eprinttype={arXiv}, + year={2024}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2411.15858} +} +``` diff --git a/docs/svtrv2.md b/docs/svtrv2.md index bb2be79..a0d6c4b 100644 --- a/docs/svtrv2.md +++ b/docs/svtrv2.md @@ -1,6 +1,6 @@ # SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition -\[[Paper](../configs/rec/svtrv2/SVTRv2.pdf)\] \[[Model](../configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Config, Training and Inference](../configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] +\[[Paper](https://arxiv.org/abs/2411.15858)\] \[[Doc](../configs/rec/svtrv2/)\] \[[Model](../configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Datasets](#downloading-datasets)\] \[[Config, Training and Inference](../configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] \[[Benchmark](#results-benchmark--configs--checkpoints)\] ## Introduction @@ -190,7 +190,7 @@ Union14M-L-LMDB-Filtered # lmdb format Union14M-L-Filtered | Union14M-L-Filter | [LMDB archives](https://drive.google.com/drive/folders/1OlDWJZgvd6s4S09S3IGeAI90jI0i7AB_?usp=sharing) | | | Evaluation | [LMDB archives](https://drive.google.com/drive/folders/1EW0_YvmRSdpVOkR2guTQFrGz7KNqNc66?usp=drive_link) | | -If you have downloaded Union14M-L, you can use [the filtered list of images](https://drive.google.com/drive/folders/1x1LC8C_W-Frl3sGV9i9_i_OD-bqNdodJ?usp=drive_link) to create an LMDB of the training set Union14M-L-Filter. +If you have downloaded Union14M-L, you can use [the filtered list of images](https://drive.google.com/drive/folders/1x1LC8C_W-Frl3sGV9i9_i_OD-bqNdodJ?usp=drive_link) to create an LMDB of the training set Union14M-L-Filter (detailed in [create_lmdb_dataset.py](../tools/create_lmdb_dataset.py)). #### **Test Set** @@ -211,9 +211,9 @@ If you have downloaded Union14M-L, you can use [the filtered list of images](htt ```shell # Multi GPU training -CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_maxratio12.yml +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml # For Multi RTX 4090 -NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_maxratio12.yml +NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml # 20epoch runs for about 6 hours ``` @@ -221,9 +221,7 @@ NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.laun ```shell # short text: Common, Union14M-Benchmark, OST -python tools/eval_rec_all_ratio.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml -# long text -python tools/eval_rec_all_long_simple.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml +python tools/eval_rec_all_en.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml ``` After a successful run, the results are saved in a csv file in `output_dir` in the config file. @@ -242,23 +240,9 @@ Firstly, downloading the IIIT5K images from [Google Drive](https://drive.google. python tools/infer_rec.py --c configs/rec/svtrv2/svtrv2_rctc.yml --o Global.infer_img=../iiit5k_test_image ``` -## Results & Configs & Checkpoints: +## Results (Benchmark) & Configs & Checkpoints: -Downloading all model checkpoints from [Google Drive](<>) and [Baidu Yun](<>). - - +(TODO) Downloading all model checkpoints from [Google Drive](<>) and [Baidu Yun](<>). @@ -994,25 +978,10 @@ Downloading all model checkpoints from [Google Drive](<>) and [Baidu Yun](<>).
-**Note**: TF$\_n$ denotes the $n$-layer Transformer block. $Size$ denotes the number of parameters ($M$). $Latency$ is measured on one NVIDIA 1080Ti GPU with Pytorch Dynamic mode. +**Note**: TF$\_n$ denotes the $n$-layer Transformer block. $Size$ denotes the number of parameters ($M$). $Latency$ is measured on one NVIDIA 1080Ti GPU with Pytorch dynamic graph mode. ## Results when trained on synthetic datasets ($ST$ + $MJ$). - - @@ -1678,10 +1647,17 @@ Downloading all model checkpoints from [Google Drive](<>) and [Baidu Yun](<>). ## Citation +If you find our method useful for your reserach, please cite: + ```bibtex -@article{Du2024SVTRv4, - title = {SVTRv2: Scene Text Recognition with a Single Visual Model}, - author = {Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang}, - year = {2024} +@article{Du2024SVTRv2, + title={SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition}, + author={Yongkun Du and Zhineng Chen and Hongtao Xie and Caiyan Jia and Yu-Gang Jiang}, + journal={CoRR}, + volume={abs/2411.15858}, + eprinttype={arXiv}, + year={2024}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2411.15858} } ``` diff --git a/opendet/postprocess/db_postprocess.py b/opendet/postprocess/db_postprocess.py index dd6b199..60c1e7f 100644 --- a/opendet/postprocess/db_postprocess.py +++ b/opendet/postprocess/db_postprocess.py @@ -208,7 +208,12 @@ def box_score_slow(self, bitmap, contour): cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype('int32'), 1) return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] - def __call__(self, outs_dict, shape_list): + def __call__(self, outs_dict, shape_list, **kwargs): + self.thresh = kwargs.get('mask_thresh', self.thresh) + self.box_thresh = kwargs.get('box_thresh', self.box_thresh) + self.unclip_ratio = kwargs.get('unclip_ratio', self.unclip_ratio) + self.box_type = kwargs.get('box_type', self.box_type) + self.score_mode = kwargs.get('score_mode', self.score_mode) pred = outs_dict['maps'] if isinstance(pred, torch.Tensor): pred = pred.detach().cpu().numpy() diff --git a/opendet/preprocess/db_resize_for_test.py b/opendet/preprocess/db_resize_for_test.py index dbd95f0..e974884 100644 --- a/opendet/preprocess/db_resize_for_test.py +++ b/opendet/preprocess/db_resize_for_test.py @@ -27,6 +27,8 @@ def __init__(self, **kwargs): def __call__(self, data): img = data['image'] + if 'max_sile_len' in data: + self.limit_side_len = data['max_sile_len'] src_h, src_w, _ = img.shape if sum([src_h, src_w]) < 64: img = self.image_padding(img) diff --git a/openrec/modeling/decoders/__init__.py b/openrec/modeling/decoders/__init__.py index f570fd0..bbb9b21 100644 --- a/openrec/modeling/decoders/__init__.py +++ b/openrec/modeling/decoders/__init__.py @@ -28,6 +28,7 @@ def build_decoder(config): from .cam_decoder import CAMDecoder from .ote_decoder import OTEDecoder from .bus_decoder import BUSDecoder + # from .dptr_parseq_clip_b_decoder import DptrParseq support_dict = [ 'CTCDecoder', 'NRTRDecoder', 'CPPDDecoder', 'ABINetDecoder', @@ -35,7 +36,7 @@ def build_decoder(config): 'SMTRDecoder', 'LPVDecoder', 'SARDecoder', 'RobustScannerDecoder', 'SRNDecoder', 'ASTERDecoder', 'RCTCDecoder', 'LISTERDecoder', 'GTCDecoder', 'SMTRDecoderNumAttn', 'MATRNDecoder', 'MGPDecoder', - 'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder' + 'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder', 'DptrParseq' ] module_name = config.pop('name') diff --git a/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py b/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py new file mode 100644 index 0000000..566b4a2 --- /dev/null +++ b/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py @@ -0,0 +1,1398 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from itertools import permutations +from collections import OrderedDict +import hashlib +import os +import gzip +import html +import urllib +import warnings +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules import transformer +from typing import Any, Optional, Tuple, List, Union +from pkg_resources import packaging +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm +from functools import lru_cache + +import ftfy +import regex as re + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = SimpleTokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x) + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # take features from the eot embedding (eot_token is the highest number in each sequence) + output = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + output = torch.cat([output.unsqueeze(1), x], dim=1) + + return output + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +class FMU(nn.Module): + """A Transformer decoder layer supporting two-stream attention (XLNet) + This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch.""" + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', + layer_norm_eps=1e-5): + super().__init__() + self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = transformer._get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.gelu + super().__setstate__(state) + + def forward(self, query: Tensor, memory: Tensor): + """Forward pass for a single stream (i.e. content or query) + tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency. + Both tgt_kv and memory are expected to be LayerNorm'd too. + memory is LayerNorm'd by ViT. + """ + query1, ca_weights = self.cross_attn(query, memory, memory) + query = query + self.dropout1(query1) + + query2 = self.linear2(self.dropout2(self.activation(self.linear1(self.norm(query))))) + query = query + self.dropout3(query2) + + return query + + +class DecoderLayer(nn.Module): + """A Transformer decoder layer supporting two-stream attention (XLNet) This + implements a pre-LN decoder, as opposed to the post-LN default in + PyTorch.""" + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation='gelu', + layer_norm_eps=1e-5, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, + nhead, + dropout=dropout, + batch_first=True) + self.cross_attn = nn.MultiheadAttention(d_model, + nhead, + dropout=dropout, + batch_first=True) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = transformer._get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.gelu + super().__setstate__(state) + + def forward_stream( + self, + tgt: Tensor, + tgt_norm: Tensor, + tgt_kv: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor], + tgt_key_padding_mask: Optional[Tensor], + ): + """Forward pass for a single stream (i.e. content or query) tgt_norm is + just a LayerNorm'd tgt. + + Added as a separate parameter for efficiency. Both tgt_kv and memory + are expected to be LayerNorm'd too. memory is LayerNorm'd by ViT. + """ + tgt2, sa_weights = self.self_attn( + tgt_norm, + tgt_kv, + tgt_kv, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + + tgt = tgt + self.dropout1(tgt2) + + tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory) + self.attn_map = ca_weights + tgt = tgt + self.dropout2(tgt2) + + tgt2 = self.linear2( + self.dropout(self.activation(self.linear1(self.norm2(tgt))))) + tgt = tgt + self.dropout3(tgt2) + return tgt, sa_weights, ca_weights + + def forward( + self, + query, + content, + memory, + query_mask: Optional[Tensor] = None, + content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None, + update_content: bool = True, + ): + query_norm = self.norm_q(query) + content_norm = self.norm_c(content) + query = self.forward_stream(query, query_norm, content_norm, memory, + query_mask, content_key_padding_mask)[0] + if update_content: + content = self.forward_stream(content, content_norm, content_norm, + memory, content_mask, + content_key_padding_mask)[0] + return query, content + + +class Decoder(nn.Module): + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm): + super().__init__() + self.layers = transformer._get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + query, + content, + memory, + query_mask: Optional[Tensor] = None, + content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None, + ): + for i, mod in enumerate(self.layers): + last = i == len(self.layers) - 1 + query, content = mod( + query, + content, + memory, + query_mask, + content_mask, + content_key_padding_mask, + update_content=not last, + ) + query = self.norm(query) + return query + + +class TokenEmbedding(nn.Module): + + def __init__(self, charset_size: int, embed_dim: int): + super().__init__() + self.embedding = nn.Embedding(charset_size, embed_dim) + self.embed_dim = embed_dim + + def forward(self, tokens: torch.Tensor): + return math.sqrt(self.embed_dim) * self.embedding(tokens) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + +class DptrParseq(nn.Module): + + def __init__(self, + in_channels, + out_channels, + max_label_length=25, + embed_dim=512, + dec_num_heads=8, + dec_mlp_ratio=4, + dec_depth=6, + perm_num=6, + perm_forward=True, + perm_mirrored=True, + decode_ar=True, + refine_iters=1, + dropout=0.1, + is_pretrain=True, + ORP_path=None, + **kwargs: Any) -> None: + super().__init__() + self.pad_id = out_channels - 1 + self.eos_id = 0 + self.bos_id = out_channels - 2 + self.max_label_length = max_label_length + self.decode_ar = decode_ar + self.refine_iters = refine_iters + self.is_pretrain = is_pretrain + if not is_pretrain: + self.token_query = nn.Parameter(torch.Tensor(1, 26, embed_dim)) + self.fmu = FMU(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) + + decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) + self.decoder = Decoder(decoder_layer, + num_layers=dec_depth, + norm=nn.LayerNorm(embed_dim)) + + # Perm/attn mask stuff + self.rng = np.random.default_rng() + self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num + self.perm_forward = perm_forward + self.perm_mirrored = perm_mirrored + + # We don't predict nor + self.head = nn.Linear(embed_dim, out_channels - 2) + self.text_embed = TokenEmbedding(out_channels, embed_dim) + + # +1 for + self.pos_queries = nn.Parameter( + torch.Tensor(1, max_label_length + 1, embed_dim)) + self.dropout = nn.Dropout(p=dropout) + # Encoder has its own init. + self.apply(self._init_weights) + nn.init.trunc_normal_(self.pos_queries, std=0.02) + + if is_pretrain: + self.clip_encoder, preprocess = load("ViT-B/16") + for p in self.clip_encoder.parameters(): + p.requires_grad = False + if ORP_path is None: + background_image_folder_path = 'background_mages_folder/path' + self.background_features = self.get_noise(background_image_folder_path, preprocess) + torch.save(self.background_features, 'save/noise/to/ORP_path') + else: + self.background_features = torch.load(ORP_path, map_location='cpu') + + def _init_weights(self, module: nn.Module): + """Initialize the weights using the typical initialization schemes used + in SOTA models.""" + + if isinstance(module, nn.Linear): + nn.init.trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.trunc_normal_(module.weight, std=0.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, + mode='fan_out', + nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + @torch.jit.ignore + def no_weight_decay(self): + param_names = {'text_embed.embedding.weight', 'pos_queries'} + return param_names + + def get_noise(self, background_image_path, preprocess): + image_paths = [os.path.join(background_image_path, filename) for filename in os.listdir(image_folder_path) if + filename.endswith(('.png', '.jpg', '.jpeg'))] + features = [] + for image_path in image_paths: + image = Image.open(image_path) + input = preprocess(image).unsqueeze(0).to(self._device) + with torch.no_grad(): + feature = self.clip_encoder.encode_image(input) + features.append(feature) + image.close() + return torch.cat(features).cpu().numpy() + + def clip_encode(self, labels): + text_inputs = torch.cat([tokenize(f"a photo of a '{c}'") for c in labels]).to(self._device) + + return self.clip_encoder.encode_text(text_inputs) + + def decode( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[Tensor] = None, + tgt_padding_mask: Optional[Tensor] = None, + tgt_query: Optional[Tensor] = None, + tgt_query_mask: Optional[Tensor] = None, + pos_query: torch.Tensor = None, + ): + N, L = tgt.shape + # stands for the null context. We only supply position information for characters after . + null_ctx = self.text_embed(tgt[:, :1]) + + if tgt_query is None: + tgt_query = pos_query[:, :L] + tgt_emb = pos_query[:, :L - 1] + self.text_embed(tgt[:, 1:]) + tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1)) + + tgt_query = self.dropout(tgt_query) + return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, + tgt_mask, tgt_padding_mask) + + def forward(self, memory, data=None, pos_query=None): + # print(memory.shape, data[0].shape) + if self.training: + if self.is_pretrain: + return self.training_step(None, pos_query, data[0], memory) + return self.training_step(memory, pos_query, data[0], None) + else: + if self.is_pretrain: + return self.forward_test(None, memory, pos_query) + return self.forward_test(memory, None, pos_query) + + def forward_test(self, + memory: Tensor, clip_ids, + pos_query: Tensor = None, + max_length: Optional[int] = None) -> Tensor: + testing = max_length is None + max_length = (self.max_label_length if max_length is None else min( + max_length, self.max_label_length)) + + if self.is_pretrain: + memory = self.clip_encoder.encode_text(clip_ids) + else: + bs = memory.shape[0] + token_query = self.token_query.expand(bs, -1, -1) + memory = self.fmu(token_query, memory) + _device = memory.get_device() + bs = memory.shape[0] + # +1 for at end of sequence. + num_steps = max_length + 1 + # memory = self.encode(images) + + # Query positions up to `num_steps` + if pos_query is None: + pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1) + else: + pos_queries = pos_query + + # Special case for the forward permutation. Faster than using `generate_attn_masks()` + tgt_mask = query_mask = torch.triu( + torch.full((num_steps, num_steps), float('-inf'), device=_device), + 1) + self.attn_maps = [] + if self.decode_ar: + tgt_in = torch.full((bs, num_steps), + self.pad_id, + dtype=torch.long, + device=_device) + tgt_in[:, 0] = self.bos_id + + logits = [] + for i in range(num_steps): + j = i + 1 # next token index + # Efficient decoding: + # Input the context up to the ith token. We use only one query (at position = i) at a time. + # This works because of the lookahead masking effect of the canonical (forward) AR context. + # Past tokens have no access to future tokens, hence are fixed once computed. + tgt_out = self.decode( + tgt_in[:, :j], + memory, + tgt_mask[:j, :j], + tgt_query=pos_queries[:, i:j], + tgt_query_mask=query_mask[i:j, :j], + pos_query=pos_queries, + ) + self.attn_maps.append(self.decoder.layers[-1].attn_map) + # the next token probability is in the output's ith token position + p_i = self.head(tgt_out) + logits.append(p_i) + if j < num_steps: + # greedy decode. add the next token index to the target input + tgt_in[:, j] = p_i.squeeze().argmax(-1) + # Efficient batch decoding: If all output words have at least one EOS token, end decoding. + if testing and (tgt_in == self.eos_id).any(dim=-1).all(): + break + + logits = torch.cat(logits, dim=1) + else: + # No prior context, so input is just . We query all positions. + tgt_in = torch.full((bs, 1), + self.bos_id, + dtype=torch.long, + device=_device) + tgt_out = self.decode(tgt_in, + memory, + tgt_query=pos_queries, + pos_query=pos_queries) + logits = self.head(tgt_out) + + if self.refine_iters: + # For iterative refinement, we always use a 'cloze' mask. + # We can derive it from the AR forward mask by unmasking the token context to the right. + query_mask[torch.triu( + torch.ones(num_steps, + num_steps, + dtype=torch.bool, + device=_device), 2)] = 0 + bos = torch.full((bs, 1), + self.bos_id, + dtype=torch.long, + device=_device) + for i in range(self.refine_iters): + # Prior context is the previous output. + tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1) + tgt_len = tgt_in.shape[1] + tgt_padding_mask = (tgt_in == self.eos_id).int().cumsum( + -1) > 0 # mask tokens beyond the first EOS token. + tgt_out = self.decode( + tgt_in, + memory, + tgt_mask[:tgt_len, :tgt_len], + tgt_padding_mask, + tgt_query=pos_queries, + tgt_query_mask=query_mask[:, :tgt_len], + pos_query=pos_queries, + ) + logits = self.head(tgt_out) + + return F.softmax(logits, -1) + + def gen_tgt_perms(self, tgt, _device): + """Generate shared permutations for the whole batch. + + This works because the same attention mask can be used for the shorter + sequences because of the padding mask. + """ + # We don't permute the position of BOS, we permute EOS separately + max_num_chars = tgt.shape[1] - 2 + # Special handling for 1-character sequences + if max_num_chars == 1: + return torch.arange(3, device=_device).unsqueeze(0) + perms = [torch.arange(max_num_chars, device=_device) + ] if self.perm_forward else [] + # Additional permutations if needed + max_perms = math.factorial(max_num_chars) + if self.perm_mirrored: + max_perms //= 2 + num_gen_perms = min(self.max_gen_perms, max_perms) + # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions + # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars. + if max_num_chars < 5: + # Pool of permutations to sample from. We only need the first half (if complementary option is selected) + # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves + if max_num_chars == 4 and self.perm_mirrored: + selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21] + else: + selector = list(range(max_perms)) + perm_pool = torch.as_tensor(list( + permutations(range(max_num_chars), max_num_chars)), + device=_device)[selector] + # If the forward permutation is always selected, no need to add it to the pool for sampling + if self.perm_forward: + perm_pool = perm_pool[1:] + perms = torch.stack(perms) + if len(perm_pool): + i = self.rng.choice(len(perm_pool), + size=num_gen_perms - len(perms), + replace=False) + perms = torch.cat([perms, perm_pool[i]]) + else: + perms.extend([ + torch.randperm(max_num_chars, device=_device) + for _ in range(num_gen_perms - len(perms)) + ]) + perms = torch.stack(perms) + if self.perm_mirrored: + # Add complementary pairs + comp = perms.flip(-1) + # Stack in such a way that the pairs are next to each other. + perms = torch.stack([perms, comp + ]).transpose(0, 1).reshape(-1, max_num_chars) + # NOTE: + # The only meaningful way of permuting the EOS position is by moving it one character position at a time. + # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS + # positions will always be much less than the number of permutations (unless a low perm_num is set). + # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly + # distribute it across the chosen number of permutations. + # Add position indices of BOS and EOS + bos_idx = perms.new_zeros((len(perms), 1)) + eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1) + perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1) + # Special handling for the reverse direction. This does two things: + # 1. Reverse context for the characters + # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode) + if len(perms) > 1: + perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, + device=_device) + return perms + + def generate_attn_masks(self, perm, _device): + """Generate attention masks given a sequence permutation (includes pos. + for bos and eos tokens) + + :param perm: the permutation sequence. i = 0 is always the BOS + :return: lookahead attention masks + """ + sz = perm.shape[0] + mask = torch.zeros((sz, sz), device=_device) + for i in range(sz): + query_idx = perm[i] + masked_keys = perm[i + 1:] + mask[query_idx, masked_keys] = float('-inf') + content_mask = mask[:-1, :-1].clone() + mask[torch.eye(sz, dtype=torch.bool, + device=_device)] = float('-inf') # mask "self" + query_mask = mask[1:, :-1] + return content_mask, query_mask + + def training_step(self, memory, pos_query, tgt_ids, clip_ids): + bs = tgt_ids.shape[0] + if self.is_pretrain: + memory = self.clip_encoder.encode_text(clip_ids) + n = memory.shape[1] + B, N, D = self.background_features.shape + random_B = np.random.choice(B, bs, replace=False) + random_N = np.random.choice(N, n, replace=False) + noise = self.background_features[random_B][:, random_N] + noise = torch.from_numpy(noise).to(memory.get_device()) + memory = memory + noise * 1e-1 + else: + token_query = self.token_query.expand(bs, -1, -1) + memory = self.fmu(token_query, memory) + + if pos_query is None: + pos_query = self.pos_queries.expand(bs, -1, -1) + # Prepare the target sequences (input and output) + tgt_perms = self.gen_tgt_perms(tgt_ids, memory.get_device()) + tgt_in = tgt_ids[:, :-1] + tgt_out = tgt_ids[:, 1:] + + # The [EOS] token is not depended upon by any other token in any permutation ordering + tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id) + + loss = 0 + loss_numel = 0 + n = (tgt_out != self.pad_id).sum().item() + for i, perm in enumerate(tgt_perms): + tgt_mask, query_mask = self.generate_attn_masks( + perm, memory.get_device()) + # print("tgt_in:", tgt_in, "tgt_out:", tgt_out, "tgt_padding_mask:", tgt_padding_mask) + # print('tgt_mask:', tgt_mask) + # print('query_mask:', query_mask) + # print(tgt_in.shape, memory.shape, tgt_mask.shape, tgt_padding_mask.shape, query_mask.shape, pos_query.shape) + out = self.decode( + tgt_in, + memory, + tgt_mask, + tgt_padding_mask, + tgt_query_mask=query_mask, + pos_query=pos_query, + ) + # print('out:', out) + logits = self.head(out) + # print('logits:', logits) + if i == 0: + final_out = logits + loss += n * F.cross_entropy(logits.flatten(end_dim=1), + tgt_out.flatten(), + ignore_index=self.pad_id) + loss_numel += n + # After the second iteration (i.e. done with canonical and reverse orderings), + # remove the [EOS] tokens for the succeeding perms + if i == 1: + tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id, + tgt_out) + n = (tgt_out != self.pad_id).sum().item() + loss /= loss_numel + + # self.log('loss', loss) + return [loss, final_out] diff --git a/openrec/modeling/decoders/smtr_decoder.py b/openrec/modeling/decoders/smtr_decoder.py index 01bd7ee..5c7cc0a 100644 --- a/openrec/modeling/decoders/smtr_decoder.py +++ b/openrec/modeling/decoders/smtr_decoder.py @@ -240,26 +240,26 @@ def forward_test_bi(self, x): next_id = torch.full([1, self.sub_str_len], self.bos_next, dtype=torch.long, - device=x.get_device()) + device=x.device) pre_id = torch.full([1, self.sub_str_len], self.bos_pre, dtype=torch.long, - device=x.get_device()) + device=x.device) # prompt_next_bos = self.char_embed(prompt_id) - # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.get_device()) + # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.device) next_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, - device=x.get_device()) + device=x.device) pre_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, - device=x.get_device()) + device=x.device) next_logits_all = [] pre_logits_all = [] mask_pad = torch.zeros([bs, 1], dtype=torch.float32, - device=x.get_device()) + device=x.device) for j in range(0, min(70, self.max_len - 1)): prompt_char_next = torch.concat([ @@ -317,7 +317,7 @@ def forward_test_bi(self, x): ques_next = self.next_token.tile([1, 1, 1, 1]).squeeze(1) mask_pad = torch.zeros([1, 1], dtype=torch.float32, - device=x.get_device()) + device=x.device) for j in range(0, min(70, self.max_len - 1)): prompt_next = torch.concat([ @@ -371,17 +371,17 @@ def forward_test_bi_attn(self, x): prompt_next_embed = self.prompt_next_embed.squeeze(1) prompt_pre_embed = self.prompt_pre_embed.squeeze(1) - next_id = torch.full([1, self.sub_str_len], self.bos_next, dtype=torch.long, device=x.get_device()) - pre_id = torch.full([1, self.sub_str_len], self.bos_pre, dtype=torch.long, device=x.get_device()) + next_id = torch.full([1, self.sub_str_len], self.bos_next, dtype=torch.long, device=x.device) + pre_id = torch.full([1, self.sub_str_len], self.bos_pre, dtype=torch.long, device=x.device) # prompt_next_bos = self.char_embed(prompt_id) - # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.get_device()) - next_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.get_device()) - pre_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.get_device()) + # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.device) + next_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.device) + pre_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.device) next_logits_all = [] pre_logits_all = [] attn_map_next = [] attn_map_pre = [] - mask_pad = torch.zeros([bs, 1], dtype=torch.float32, device=x.get_device()) + mask_pad = torch.zeros([bs, 1], dtype=torch.float32, device=x.device) for j in range(0, min(70, self.max_len-1)): prompt_char_next = torch.concat([prompt_next_embed[:, :1, :], prompt_next_embed[:, 1:, :] + self.char_embed(next_id)], 1) # b, sub_l, dim @@ -428,7 +428,7 @@ def forward_test_bi_attn(self, x): next_logits_all_mid = [] attn_map_next_mid = [] ques_next = self.next_token.tile([1, 1, 1, 1]).squeeze(1) - mask_pad = torch.zeros([1, 1], dtype=torch.float32, device=x.get_device()) + mask_pad = torch.zeros([1, 1], dtype=torch.float32, device=x.device) for j in range(0, min(70, self.max_len-1)): prompt_next = torch.concat([prompt_next_embed[:, :1, :], prompt_next_embed[:, 1:, :] + self.char_embed(next_id)], 1) # b, sub_l, dim @@ -474,15 +474,15 @@ def forward_test(self, x): prompt_id = torch.full([bs, self.sub_str_len], self.bos_next, dtype=torch.long, - device=x.get_device()) + device=x.device) pred_id_list = torch.full([bs, self.max_len], self.ignore_index, dtype=torch.long, - device=x.get_device()) + device=x.device) logits_all = [] mask_pad = torch.zeros([bs, 1], dtype=torch.float32, - device=x.get_device()) + device=x.device) for j in range(0, self.max_len - 1): prompt_next = torch.concat([ @@ -522,15 +522,15 @@ def forward_test(self, x): prompt_id = torch.full([bs, self.sub_str_len], self.bos_pre, dtype=torch.long, - device=x.get_device()) + device=x.device) pred_id_list = torch.full([bs, self.max_len], self.ignore_index, dtype=torch.long, - device=x.get_device()) + device=x.device) logits_all = [] mask_pad = torch.zeros([bs, 1], dtype=torch.float32, - device=x.get_device()) + device=x.device) for j in range(0, self.max_len - 1): prompt_next = torch.concat([ @@ -599,7 +599,7 @@ def forward_train(self, x, targets=None): mask_pad = torch.zeros([bs * (max_len_curr + max_len_curr_pre), 1], dtype=torch.float32, - device=x.get_device()) + device=x.device) mask = torch.concat([mask_next, mask_pre], 1).flatten(0, 1) mask = torch.concat([mask_pad, mask], 1) next_pre = next_pre.flatten(0, 1) diff --git a/openrec/preprocess/__init__.py b/openrec/preprocess/__init__.py index 0558b30..8c216fa 100644 --- a/openrec/preprocess/__init__.py +++ b/openrec/preprocess/__init__.py @@ -24,6 +24,7 @@ from .srn_label_encode import SRNLabelEncode from .visionlan_label_encode import VisionLANLabelEncode from .cam_label_encode import CAMLabelEncode +# from .dptr_label_encode import DPTRLabelEncode class KeepKeys(object): diff --git a/openrec/preprocess/dptr_label_encode.py b/openrec/preprocess/dptr_label_encode.py new file mode 100644 index 0000000..0fd4f23 --- /dev/null +++ b/openrec/preprocess/dptr_label_encode.py @@ -0,0 +1,157 @@ +import re +from abc import ABC, abstractmethod +from itertools import groupby +from typing import List, Optional, Tuple +import numpy as np +import torch +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence +import unicodedata +from ..modeling.decoders.dptr_parseq_clip_b_decoder import tokenize + +class CharsetAdapter: + """Transforms labels according to the target charset.""" + + def __init__(self, target_charset) -> None: + super().__init__() + self.lowercase_only = target_charset == target_charset.lower() + self.uppercase_only = target_charset == target_charset.upper() + self.unsupported = re.compile(f'[^{re.escape(target_charset)}]') + + def __call__(self, label): + if self.lowercase_only: + label = label.lower() + elif self.uppercase_only: + label = label.upper() + # Remove unsupported characters + label = self.unsupported.sub('', label) + return label + + +class BaseTokenizer(ABC): +# eos=0, a=1, bos=37, pad=38 + def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None: + self._itos = specials_first + tuple(charset) + specials_last + self._stoi = {s: i for i, s in enumerate(self._itos)} + # print("stoi:", self._stoi) + + def __len__(self): + return len(self._itos) + + def _tok2ids(self, tokens: str) -> List[int]: + # print("tokens", tokens) + return [self._stoi[s] for s in tokens] + + def _ids2tok(self, token_ids: List[int], join: bool = True) -> str: + tokens = [self._itos[i] for i in token_ids] + return ''.join(tokens) if join else tokens + + @abstractmethod + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + """Encode a batch of labels to a representation suitable for the model. + + Args: + labels: List of labels. Each can be of arbitrary length. + device: Create tensor on this device. + + Returns: + Batched tensor representation padded to the max label length. Shape: N, L + """ + raise NotImplementedError + + @abstractmethod + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + """Internal method which performs the necessary filtering prior to decoding.""" + raise NotImplementedError + + def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]: + """Decode a batch of token distributions. + + Args: + token_dists: softmax probabilities over the token distribution. Shape: N, L, C + raw: return unprocessed labels (will return list of list of strings) + + Returns: + list of string labels (arbitrary length) and + their corresponding sequence probabilities as a list of Tensors + """ + batch_tokens = [] + batch_probs = [] + for dist in token_dists: + probs, ids = dist.max(-1) # greedy selection + if not raw: + probs, ids = self._filter(probs, ids) + tokens = self._ids2tok(ids, not raw) + batch_tokens.append(tokens) + batch_probs.append(probs) + return batch_tokens, batch_probs + + +class Tokenizer(BaseTokenizer): + BOS = '[B]' + EOS = '[E]' + PAD = '[P]' + + def __init__(self, charset: str) -> None: + specials_first = (self.EOS,) + specials_last = (self.BOS, self.PAD) + super().__init__(charset, specials_first, specials_last) + self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last] + + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + batch = [self.bos_id] + self._tok2ids(labels) + [self.eos_id] + return batch + # return pad_sequence(batch, batch_first=True, padding_value=self.pad_id) + + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + ids = ids.tolist() + try: + eos_idx = ids.index(self.eos_id) + except ValueError: + eos_idx = len(ids) # Nothing to truncate. + # Truncate after EOS + ids = ids[:eos_idx] + probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists) + return probs, ids + +class DPTRLabelEncode(Tokenizer): + """Convert between text-label and text-index.""" + def __init__(self, max_text_length=25, character_dict_path=None, **kwargs): + self.max_length = max_text_length + charset = get_alpha(character_dict_path) + charset = ''.join(charset) + # print(charset) + super(DPTRLabelEncode, self).__init__(charset) + + def __call__(self, data, normalize_unicode=True): + text = data['label'] + + if normalize_unicode: + text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode() + text = ''.join(text.split()) + if len(text) == 0 or len(text) > self.max_length: + return None + + text_ids = self.encode(text) + clip_ids = tokenize(f"a photo of a '{text}'") + text_ids = text_ids + [self.pad_id] * (self.max_length + 2 - len(text_ids)) + # print(text, len(text_ids), len(clip_ids[0])) + data['clip_label'] = np.array(clip_ids[0]) + data['label'] = np.array(text_ids) + return data + + def add_special_char(self, dict_character): + dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD] + return dict_character + +def get_alpha(alpha_path): + character_str = [] + with open(alpha_path, 'rb') as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip('\n').strip('\r\n') + character_str.append(line) + dict_character = list(character_str) + if 'arabic' in alpha_path: + reverse = True + return dict_character \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1ab8901..8697362 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ imgaug lmdb numpy opencv-python<=4.6.0.66 +pyclipper pyyaml rapidfuzz tqdm diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/create_lmdb_dataset.py b/tools/create_lmdb_dataset.py index a96aa92..d9a3a44 100644 --- a/tools/create_lmdb_dataset.py +++ b/tools/create_lmdb_dataset.py @@ -99,6 +99,7 @@ def createDataset(data_list, outputPath, checkValid=True): if __name__ == '__main__': data_dir = './Union14M-L/' + # downloading the filtered_label_list from https://drive.google.com/drive/folders/1x1LC8C_W-Frl3sGV9i9_i_OD-bqNdodJ?usp=drive_link label_file_list = [ './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_challenging.jsonl.txt', './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_easy.jsonl.txt', diff --git a/tools/data/__init__.py b/tools/data/__init__.py index ad1a203..dbb6719 100644 --- a/tools/data/__init__.py +++ b/tools/data/__init__.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader, DistributedSampler from tools.data.lmdb_dataset import LMDBDataSet +from tools.data.text_lmdb_dataset import TextLMDBDataSet from tools.data.lmdb_dataset_test import LMDBDataSetTest from tools.data.multi_scale_sampler import MultiScaleSampler from tools.data.ratio_dataset import RatioDataSet @@ -30,7 +31,7 @@ def build_dataloader(config, mode, logger, seed=None, epoch=3): config = copy.deepcopy(config) support_dict = [ - 'SimpleDataSet', 'LMDBDataSet', 'MultiScaleDataSet', 'STRLMDBDataSet', + 'SimpleDataSet', 'LMDBDataSet', 'TextLMDBDataSet', 'MultiScaleDataSet', 'STRLMDBDataSet', 'LMDBDataSetTest', 'RatioDataSet', 'RatioDataSetTest', 'RatioDataSetTVResize', 'RatioDataSetTVResizeTest' ] diff --git a/tools/data/ratio_sampler.py b/tools/data/ratio_sampler.py index e0c9d72..51e9c2a 100644 --- a/tools/data/ratio_sampler.py +++ b/tools/data/ratio_sampler.py @@ -56,7 +56,7 @@ def __init__(self, self.base_im_w = base_im_w # Get the GPU and node related information - num_replicas = torch.cuda.device_count() + num_replicas = torch.cuda.device_count() if torch.cuda.is_available() else 1 # rank = dist.get_rank() rank = (int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0) diff --git a/tools/data/text_lmdb_dataset.py b/tools/data/text_lmdb_dataset.py new file mode 100644 index 0000000..2bd8f37 --- /dev/null +++ b/tools/data/text_lmdb_dataset.py @@ -0,0 +1,127 @@ +import os +import cv2 +import lmdb +import numpy as np +from torch.utils.data import Dataset + +from openrec.preprocess import create_operators, transform + + +class TextLMDBDataSet(Dataset): + + def __init__(self, config, mode, logger, seed=None, epoch=1): + super(TextLMDBDataSet, self).__init__() + + global_config = config['Global'] + dataset_config = config[mode]['dataset'] + loader_config = config[mode]['loader'] + loader_config['batch_size_per_card'] + data_dir = dataset_config['data_dir'] + self.do_shuffle = loader_config['shuffle'] + + self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir) + logger.info(f'Initialize indexs of datasets: {data_dir}') + self.data_idx_order_list = self.dataset_traversal() + if self.do_shuffle: + np.random.shuffle(self.data_idx_order_list) + self.ops = create_operators(dataset_config['transforms'], + global_config) + self.ext_op_transform_idx = dataset_config.get('ext_op_transform_idx', + 1) + + ratio_list = dataset_config.get('ratio_list', [1.0]) + self.need_reset = True in [x < 1 for x in ratio_list] + + def load_hierarchical_lmdb_dataset(self, data_dir): + lmdb_sets = {} + dataset_idx = 0 + for dirpath, dirnames, filenames in os.walk(data_dir + '/'): + if not dirnames: + env = lmdb.open( + dirpath, + max_readers=32, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + txn = env.begin(write=False) + num_samples = int(txn.get('num-samples'.encode())) + lmdb_sets[dataset_idx] = { + 'dirpath': dirpath, + 'env': env, + 'txn': txn, + 'num_samples': num_samples, + } + dataset_idx += 1 + return lmdb_sets + + def dataset_traversal(self): + lmdb_num = len(self.lmdb_sets) + total_sample_num = 0 + for lno in range(lmdb_num): + total_sample_num += self.lmdb_sets[lno]['num_samples'] + data_idx_order_list = np.zeros((total_sample_num, 2)) + beg_idx = 0 + for lno in range(lmdb_num): + tmp_sample_num = self.lmdb_sets[lno]['num_samples'] + end_idx = beg_idx + tmp_sample_num + data_idx_order_list[beg_idx:end_idx, 0] = lno + data_idx_order_list[beg_idx:end_idx, + 1] = list(range(tmp_sample_num)) + data_idx_order_list[beg_idx:end_idx, 1] += 1 + beg_idx = beg_idx + tmp_sample_num + return data_idx_order_list + + def get_ext_data(self): + ext_data_num = 0 + for op in self.ops: + if hasattr(op, 'ext_data_num'): + ext_data_num = getattr(op, 'ext_data_num') + break + load_data_ops = self.ops[:self.ext_op_transform_idx] + ext_data = [] + + while len(ext_data) < ext_data_num: + lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint( + len(self))] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info( + self.lmdb_sets[lmdb_idx]['txn'], file_idx) + if sample_info is None: + continue + label = sample_info + data = {'label': label} + data = transform(data, load_data_ops) + if data is None: + continue + ext_data.append(data) + return ext_data + + def get_lmdb_sample_info(self, txn, index, normalize_unicode=True, remove_whitespace=True, max_length=True): + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key) + if label is None: + return None + label = label.decode('utf-8') + + return label + + def __getitem__(self, idx): + lmdb_idx, file_idx = self.data_idx_order_list[idx] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info( + self.lmdb_sets[lmdb_idx]['txn'], file_idx) + if sample_info is None: + return self.__getitem__(np.random.randint(self.__len__())) + label = sample_info + data = {'label': label} + outs = transform(data, self.ops) + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return self.data_idx_order_list.shape[0] diff --git a/tools/download/download_dataset.py b/tools/download/download_dataset.py new file mode 100644 index 0000000..50e5a3b --- /dev/null +++ b/tools/download/download_dataset.py @@ -0,0 +1,32 @@ +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) + +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..', '..'))) + +from engine import Config +from utility import ArgsParser +import download.utils +from torchvision.datasets.utils import extract_archive + +def main(cfg): + urls, filename_paths, check_validity = download.utils.get_dataset_info(cfg) + for url, filename_path in zip(urls, filename_paths): + print(f"Downloading {filename_path} from {url} . . .") + download.utils.urlretrieve(url=url, filename=filename_path, check_validity=check_validity) + if not filename_path.endswith(".mdb"): + extract_archive(from_path=filename_path, to_path=cfg["root"], remove_finished=True) + + print("Downloads finished!") + +if __name__ == "__main__": + FLAGS = ArgsParser().parse_args() + cfg = Config(FLAGS.config) + FLAGS = vars(FLAGS) + opt = FLAGS.pop('opt') + cfg.merge_dict(FLAGS) + cfg.merge_dict(opt) + main(cfg.cfg) diff --git a/tools/download/utils.py b/tools/download/utils.py new file mode 100644 index 0000000..f83ae39 --- /dev/null +++ b/tools/download/utils.py @@ -0,0 +1,23 @@ +import urllib +import ssl +from tqdm import tqdm +import os + +def get_dataset_info(cfg): + download_urls, filenames, check_validity = cfg["download_links"], cfg["filenames"], cfg["check_validity"] + return download_urls, filenames, check_validity + +# Modified from torchvision as some datasets cant pass the certificate validity check: +# https://github.com/pytorch/vision/blob/868a3b42f4bffe29e4414ad7e4c7d9d0b4690ecb/torchvision/datasets/utils.py#L27C1-L32C40 +def urlretrieve(url, filename, chunk_size=1024 * 32, check_validity=True): + os.makedirs(os.path.dirname(filename), exist_ok=True) + ctx = ssl.create_default_context() + if not check_validity: + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + request = urllib.request.Request(url) + with urllib.request.urlopen(request, context=ctx) as response: + with open(filename, "wb") as fh, tqdm(total=response.length, unit="B", unit_scale=True) as pbar: + while chunk := response.read(chunk_size): + fh.write(chunk) + pbar.update(len(chunk)) \ No newline at end of file diff --git a/tools/engine/trainer.py b/tools/engine/trainer.py index e2382b4..5faaf3b 100644 --- a/tools/engine/trainer.py +++ b/tools/engine/trainer.py @@ -245,7 +245,7 @@ def train(self): train_reader_cost += time.time() - reader_start # use amp if self.scaler: - with torch.cuda.amp.autocast(): + with torch.cuda.amp.autocast(enabled=self.device.type == 'cuda'): preds = self.model(batch[0], data=batch[1:]) loss = self.loss_class(preds, batch) self.scaler.scale(loss['loss']).backward() @@ -543,7 +543,7 @@ def eval(self): batch = [t.to(self.device) for t in batch] start = time.time() if self.scaler: - with torch.cuda.amp.autocast(): + with torch.cuda.amp.autocast(enabled=self.device.type == 'cuda'): preds = self.model(batch[0], data=batch[1:]) else: preds = self.model(batch[0], data=batch[1:]) @@ -581,7 +581,7 @@ def eval_ema(self): batch = [t.to(self.device) for t in batch] start = time.time() if self.scaler: - with torch.cuda.amp.autocast(): + with torch.cuda.amp.autocast(enabled=self.device.type == 'cuda'): preds = self.ema_model(batch[0], data=batch[1:]) else: preds = self.ema_model(batch[0], data=batch[1:]) diff --git a/tools/infer_det.py b/tools/infer_det.py index 67cdef3..fd58760 100644 --- a/tools/infer_det.py +++ b/tools/infer_det.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from pathlib import Path import time import numpy as np @@ -23,6 +24,64 @@ from tools.utils.logging import get_logger from tools.utils.utility import get_image_file_list +logger = get_logger() + +root_dir = Path(__file__).resolve().parent +DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml') + +MODEL_NAME_DET = './openocr_det_repvit_ch.pth' # 模型文件名称 +DOWNLOAD_URL_DET = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth' # 模型文件 URL + + +def check_and_download_model(model_name: str, url: str): + """ + 检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。 + + Args: + model_name (str): 模型文件的名称,例如 "model.pt" + url (str): 模型文件的下载地址 + + Returns: + str: 模型文件的完整路径 + """ + if os.path.exists(model_name): + return model_name + + # 固定缓存路径为用户主目录下的 ".cache/openocr" + cache_dir = Path.home() / '.cache' / 'openocr' + model_path = cache_dir / model_name + + # 如果模型文件已存在,直接返回路径 + if model_path.exists(): + logger.info(f'Model already exists at: {model_path}') + return str(model_path) + + # 如果文件不存在,下载模型 + logger.info(f'Model not found. Downloading from {url}...') + + # 创建缓存目录(如果不存在) + cache_dir.mkdir(parents=True, exist_ok=True) + + try: + # 下载文件 + import urllib.request + with urllib.request.urlopen(url) as response, open(model_path, + 'wb') as out_file: + out_file.write(response.read()) + logger.info(f'Model downloaded and saved at: {model_path}') + return str(model_path) + + except Exception as e: + logger.error(f'Error downloading the model: {e}') + # 提示用户手动下载 + logger.error( + f'Unable to download the model automatically. ' + f'Please download the model manually from the following URL:\n{url}\n' + f'and save it to: {model_name} or {model_path}') + raise RuntimeError( + f'Failed to download the model. Please download it manually from {url} ' + f'and save it to {model_path}') from e + def replace_batchnorm(net): for child_name, child in net.named_children(): @@ -162,6 +221,7 @@ def set_device(device, numId=0): if device == 'gpu' and torch.cuda.is_available(): device = torch.device(f'cuda:{numId}') else: + logger.info('GPU is not available, using CPU.') device = torch.device('cpu') return device @@ -184,7 +244,11 @@ def __init__(self, config=None, numId=0): """ if config is None: - config = Config('./configs/det/dbnet/repvit_db.yml').cfg + config = Config(DEFAULT_CFG_PATH_DET).cfg + + if not os.path.exists(config['Global']['pretrained_model']): + config['Global']['pretrained_model'] = check_and_download_model( + MODEL_NAME_DET, DOWNLOAD_URL_DET) from opendet.modeling import build_model as build_det_model from opendet.postprocess import build_post_process @@ -212,10 +276,6 @@ def __init__(self, config=None, numId=0): self.ops = create_operators(transforms, global_config) - save_res_path = config['Global']['save_res_path'] - if not os.path.exists(os.path.dirname(save_res_path)): - os.makedirs(os.path.dirname(save_res_path)) - # build post process self.post_process_class = build_post_process(config['PostProcess'], global_config) @@ -277,11 +337,10 @@ def crop_infer( images = np.array(image_list) shape_list = np.array(shape_list) images = torch.from_numpy(images).to(device=self.device) - - t_start = time.time() - preds = self.model(images) - torch.cuda.synchronize() - t_cost = time.time() - t_start + with torch.no_grad(): + t_start = time.time() + preds = self.model(images) + t_cost = time.time() - t_start preds['maps'] = restore_preds(preds['maps'], crop_positions, (img_height, img_width)) @@ -290,7 +349,12 @@ def crop_infer( results.append(info) return results - def __call__(self, img_path=None, img_numpy_list=None, img_numpy=None): + def __call__(self, + img_path=None, + img_numpy_list=None, + img_numpy=None, + return_mask=False, + **kwargs): """ 对输入图像进行处理,并返回处理结果。 @@ -328,6 +392,8 @@ def __call__(self, img_path=None, img_numpy_list=None, img_numpy=None): img = f.read() data = {'image': img} data = self.transform(data, self.ops[:1]) + if kwargs.get('det_input_size', None) is not None: + data['max_sile_len'] = kwargs['det_input_size'] batch = self.transform(data, self.ops[1:]) images = np.expand_dims(batch[0], axis=0) @@ -337,26 +403,30 @@ def __call__(self, img_path=None, img_numpy_list=None, img_numpy=None): t_start = time.time() preds = self.model(images) t_cost = time.time() - t_start - post_result = self.post_process_class(preds, shape_list) + post_result = self.post_process_class(preds, shape_list, **kwargs) info = {'boxes': post_result[0]['points'], 'elapse': t_cost} + if return_mask: + if isinstance(preds['maps'], torch.Tensor): + mask = preds['maps'].detach().cpu().numpy() + else: + mask = preds['maps'] + info['mask'] = mask results.append(info) return results @torch.no_grad() def main(cfg): - logger = get_logger() is_visualize = cfg['Global'].get('is_visualize', False) model = OpenDetector(cfg) - save_res_path = cfg['Global']['output_dir'] + save_res_path = './det_results/' if not os.path.exists(save_res_path): os.makedirs(save_res_path) sample_num = 0 with open(save_res_path + '/det_results.txt', 'wb') as fout: for file in get_image_file_list(cfg['Global']['infer_img']): - preds_result = model(img_path=file)[0] logger.info('{} infer_img: {}, time cost: {}'.format( sample_num, file, preds_result['elapse'])) @@ -368,14 +438,16 @@ def main(cfg): dt_boxes_json.append(tmp_json) if is_visualize: src_img = cv2.imread(file) - save_det_path = save_res_path + '/det_results/' - draw_det_res(boxes, src_img, file, save_det_path) + draw_det_res(boxes, src_img, file, save_res_path) logger.info('The detected Image saved in {}'.format( - os.path.join(save_det_path, os.path.basename(file)))) + os.path.join(save_res_path, os.path.basename(file)))) otstr = file + '\t' + json.dumps(dt_boxes_json) + '\n' logger.info('results: {}'.format(json.dumps(dt_boxes_json))) fout.write(otstr.encode()) sample_num += 1 + logger.info( + f"Results saved to {os.path.join(save_res_path, 'det_results.txt')}.)" + ) logger.info('success!') diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py index 54633c9..482e3c9 100644 --- a/tools/infer_e2e.py +++ b/tools/infer_e2e.py @@ -3,6 +3,7 @@ from __future__ import print_function import os +from pathlib import Path import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) @@ -17,32 +18,94 @@ import cv2 import json from PIL import Image -import torch from tools.utils.utility import get_image_file_list, check_and_read from tools.infer_rec import OpenRecognizer from tools.infer_det import OpenDetector from tools.engine import Config from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop, draw_ocr_box_txt +from tools.utils.logging import get_logger +root_dir = Path(__file__).resolve().parent +DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml') +DEFAULT_CFG_PATH_REC_SERVER = str(root_dir / + '../configs/rec/svtrv2/svtrv2_ch.yml') +DEFAULT_CFG_PATH_REC = str(root_dir / '../configs/rec/svtrv2/repsvtr_ch.yml') -def set_device(device): - if device == 'gpu' and torch.cuda.is_available(): - device = torch.device('cuda:0') - else: - device = torch.device('cpu') - return device +logger = get_logger() + +MODEL_NAME_DET = './openocr_det_repvit_ch.pth' # 模型文件名称 +DOWNLOAD_URL_DET = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth' # 模型文件 URL +MODEL_NAME_REC = './openocr_repsvtr_ch.pth' # 模型文件名称 +DOWNLOAD_URL_REC = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth' # 模型文件 URL +MODEL_NAME_REC_SERVER = './openocr_svtrv2_ch.pth' # 模型文件名称 +DOWNLOAD_URL_REC_SERVER = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth' # 模型文件 URL + + +def check_and_download_model(model_name: str, url: str): + """ + 检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。 + + Args: + model_name (str): 模型文件的名称,例如 "model.pt" + url (str): 模型文件的下载地址 + + Returns: + str: 模型文件的完整路径 + """ + if os.path.exists(model_name): + return model_name + + # 固定缓存路径为用户主目录下的 ".cache/openocr" + cache_dir = Path.home() / '.cache' / 'openocr' + model_path = cache_dir / model_name + + # 如果模型文件已存在,直接返回路径 + if model_path.exists(): + logger.info(f'Model already exists at: {model_path}') + return str(model_path) + + # 如果文件不存在,下载模型 + logger.info(f'Model not found. Downloading from {url}...') + + # 创建缓存目录(如果不存在) + cache_dir.mkdir(parents=True, exist_ok=True) + + try: + # 下载文件 + import urllib.request + with urllib.request.urlopen(url) as response, open(model_path, + 'wb') as out_file: + out_file.write(response.read()) + logger.info(f'Model downloaded and saved at: {model_path}') + return str(model_path) + + except Exception as e: + logger.error(f'Error downloading the model: {e}') + # 提示用户手动下载 + logger.error( + f'Unable to download the model automatically. ' + f'Please download the model manually from the following URL:\n{url}\n' + f'and save it to: {model_name} or {model_path}') + raise RuntimeError( + f'Failed to download the model. Please download it manually from {url} ' + f'and save it to {model_path}') from e def check_and_download_font(font_path): if not os.path.exists(font_path): - print(f"Downloading '{font_path}' ...") + cache_dir = Path.home() / '.cache' / 'openocr' + font_path = str(cache_dir / font_path) + if os.path.exists(font_path): + return font_path + logger.info(f"Downloading '{font_path}' ...") try: import urllib.request font_url = 'https://shuiche-shop.oss-cn-chengdu.aliyuncs.com/fonts/simfang.ttf' urllib.request.urlretrieve(font_url, font_path) - print(f'Downloading font success: {font_path}') + logger.info(f'Downloading font success: {font_path}') except Exception as e: - print(f'Downloading font error: {e}') + logger.info(f'Downloading font error: {e}') + return font_path def sorted_boxes(dt_boxes): @@ -84,14 +147,18 @@ def __init__(self, mode='mobile', drop_score=0.5, det_box_type='quad'): 无返回值。 """ - cfg_det = Config( - './configs/det/dbnet/repvit_db.yml').cfg # mobile model + cfg_det = Config(DEFAULT_CFG_PATH_DET).cfg # mobile model + model_dir = check_and_download_model(MODEL_NAME_DET, DOWNLOAD_URL_DET) + cfg_det['Global']['pretrained_model'] = model_dir if mode == 'server': - cfg_rec = Config( - './configs/det/svtrv2/svtrv2_ch.yml').cfg # server model + cfg_rec = Config(DEFAULT_CFG_PATH_REC_SERVER).cfg # server model + model_dir = check_and_download_model(MODEL_NAME_REC_SERVER, + DOWNLOAD_URL_REC_SERVER) else: - cfg_rec = Config( - './configs/rec/svtrv2/repsvtr_ch.yml').cfg # mobile model + cfg_rec = Config(DEFAULT_CFG_PATH_REC).cfg # mobile model + model_dir = check_and_download_model(MODEL_NAME_REC, + DOWNLOAD_URL_REC) + cfg_rec['Global']['pretrained_model'] = model_dir self.text_detector = OpenDetector(cfg_det) self.text_recognizer = OpenRecognizer(cfg_rec) self.det_box_type = det_box_type @@ -114,14 +181,19 @@ def infer_single_image(self, img_numpy, ori_img, crop_infer=False, - rec_batch_num=6): + rec_batch_num=6, + return_mask=False, + **kwargs): start = time.time() if crop_infer: dt_boxes = self.text_detector.crop_infer( img_numpy=img_numpy)[0]['boxes'] else: - dt_boxes = self.text_detector(img_numpy=img_numpy)[0]['boxes'] - # print(dt_boxes) + det_res = self.text_detector(img_numpy=img_numpy, + return_mask=return_mask, + **kwargs)[0] + dt_boxes = det_res['boxes'] + # logger.info(dt_boxes) det_time_cost = time.time() - start if dt_boxes is None: @@ -155,6 +227,13 @@ def infer_single_image(self, avg_rec_time_cost = rec_time_cost_sig / len(dt_boxes) if len( dt_boxes) > 0 else 0.0 + if return_mask: + return filter_boxes, filter_rec_res, { + 'time_cost': det_time_cost + rec_time_cost, + 'detection_time': det_time_cost, + 'recognition_time': rec_time_cost, + 'avg_rec_time_cost': avg_rec_time_cost + }, det_res['mask'] return filter_boxes, filter_rec_res, { 'time_cost': det_time_cost + rec_time_cost, @@ -169,7 +248,9 @@ def __call__(self, is_visualize=False, img_numpy=None, rec_batch_num=6, - crop_infer=False): + crop_infer=False, + return_mask=False, + **kwargs): """ img_path: str, optional, default=None Path to the directory containing images or the image filename. @@ -194,11 +275,21 @@ def __call__(self, time_dicts = [] for index, img in enumerate(img_numpy): ori_img = img.copy() - dt_boxes, rec_res, time_dict = self.infer_single_image( - img_numpy=img, - ori_img=ori_img, - crop_infer=crop_infer, - rec_batch_num=rec_batch_num) + if return_mask: + dt_boxes, rec_res, time_dict, mask = self.infer_single_image( + img_numpy=img, + ori_img=ori_img, + crop_infer=crop_infer, + rec_batch_num=rec_batch_num, + return_mask=return_mask, + **kwargs) + else: + dt_boxes, rec_res, time_dict = self.infer_single_image( + img_numpy=img, + ori_img=ori_img, + crop_infer=crop_infer, + rec_batch_num=rec_batch_num, + **kwargs) if dt_boxes is None: results.append([]) time_dicts.append({}) @@ -210,6 +301,8 @@ def __call__(self, } for i in range(len(dt_boxes))] results.append(res) time_dicts.append(time_dict) + if return_mask: + return results, time_dicts, mask return results, time_dicts image_file_list = get_image_file_list(img_path) @@ -225,7 +318,8 @@ def __call__(self, imgs = [img] else: imgs = img - print(f'Processing {idx+1}/{len(image_file_list)}: {image_file}') + logger.info( + f'Processing {idx+1}/{len(image_file_list)}: {image_file}') res_list = [] time_dicts = [] @@ -235,7 +329,8 @@ def __call__(self, img_numpy=img_numpy, ori_img=ori_img, crop_infer=crop_infer, - rec_batch_num=rec_batch_num) + rec_batch_num=rec_batch_num, + **kwargs) if dt_boxes is None: res_list.append([]) time_dicts.append({}) @@ -252,10 +347,10 @@ def __call__(self, time_dicts)): if len(res) > 0: - print(f'Results: {res}.') - print(f'Time cost: {time_dict}.') + logger.info(f'Results: {res}.') + logger.info(f'Time cost: {time_dict}.') else: - print('No text detected.') + logger.info('No text detected.') if len(res_list) > 1: save_pred = (os.path.basename(image_file) + '_' + @@ -274,12 +369,12 @@ def __call__(self, if is_visualize and len(res) > 0: if idx == 0: font_path = './simfang.ttf' - check_and_download_font(font_path) + font_path = check_and_download_font(font_path) os.makedirs(save_dir, exist_ok=True) draw_img_save_dir = os.path.join( save_dir, 'vis_results/') os.makedirs(draw_img_save_dir, exist_ok=True) - print( + logger.info( f'Visualized results will be saved to {draw_img_save_dir}.' ) dt_boxes = [res[i]['points'] for i in range(len(res))] @@ -320,14 +415,15 @@ def __call__(self, 'w', encoding='utf-8') as f: f.writelines(save_results) - print( + logger.info( f"Results saved to {os.path.join(save_dir, 'system_results.txt')}." ) if is_visualize: - print(f'Visualized results saved to {draw_img_save_dir}.') + logger.info( + f'Visualized results saved to {draw_img_save_dir}.') return save_results, time_dicts_return else: - print('No text detected.') + logger.info('No text detected.') return None, None diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 033e788..ce13f26 100644 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import sys import time @@ -18,6 +19,69 @@ from tools.utils.utility import get_image_file_list from tools.infer_det import replace_batchnorm +logger = get_logger() + +root_dir = Path(__file__).resolve().parent +DEFAULT_CFG_PATH_REC_SERVER = str(root_dir / + '../configs/rec/svtrv2/svtrv2_ch.yml') +DEFAULT_CFG_PATH_REC = str(root_dir / '../configs/rec/svtrv2/repsvtr_ch.yml') +DEFAULT_DICT_PATH_REC = str(root_dir / './utils/ppocr_keys_v1.txt') + +MODEL_NAME_REC = './openocr_repsvtr_ch.pth' # 模型文件名称 +DOWNLOAD_URL_REC = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth' # 模型文件 URL +MODEL_NAME_REC_SERVER = './openocr_svtrv2_ch.pth' # 模型文件名称 +DOWNLOAD_URL_REC_SERVER = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth' # 模型文件 URL + + +def check_and_download_model(model_name: str, url: str): + """ + 检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。 + + Args: + model_name (str): 模型文件的名称,例如 "model.pt" + url (str): 模型文件的下载地址 + + Returns: + str: 模型文件的完整路径 + """ + if os.path.exists(model_name): + return model_name + + # 固定缓存路径为用户主目录下的 ".cache/openocr" + cache_dir = Path.home() / '.cache' / 'openocr' + model_path = cache_dir / model_name + + # 如果模型文件已存在,直接返回路径 + if model_path.exists(): + logger.info(f'Model already exists at: {model_path}') + return str(model_path) + + # 如果文件不存在,下载模型 + logger.info(f'Model not found. Downloading from {url}...') + + # 创建缓存目录(如果不存在) + cache_dir.mkdir(parents=True, exist_ok=True) + + try: + # 下载文件 + import urllib.request + with urllib.request.urlopen(url) as response, open(model_path, + 'wb') as out_file: + out_file.write(response.read()) + logger.info(f'Model downloaded and saved at: {model_path}') + return str(model_path) + + except Exception as e: + logger.error(f'Error downloading the model: {e}') + # 提示用户手动下载 + logger.error( + f'Unable to download the model automatically. ' + f'Please download the model manually from the following URL:\n{url}\n' + f'and save it to: {model_name} or {model_path}') + raise RuntimeError( + f'Failed to download the model. Please download it manually from {url} ' + f'and save it to {model_path}') from e + class RatioRecTVReisze(object): @@ -84,6 +148,7 @@ def set_device(device, numId=0): if device == 'gpu' and torch.cuda.is_available(): device = torch.device(f'cuda:{numId}') else: + logger.info('GPU is not available, using CPU.') device = torch.device('cpu') return device @@ -109,11 +174,30 @@ def __init__(self, config=None, mode='mobile', numId=0): if config is None: if mode == 'server': config = Config( - './configs/det/svtrv2/svtrv2_ch.yml').cfg # server model + DEFAULT_CFG_PATH_REC_SERVER).cfg # server model + if not os.path.exists(config['Global']['pretrained_model']): + model_dir = check_and_download_model( + MODEL_NAME_REC_SERVER, DOWNLOAD_URL_REC_SERVER) else: - config = Config( - './configs/rec/svtrv2/repsvtr_ch.yml').cfg # mobile model - + config = Config(DEFAULT_CFG_PATH_REC).cfg # mobile model + if not os.path.exists(config['Global']['pretrained_model']): + model_dir = check_and_download_model( + MODEL_NAME_REC, DOWNLOAD_URL_REC) + config['Global']['pretrained_model'] = model_dir + config['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC + else: + if config['Architecture']['algorithm'] == 'SVTRv2_mobile': + if not os.path.exists(config['Global']['pretrained_model']): + config['Global'][ + 'pretrained_model'] = check_and_download_model( + MODEL_NAME_REC, DOWNLOAD_URL_REC) + config['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC + elif config['Architecture']['algorithm'] == 'SVTRv2_server': + if not os.path.exists(config['Global']['pretrained_model']): + config['Global'][ + 'pretrained_model'] = check_and_download_model( + MODEL_NAME_REC_SERVER, DOWNLOAD_URL_REC_SERVER) + config['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC global_config = config['Global'] self.cfg = config if global_config['pretrained_model'] is None: @@ -126,7 +210,6 @@ def __init__(self, config=None, mode='mobile', numId=0): self.transform = transform self.post_process_class = build_post_process(config['PostProcess'], global_config) - char_num = self.post_process_class.get_character_num() config['Architecture']['Decoder']['out_channels'] = char_num # print(char_num) @@ -232,7 +315,6 @@ def __call__(self, with torch.no_grad(): t_start = time.time() preds = self.model(images, others) - torch.cuda.synchronize() t_cost = time.time() - t_start post_results = self.post_process_class(preds) @@ -256,10 +338,9 @@ def __call__(self, def main(cfg): - logger = get_logger() model = OpenRecognizer(cfg) - save_res_path = cfg['Global']['output_dir'] + save_res_path = './rec_results/' if not os.path.exists(save_res_path): os.makedirs(save_res_path) @@ -272,9 +353,7 @@ def main(cfg): sample_num = 0 with open(save_res_path + '/rec_results.txt', 'wb') as fout: for file in get_image_file_list(cfg['Global']['infer_img']): - preds_result = model(img_path=file, batch_num=1)[0] - rec_text = preds_result['text'] score = preds_result['score'] t_cost = preds_result['elapse'] @@ -287,6 +366,9 @@ def main(cfg): t_sum += t_cost fout.write(otstr.encode()) sample_num += 1 + logger.info( + f"Results saved to {os.path.join(save_res_path, 'rec_results.txt')}.)" + ) print(text_len_num) w_avg_t_cost = []
Method