diff --git a/.gitignore b/.gitignore
index 6262738..a166590 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,6 +10,9 @@ __pycache__/
inference/
inference_results/
output/
+e2e_results/
+rec_results/
+det_results/
train_data/
log/
*.DS_Store
diff --git a/README.md b/README.md
index d0f19c8..bf67a07 100644
--- a/README.md
+++ b/README.md
@@ -1,57 +1,64 @@
# OpenOCR
-We aim to establishing a unified benchmark for training and evaluating models for scene text detection and recognition. Based on this benchmark, we introduce an accurate and efficient general OCR system, OpenOCR. Additionally, this repository will serve as the official codebase for the OCR team from the [FVL](https://fvl.fudan.edu.cn) Laboratory, Fudan University.
+We aim to establish a unified benchmark for training and evaluating models in scene text detection and recognition. Building on this benchmark, we introduce a general OCR system with accuracy and efficiency, **OpenOCR**. This repository also serves as the official codebase of the OCR team from the [FVL Laboratory](https://fvl.fudan.edu.cn), Fudan University.
We sincerely welcome the researcher to recommend OCR or relevant algorithms and point out any potential factual errors or bugs. Upon receiving the suggestions, we will promptly evaluate and critically reproduce them. We look forward to collaborating with you to advance the development of OpenOCR and continuously contribute to the OCR community!
## Features
-- 🔥**OpenOCR: A general OCR system for accuracy and efficiency**
- - ⚡\[[Quick Start](#quick-start)\] \[[Demo](<>)(TODO)\]
+- 🔥**OpenOCR: A general OCR system with accuracy and efficiency**
+ - ⚡\[[Quick Start](#quick-start)\] \[[Model](https://github.com/Topdu/OpenOCR/releases/tag/develop0.0.1)\] \[[ModelScope Demo](https://modelscope.cn/studios/topdktu/OpenOCR-Demo)\] \[[Hugging Face Demo](https://huggingface.co/spaces/topdu/OpenOCR-Demo)\] \[[Local Demo](#local-demo)\] \[[PaddleOCR Implementation](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/text_recognition/algorithm_rec_svtrv2.html)\]
- [Introduction](./docs/openocr.md)
- - A practical version of the model builds on SVTRv2.
- - Outperforming [PP-OCRv4](<>) released by [PaddleOCR](<>) by 4.5% on the [OCR competition leaderboard](<>).
- - [x] Supporting Chinese and English text detection and recognition.
- - [x] Providing server model and mobile model.
- - [ ] Fine-tuning OpenOCR on a custom dataset
- - [ ] Export to ONNX engine
+ - A practical OCR system building on SVTRv2.
+ - Outperforms [PP-OCRv4](https://paddlepaddle.github.io/PaddleOCR/latest/ppocr/model_list.html) baseline by 4.5% on the [OCR competition leaderboard](https://aistudio.baidu.com/competition/detail/1131/0/leaderboard) in terms of accuracy, while preserving quite similar inference speed.
+ - [x] Supports Chinese and English text detection and recognition.
+ - [x] Provides server model and mobile model.
+ - [ ] Fine-tunes OpenOCR on a custom dataset.
+ - [ ] ONNX model export for wider compatibility.
- 🔥**SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition**
- - \[[Paper](../configs/rec/svtrv2/SVTRv2.pdf)\] \[[Model](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Config, Training and Inference](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\]
+ - \[[Paper](https://arxiv.org/abs/2411.15858)\] \[[Doc](./configs/rec/svtrv2/)\] \[[Model](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Datasets](./docs/svtrv2.md#downloading-datasets)\] \[[Config, Training and Inference](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] \[[Benchmark](./docs/svtrv2.md#results-benchmark--configs--checkpoints)\]
- [Introduction](./docs/svtrv2.md)
- - Developing a unified training and evaluation benchmark for Scene Text Recognition
- - Supporting for 24 Scene Text Recognition methods trained from scratch on large-scale real datasets, and will continue to add the latest methods.
- - Improving results by 20-30% compared to training on synthetic datasets.
+ - A unified training and evaluation benchmark (on top of [Union14M](https://github.com/Mountchicken/Union14M?tab=readme-ov-file#3-union14m-dataset)) for Scene Text Recognition
+ - Supports 24 Scene Text Recognition methods trained from scratch on the large-scale real dataset [Union14M-L-Filter](./docs/svtrv2.md#dataset-details), and will continue to add the latest methods.
+ - Improves accuracy by 20-30% compared to models trained based on synthetic datasets.
- Towards Arbitrary-Shaped Text Recognition and Language modeling with a Single Visual Model.
- - Surpasses Attention-based Decoder Methods across challenging scenarios in terms of accuracy and speed
- - [Get Started](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) with training a SoTA Scene Text Recognition model from scratch.
+ - Surpasses Attention-based Encoder-Decoder Methods across challenging scenarios in terms of accuracy and speed
+ - [Get Started](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) with training a SOTA Scene Text Recognition model from scratch.
## Ours STR algorithms
-- [**DPTR**](<>) (*Shuai Zhao, Yongkun Du, Zhineng Chen\*, Yu-Gang Jiang. Decoder Pre-Training with only Text for Scene Text Recognition,* ACM MM 2024. [paper](https://arxiv.org/abs/2408.05706))
-- [**IGTR**](./configs/rec/igtr/) (*Yongkun Du, Zhineng Chen\*, Yuchen Su, Caiyan Jia, Yu-Gang Jiang. Instruction-Guided Scene Text Recognition,* Under TPAMI minor revison 2024. [Doc](./configs/rec/igtr/readme.md), [paper](https://arxiv.org/abs/2401.17851))
-- [**SVTRv2**](./configs/rec/svtrv2) (*Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang. SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition,* 2024. [paper](./configs/rec/svtrv2/SVTRv2.pdf))
-- [**SMTR&FocalSVTR**](./configs/rec/smtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang. Out of Length Text Recognition with Sub-String Matching,* 2024. [paper](https://arxiv.org/abs/2407.12317))
-- [**CDistNet**](./configs/rec/cdistnet/) (*Tianlun Zheng, Zhineng Chen\*, Shancheng Fang, Hongtao Xie, Yu-Gang Jiang. CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition,* IJCV 2024. [paper](https://link.springer.com/article/10.1007/s11263-023-01880-0))
-- **MRN** (*Tianlun Zheng, Zhineng Chen\*, Bingchen Huang, Wei Zhang, Yu-Gang Jiang. MRN: Multiplexed routing network for incremental multilingual text recognition,* ICCV 2023. [paper](https://openaccess.thecvf.com/content/ICCV2023/html/Zheng_MRN_Multiplexed_Routing_Network_for_Incremental_Multilingual_Text_Recognition_ICCV_2023_paper.html))
-- **TPS++** (*Tianlun Zheng, Zhineng Chen\*, Jinfeng Bai, Hongtao Xie, Yu-Gang Jiang. TPS++: Attention-Enhanced Thin-Plate Spline for Scene Text Recognition,* IJCAI 2023. [paper](https://arxiv.org/abs/2305.05322))
-- [**CPPD**](./configs/rec/cppd/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Chenxia Li, Yuning Du, Yu-Gang Jiang. Context Perception Parallel Decoder for Scene Text Recognition,* Under TPAMI minor revision 2023. [PaddleOCR Doc](https://github.com/Topdu/PaddleOCR/blob/main/doc/doc_ch/algorithm_rec_cppd.md), [paper](https://arxiv.org/abs/2307.12270))
-- [**SVTR**](./configs/rec/svtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Tianlun Zheng, Chenxia Li, Yuning Du, Yu-Gang Jiang. SVTR: Scene Text Recognition with a Single Visual Model,* IJCAI 2022 (Long). [PaddleOCR Doc](https://github.com/Topdu/PaddleOCR/blob/main/doc/doc_ch/algorithm_rec_svtr.md), [paper](https://www.ijcai.org/proceedings/2022/124))
-- [**NRTR**](./configs/rec/nrtr/) (*Fenfen Sheng, Zhineng Chen\*, Bo Xu. NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition,* ICDAR 2019. [paper](https://arxiv.org/abs/1806.00926))
+- [**SMTR&FocalSVTR**](./configs/rec/smtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang. Out of Length Text Recognition with Sub-String Matching,* AAAI 2025. [Doc](./configs/rec/smtr/), [Paper](https://arxiv.org/abs/2407.12317))
+- [**DPTR**](./configs/rec/dptr/) (*Shuai Zhao, Yongkun Du, Zhineng Chen\*, Yu-Gang Jiang. Decoder Pre-Training with only Text for Scene Text Recognition,* ACM MM 2024. [Paper](https://arxiv.org/abs/2408.05706))
+- [**IGTR**](./configs/rec/igtr/) (*Yongkun Du, Zhineng Chen\*, Yuchen Su, Caiyan Jia, Yu-Gang Jiang. Instruction-Guided Scene Text Recognition,* TPAMI (accepted). [Doc](./configs/rec/igtr), [Paper](https://arxiv.org/abs/2401.17851))
+- [**SVTRv2**](./configs/rec/svtrv2) (*Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang. SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition,* 2024. [Doc](./configs/rec/svtrv2/), [Paper](https://arxiv.org/abs/2411.15858))
+- [**CDistNet**](./configs/rec/cdistnet/) (*Tianlun Zheng, Zhineng Chen\*, Shancheng Fang, Hongtao Xie, Yu-Gang Jiang. CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition,* IJCV 2024. [Paper](https://link.springer.com/article/10.1007/s11263-023-01880-0))
+- **MRN** (*Tianlun Zheng, Zhineng Chen\*, Bingchen Huang, Wei Zhang, Yu-Gang Jiang. MRN: Multiplexed routing network for incremental multilingual text recognition,* ICCV 2023. [Paper](https://openaccess.thecvf.com/content/ICCV2023/html/Zheng_MRN_Multiplexed_Routing_Network_for_Incremental_Multilingual_Text_Recognition_ICCV_2023_paper.html), [Code](https://github.com/simplify23/MRN))
+- **TPS++** (*Tianlun Zheng, Zhineng Chen\*, Jinfeng Bai, Hongtao Xie, Yu-Gang Jiang. TPS++: Attention-Enhanced Thin-Plate Spline for Scene Text Recognition,* IJCAI 2023. [Paper](https://arxiv.org/abs/2305.05322), [Code](https://github.com/simplify23/TPS_PP))
+- [**CPPD**](./configs/rec/cppd/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Chenxia Li, Yuning Du, Yu-Gang Jiang. Context Perception Parallel Decoder for Scene Text Recognition,* Under TPAMI minor revision 2023. [PaddleOCR Doc](https://github.com/Topdu/PaddleOCR/blob/main/doc/doc_ch/algorithm_rec_cppd.md), [Paper](https://arxiv.org/abs/2307.12270))
+- [**SVTR**](./configs/rec/svtr/) (*Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xiaoting Yin, Tianlun Zheng, Chenxia Li, Yuning Du, Yu-Gang Jiang. SVTR: Scene Text Recognition with a Single Visual Model,* IJCAI 2022 (Long). [PaddleOCR Doc](https://github.com/Topdu/PaddleOCR/blob/main/doc/doc_ch/algorithm_rec_svtr.md), [Paper](https://www.ijcai.org/proceedings/2022/124))
+- [**NRTR**](./configs/rec/nrtr/) (*Fenfen Sheng, Zhineng Chen\*, Bo Xu. NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition,* ICDAR 2019. [Paper](https://arxiv.org/abs/1806.00926))
## Recent Updates
+- **2024.12.31**: Our paper [IGTR](https://arxiv.org/abs/2401.17851) is accepted by TPAMI. Accessible in [Doc](./configs/rec/igtr/).
+
+- **2024.12.16**: Our paper [SMTR](https://arxiv.org/abs/2407.12317) is accepted by AAAI 2025. Accessible in [Doc](./configs/rec/smtr/).
+
+- **2024.12.03**: The pre-training code for [DPTR](https://arxiv.org/abs/2408.05706) is merged.
+
- **🔥 2024.11.23 release notes**:
- - **OpenOCR: A general OCR system for accuracy and efficiency**
- - ⚡\[[Quick Start](#quick-start)\] \[[Demo](<>)(TODO)\]
+
+ - **OpenOCR: A general OCR system with accuracy and efficiency**
+ - ⚡\[[Quick Start](#quick-start)\] \[[Model](https://github.com/Topdu/OpenOCR/releases/tag/develop0.0.1)\] \[[ModelScope Demo](https://modelscope.cn/studios/topdktu/OpenOCR-Demo)\] \[[Hugging Face Demo](https://huggingface.co/spaces/topdu/OpenOCR-Demo)\] \[[Local Demo](#local-demo)\] \[[PaddleOCR Implementation](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/text_recognition/algorithm_rec_svtrv2.html)\]
- [Introduction](./docs/openocr.md)
- **SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition**
- - \[[Paper](../configs/rec/svtrv2/SVTRv2.pdf)\] \[[Model](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Config, Training and Inference](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\]
+ - \[[Paper](https://arxiv.org/abs/2411.15858)\] \[[Doc](./configs/rec/svtrv2/)\] \[[Model](./configs/rec/svtrv2/readme.md#11-models-and-results)\] \[[Datasets](./docs/svtrv2.md#downloading-datasets)\] \[[Config, Training and Inference](./configs/rec/svtrv2/readme.md#3-model-training--evaluation)\] \[[Benchmark](./docs/svtrv2.md#results--configs--checkpoints)\]
- [Introduction](./docs/svtrv2.md)
- - [Get Started](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) with training a SoTA Scene Text Recognition model from scratch.
+ - [Get Started](./docs/svtrv2.md#get-started-with-training-a-sota-scene-text-recognition-model-from-scratch) with training a SOTA Scene Text Recognition model from scratch.
-## ⚡[Quick Start](./docs/openocr.md#quick-start)
+## Quick Start
-#### Dependencies:
+### Dependencies:
- [PyTorch](http://pytorch.org/) version >= 1.13.0
- Python version >= 3.7
@@ -59,12 +66,15 @@ We sincerely welcome the researcher to recommend OCR or relevant algorithms and
```shell
conda create -n openocr python==3.8
conda activate openocr
+# install gpu version torch
conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia
+# or cpu version
+conda install pytorch torchvision torchaudio cpuonly -c pytorch
```
After installing dependencies, the following two installation methods are available. Either one can be chosen.
-#### 1. Python Modules
+### 1. Python Modules
```shell
pip install openocr-python
@@ -79,19 +89,21 @@ engine = OpenOCR()
img_path = '/path/img_path or /path/img_file'
result, elapse = engine(img_path)
-print(result)
-print(elapse)
# Server mode
-engine = OpenOCR(mode='server')
+# engine = OpenOCR(mode='server')
```
-#### 2. Clone this repository:
+### 2. Clone this repository:
```shell
git clone https://github.com/Topdu/OpenOCR.git
cd OpenOCR
pip install -r requirements.txt
+wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth
+wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth
+# Rec Server model
+# wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth
```
**Usage**:
@@ -99,14 +111,22 @@ pip install -r requirements.txt
```shell
# OpenOCR system: Det + Rec model
python tools/infer_e2e.py --img_path=/path/img_fold or /path/img_file
-
# Det model
python tools/infer_det.py --c ./configs/det/dbnet/repvit_db.yml --o Global.infer_img=/path/img_fold or /path/img_file
-
# Rec model
python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.infer_img=/path/img_fold or /path/img_file
```
+#### Local Demo
+
+```shell
+pip install gradio==4.20.0
+wget https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/OCR_e2e_img.tar
+tar xf OCR_e2e_img.tar
+# start demo
+python demo_gradio.py
+```
+
## Reproduction schedule:
### Scene Text Recognition
@@ -117,7 +137,7 @@ python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.inf
| [ASTER](./configs/rec/aster/) | [TPAMI 2019](https://ieeexplore.ieee.org/document/8395027) | ✅ | ✅ | [pretto0](https://github.com/pretto0) |
| [NRTR](./configs/rec/nrtr/) | [ICDAR 2019](https://arxiv.org/abs/1806.00926) | ✅ | ✅ | |
| [SAR](./configs/rec/sar/) | [AAAI 2019](https://aaai.org/papers/08610-show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition/) | ✅ | ✅ | [pretto0](https://github.com/pretto0) |
-| [MORAN](./configs/rec/moran/) | [PR 2019](https://www.sciencedirect.com/science/article/abs/pii/S0031320319300263) | ✅ | ✅ | Debug |
+| [MORAN](./configs/rec/moran/) | [PR 2019](https://www.sciencedirect.com/science/article/abs/pii/S0031320319300263) | ✅ | ✅ | |
| [DAN](./configs/rec/dan/) | [AAAI 2020](https://arxiv.org/pdf/1912.10205) | ✅ | ✅ | |
| [RobustScanner](./configs/rec/robustscanner/) | [ECCV 2020](https://www.ecva.net/papers/eccv_2020/papers_ECCV/html/3160_ECCV_2020_paper.php) | ✅ | ✅ | [pretto0](https://github.com/pretto0) |
| [AutoSTR](./configs/rec/autostr/) | [ECCV 2020](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123690732.pdf) | ✅ | ✅ | |
@@ -125,6 +145,7 @@ python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.inf
| [SEED](./configs/rec/seed/) | [CVPR 2020](https://openaccess.thecvf.com/content_CVPR_2020/html/Qiao_SEED_Semantics_Enhanced_Encoder-Decoder_Framework_for_Scene_Text_Recognition_CVPR_2020_paper.html) | ✅ | ✅ | |
| [ABINet](./configs/rec/abinet/) | [CVPR 2021](https://openaccess.thecvf.com//content/CVPR2021/html/Fang_Read_Like_Humans_Autonomous_Bidirectional_and_Iterative_Language_Modeling_for_CVPR_2021_paper.html) | ✅ | ✅ | [YesianRohn](https://github.com/YesianRohn) |
| [VisionLAN](./configs/rec/visionlan/) | [ICCV 2021](https://openaccess.thecvf.com/content/ICCV2021/html/Wang_From_Two_to_One_A_New_Scene_Text_Recognizer_With_ICCV_2021_paper.html) | ✅ | ✅ | [YesianRohn](https://github.com/YesianRohn) |
+| PIMNet | [ACM MM 2021](https://dl.acm.org/doi/10.1145/3474085.3475238) | | | TODO |
| [SVTR](./configs/rec/svtrs/) | [IJCAI 2022](https://www.ijcai.org/proceedings/2022/124) | ✅ | ✅ | |
| [PARSeq](./configs/rec/parseq/) | [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136880177.pdf) | ✅ | ✅ | |
| [MATRN](./configs/rec/matrn/) | [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136880442.pdf) | ✅ | ✅ | |
@@ -137,14 +158,14 @@ python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.inf
| [BUSNet](./configs/rec/busnet/) | [AAAI 2024](https://ojs.aaai.org/index.php/AAAI/article/view/28402) | ✅ | ✅ | |
| DCTC | [AAAI 2024](https://ojs.aaai.org/index.php/AAAI/article/view/28575) | | | TODO |
| [CAM](./configs/rec/cam/) | [PR 2024](https://arxiv.org/abs/2402.13643) | ✅ | ✅ | |
-| [OTE](./configs/rec/ote/) | [CVPR 2024](https://openaccess.thecvf.com/content/CVPR2024/papers/Xu_OTE_Exploring_Accurate_Scene_Text_Recognition_Using_One_Token_CVPR_2024_paper.pdf) | ✅ | ✅ | |
+| [OTE](./configs/rec/ote/) | [CVPR 2024](https://openaccess.thecvf.com/content/CVPR2024/html/Xu_OTE_Exploring_Accurate_Scene_Text_Recognition_Using_One_Token_CVPR_2024_paper.html) | ✅ | ✅ | |
| CFF | [IJCAI 2024](https://arxiv.org/abs/2407.05562) | | | TODO |
-| DPTR | [ACM MM 2024](https://arxiv.org/abs/2408.05706) | | | TODO |
+| [DPTR](./configs/rec/dptr/) | [ACM MM 2024](https://arxiv.org/abs/2408.05706) | | | [fd-zs](https://github.com/fd-zs) |
| VIPTR | [ACM CIKM 2024](https://arxiv.org/abs/2401.10110) | | | TODO |
-| [IGTR](./configs/rec/igtr/) | [2024](https://arxiv.org/abs/2401.17851) | ✅ | ✅ | |
-| [SMTR](./configs/rec/smtr/) | [2024](https://arxiv.org/abs/2407.12317) | ✅ | ✅ | |
+| [IGTR](./configs/rec/igtr/) | [TPAMI](https://arxiv.org/abs/2401.17851) | ✅ | ✅ | |
+| [SMTR](./configs/rec/smtr/) | [AAAI 2025](https://arxiv.org/abs/2407.12317) | ✅ | ✅ | |
| [FocalSVTR-CTC](./configs/rec/svtrs/) | [2024](https://arxiv.org/abs/2407.12317) | ✅ | ✅ | |
-| [SVTRv2](./configs/rec/svtrv2/) | [2024](./configs/rec/svtrv2/SVTRv2.pdf) | ✅ | ✅ | |
+| [SVTRv2](./configs/rec/svtrv2/) | [2024](https://arxiv.org/abs/2411.15858) | ✅ | ✅ | |
| [ResNet+Trans-CTC](./configs/rec/svtrs/) | | ✅ | ✅ | |
| [ViT-CTC](./configs/rec/svtrs/) | | ✅ | ✅ | |
@@ -152,7 +173,7 @@ python tools/infer_rec.py --c ./configs/rec/svtrv2/repsvtr_ch.yml --o Global.inf
______________________________________________________________________
-Yiming Lei ([pretto0](https://github.com/pretto0)) and Xingsong Ye ([YesianRohn](https://github.com/YesianRohn)) from the [FVL](https://fvl.fudan.edu.cn) Laboratory, Fudan University, under the guidance of Professor Zhineng Chen, completed the majority of the algorithm reproduction work. Grateful for their outstanding contributions.
+Yiming Lei ([pretto0](https://github.com/pretto0)), Xingsong Ye ([YesianRohn](https://github.com/YesianRohn)), and Shuai Zhao ([fd-zs](https://github.com/fd-zs)) from the [FVL Laboratory](https://fvl.fudan.edu.cn), Fudan University, with guidance from Dr. Zhineng Chen ([Homepage](https://zhinchenfd.github.io/)), completed the majority work of the algorithm reproduction. Grateful for their outstanding contributions.
### Scene Text Detection (STD)
@@ -164,6 +185,23 @@ TODO
______________________________________________________________________
+## Citation
+
+If you find our method useful for your reserach, please cite:
+
+```bibtex
+@article{Du2024SVTRv2,
+ title={SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition},
+ author={Yongkun Du and Zhineng Chen and Hongtao Xie and Caiyan Jia and Yu-Gang Jiang},
+ journal={CoRR},
+ volume={abs/2411.15858},
+ eprinttype={arXiv},
+ year={2024},
+ primaryClass={cs.CV},
+ url={https://arxiv.org/abs/2411.15858}
+}
+```
+
# Acknowledgement
This codebase is built based on the [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [PytorchOCR](https://github.com/WenmuZhou/PytorchOCR), and [MMOCR](https://github.com/open-mmlab/mmocr). Thanks for their awesome work!
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..1fa1819
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,11 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
+
+from tools.infer_e2e import OpenOCR, OpenDetector, OpenRecognizer
diff --git a/configs/dataset/rec/evaluation.yaml b/configs/dataset/rec/evaluation.yaml
new file mode 100644
index 0000000..7c10194
--- /dev/null
+++ b/configs/dataset/rec/evaluation.yaml
@@ -0,0 +1,41 @@
+root: ../evaluation
+task: str
+download_links:
+ # IC15_1811
+ - https://drive.usercontent.google.com/download?id=1eGY0kXNV1qVxeUpoGzs-ioUO-ky7msH6&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1BWv7aLoLAT7avY326gXP3GJF48UZpuBC&authuser=0&confirm=t
+ # SVT
+ - https://drive.usercontent.google.com/download?id=1ecEZ4cJ7dIbTCZRltE0s5KzUotQWagH-&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1OygBP7i9R-3Pwi6WodCcW31J8CUMugOJ&authuser=0&confirm=t
+ # IIIT5k
+ - https://drive.usercontent.google.com/download?id=1PJ9_IvIGZTS5hHdGLnpKuYKZcCO8jE0E&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=10P3MixSBt1v8k8_6aFfziC33Z5IlM6Uf&authuser=0&confirm=t
+ # IC13_857
+ - https://drive.usercontent.google.com/download?id=1-wMHOFBXJaOaY-UD00nDn6qw2s_8R4Vd&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1J1QCFtOFxFKiLJIgTqZ6eRo9Y5QGqHpA&authuser=0&confirm=t
+ # SVTP
+ - https://drive.usercontent.google.com/download?id=1kckwfZkdaHG8k_FW5IIJKUaYZkF21Hza&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1x61lm_ea7lvIdxNPMG-jy-5W0MxtdH0N&authuser=0&confirm=t
+ # CUTE80
+ - https://drive.usercontent.google.com/download?id=1Zv_91c81tinLy5Je89HPr-5wUSnqXKIB&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1OuJ6QoJ9AlyNHIM9j2WedAPxTnac7kyY&authuser=0&confirm=t
+filenames:
+ # IC15_1811
+ - ../evaluation/IC15_1811/data.mdb
+ - ../evaluation/IC15_1811/lock.mdb
+ # SVT
+ - ../evaluation/SVT/data.mdb
+ - ../evaluation/SVT/lock.mdb
+ # IIIT5k
+ - ../evaluation/IIIT5k/data.mdb
+ - ../evaluation/IIIT5k/lock.mdb
+ # IC13_857
+ - ../evaluation/IC13_857/data.mdb
+ - ../evaluation/IC13_857/lock.mdb
+ # SVTP
+ - ../evaluation/SVTP/data.mdb
+ - ../evaluation/SVTP/lock.mdb
+ # CUTE80
+ - ../evaluation/CUTE80/data.mdb
+ - ../evaluation/CUTE80/lock.mdb
+check_validity: true
diff --git a/configs/dataset/rec/ltb.yaml b/configs/dataset/rec/ltb.yaml
new file mode 100644
index 0000000..a0b03a5
--- /dev/null
+++ b/configs/dataset/rec/ltb.yaml
@@ -0,0 +1,9 @@
+root: ../ltb
+task: str
+download_links:
+ - https://drive.usercontent.google.com/download?id=16AEA1YGTsyVB44uEjKi4ZUV1snjCYBr4&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1xU4OStrOaI23bPG4flWAPWn2YrQe2bmY&authuser=0&confirm=t
+filenames:
+ - ../ltb/data.mdb
+ - ../ltb/lock.mdb
+check_validity: true
diff --git a/configs/dataset/rec/mjsynth.yaml b/configs/dataset/rec/mjsynth.yaml
new file mode 100644
index 0000000..7a9fc75
--- /dev/null
+++ b/configs/dataset/rec/mjsynth.yaml
@@ -0,0 +1,11 @@
+root: ../synth
+task: str
+download_links:
+ - https://drive.usercontent.google.com/download?id=1FIoplSFZ-BKQoRDHDXsVMKa844e-K8PD&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1eckTvaeRtlTZvbO2orrVz-cIuIk6i87K&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1PBXTf-2PnmEvJBsqzJqxxRwzhAZGTiMG&authuser=0&confirm=t
+filenames:
+ - ../synth/MJ_train.zip
+ - ../synth/MJ_val.zip
+ - ../synth/MJ_test.zip
+check_validity: true
\ No newline at end of file
diff --git a/configs/dataset/rec/openvino.yaml b/configs/dataset/rec/openvino.yaml
new file mode 100644
index 0000000..f60e448
--- /dev/null
+++ b/configs/dataset/rec/openvino.yaml
@@ -0,0 +1,25 @@
+root: ../OpenVINO
+task: str
+download_links:
+ # train_1
+ - https://drive.usercontent.google.com/download?id=1q23QAIRTyG0t-bBm4aAwRwiqB6VUfphw&authuser=0&confirm=
+ # train_2
+ - https://drive.usercontent.google.com/download?id=1AtbaJljM68cbZqi5lcM92d9VkQUCbSqI&authuser=0&confirm=
+ # train_5
+ - https://drive.usercontent.google.com/download?id=1dejstYnJ8_sESuO_uvwi__jT1B8gPxf3&authuser=0&confirm=t
+ # train_f
+ - https://drive.usercontent.google.com/download?id=1C4akchTc7-yi1OS_sJ3KP693UKcnecke&authuser=0&confirm=t
+ # validation
+ - https://drive.usercontent.google.com/download?id=17TRzSQhuK_juAxAv3KmX0y13pQP2cz6R&authuser=0&confirm=t
+filenames:
+ # train_1
+ - ../OpenVINO/train_1.zip
+ # train_2
+ - ../OpenVINO/train_2.zip
+ # train_5
+ - ../OpenVINO/train_5.zip
+ # train_f
+ - ../OpenVINO/train_f.zip
+ # validation
+ - ../OpenVINO/validation.zip
+check_validity: true
diff --git a/configs/dataset/rec/ost.yaml b/configs/dataset/rec/ost.yaml
new file mode 100644
index 0000000..6690f12
--- /dev/null
+++ b/configs/dataset/rec/ost.yaml
@@ -0,0 +1,17 @@
+root: ../OST
+task: str
+download_links:
+ # OST heavy
+ - https://drive.usercontent.google.com/download?id=1RGpIFbD_SRlrzZFBoVF_LGvetNx1-5pg&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1Th4MfDf44k0EBpIqCLqVoGRu6G-FP1hq&authuser=0&confirm=t
+ # OST weak
+ - https://drive.usercontent.google.com/download?id=1z5CTDJucUnvALG12Q4UXk1DDKJDd8WJn&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1V17TTkX3sjpV7v0km_F2SDCK0tL3k_ls&authuser=0&confirm=t
+filenames:
+ # OST heavy
+ - ../OST/heavy/data.mdb
+ - ../OST/heavy/lock.mdb
+ # OST weak
+ - ../OST/weak/data.mdb
+ - ../OST/weak/lock.mdb
+check_validity: true
diff --git a/configs/dataset/rec/synthtext.yaml b/configs/dataset/rec/synthtext.yaml
new file mode 100644
index 0000000..4b84088
--- /dev/null
+++ b/configs/dataset/rec/synthtext.yaml
@@ -0,0 +1,7 @@
+root: ../synth
+task: str
+download_links:
+ - https://drive.usercontent.google.com/download?id=1T-enqkq6_l2HqrsV3da_h0oJ7CUKu_oc&authuser=0&confirm=t
+filenames:
+ - ../synth/ST.zip
+check_validity: true
diff --git a/configs/dataset/rec/test.yaml b/configs/dataset/rec/test.yaml
new file mode 100644
index 0000000..b043ba4
--- /dev/null
+++ b/configs/dataset/rec/test.yaml
@@ -0,0 +1,77 @@
+root: ../test
+task: str
+download_links:
+ # IC13_857
+ - https://drive.usercontent.google.com/download?id=1PZSCbe6_DI8MlCqCRWXGT2PP92_frIXq&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1qkN7NDg0zUHxUiZHAeEatDTqlsgpFWp3&authuser=0&confirm=t
+ # IC15_2077
+ - https://drive.usercontent.google.com/download?id=1dFkY3DNbr-Mepn3TWBiA9COEJ63fGFcp&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1UvVwLNZ3tS1YdTBa8MulPzjeVezKaDro&authuser=0&confirm=t
+ # SVTP
+ - https://drive.usercontent.google.com/download?id=1aofeerilxJ7J3S7QxuCEXbmXTpz8Xshx&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1rJ1KoO4K_VUxEAUN_bMgBGzK8_JZAAno&authuser=0&confirm=t
+ # IIIT5k
+ - https://drive.usercontent.google.com/download?id=1XFO2M1Kbgwv3-iTNTmhQXAEjNmKYOeoT&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1stwK2hFsyaV7HHsEG9EYgnUQebNb2_nG&authuser=0&confirm=t
+ # COCOv1.4
+ - https://drive.usercontent.google.com/download?id=1Se2QSGS19xx7Gfy-SUdX9mlAOr2eYsfA&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1xvekFi389QfkH7yS0JIVV0QzjhUspjDv&authuser=0&confirm=t
+ # IC15_1811
+ - https://drive.usercontent.google.com/download?id=1pHsw8wrThD9EGEE6AusQLZozefSj4iyR&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1TXZ1qHuKAksaAlvd3qMv4IHKnN-IJW9a&authuser=0&confirm=t
+ # Uber
+ - https://drive.usercontent.google.com/download?id=1L2j6BZeLTGQ1FIl8HB_D3AFiWLltGV5r&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=12DUj28yzLWxFO_gfMfSjTkRujYD5MNEE&authuser=0&confirm=t
+ # IC13_1095
+ - https://drive.usercontent.google.com/download?id=1fu8onMt3Z6fDLNAiHcm-sQ2qCXduE-FU&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1OQAZtLj8U2Cl4L0ErGFsz6vGIVTTWasD&authuser=0&confirm=t
+ # IC13_1015
+ - https://drive.usercontent.google.com/download?id=1mbsfuvWB282HYfn9tbqcj1nUDkLXcSNB&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1QGogU_hV-oN7iY2POutdD2LDcmK6plnV&authuser=0&confirm=t
+ # ArT
+ - https://drive.usercontent.google.com/download?id=1-53knSy-uTSngCG7wyBngVyTuTCmdnWl&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=172EsSaf7BVaB1ORtohi-Jc_8SuUKZGGf&authuser=0&confirm=t
+ # SVT
+ - https://drive.usercontent.google.com/download?id=1p7aVUr9Yr7c4X4YUBvk2-YP28rraHjn9&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1ALmhvSleZ0yf-lcdbQPP3M9Zc3oqnXij&authuser=0&confirm=t
+ # CUTE80
+ - https://drive.usercontent.google.com/download?id=1Ujr4axHKnu54P2rIGUhkjdM6XlhDYrI_&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1DvZi9L3MqjO2zRUyCg3YvP4qMAt2bsme&authuser=0&confirm=t
+filenames:
+ # IC13_857
+ - ../test/IC13_857/data.mdb
+ - ../test/IC13_857/lock.mdb
+ # IC15_2077
+ - ../test/IC15_2077/data.mdb
+ - ../test/IC15_2077/lock.mdb
+ # SVTP
+ - ../test/SVTP/data.mdb
+ - ../test/SVTP/lock.mdb
+ # IIIT5k
+ - ../test/IIIT5k/data.mdb
+ - ../test/IIIT5k/lock.mdb
+ # COCOv1.4
+ - ../test/COCOv1.4/data.mdb
+ - ../test/COCOv1.4/lock.mdb
+ # IC15_1811
+ - ../test/IC15_1811/data.mdb
+ - ../test/IC15_1811/lock.mdb
+ # Uber
+ - ../test/Uber/data.mdb
+ - ../test/Uber/lock.mdb
+ # IC13_1095
+ - ../test/IC13_1095/data.mdb
+ - ../test/IC13_1095/lock.mdb
+ # IC13_1015
+ - ../test/IC13_1015/data.mdb
+ - ../test/IC13_1015/lock.mdb
+ # ArT
+ - ../test/ArT/data.mdb
+ - ../test/ArT/lock.mdb
+ # SVT
+ - ../test/SVT/data.mdb
+ - ../test/SVT/lock.mdb
+ # CUTE80
+ - ../test/CUTE80/data.mdb
+ - ../test/CUTE80/lock.mdb
+check_validity: true
diff --git a/configs/dataset/rec/textocr.yaml b/configs/dataset/rec/textocr.yaml
new file mode 100644
index 0000000..abfb4b7
--- /dev/null
+++ b/configs/dataset/rec/textocr.yaml
@@ -0,0 +1,13 @@
+root: ../TextOCR
+task: str
+download_links:
+ # train
+ - https://drive.usercontent.google.com/download?id=1jVjJFno4pnsU0Cp_kn4MIXQrChmELy92&authuser=0&confirm=
+ # val
+ - https://drive.usercontent.google.com/download?id=1ubIRu01MXIek6OvInu-XjaIbw6277-vw&authuser=0&confirm=t
+filenames:
+ # train
+ - ../TextOCR/train.zip
+ # val
+ - ../TextOCR/val.zip
+check_validity: true
diff --git a/configs/dataset/rec/textocr_horizontal.yaml b/configs/dataset/rec/textocr_horizontal.yaml
new file mode 100644
index 0000000..1ccec04
--- /dev/null
+++ b/configs/dataset/rec/textocr_horizontal.yaml
@@ -0,0 +1,13 @@
+root: ../TextOCR_horizontal
+task: str
+download_links:
+ # train
+ - https://drive.usercontent.google.com/download?id=1sWH6J11xbjQb8SH7fdG_8mIKVI81ZQy5&authuser=0&confirm=
+ # val
+ - https://drive.usercontent.google.com/download?id=1gIE-AU2o-5hvg288-bjphO6UkI5AEQ2d&authuser=0&confirm=t
+filenames:
+ # train
+ - ../TextOCR_horizontal/train.zip
+ # val
+ - ../TextOCR_horizontal/val.zip
+check_validity: true
diff --git a/configs/dataset/rec/union14m_b.yaml b/configs/dataset/rec/union14m_b.yaml
new file mode 100644
index 0000000..7479207
--- /dev/null
+++ b/configs/dataset/rec/union14m_b.yaml
@@ -0,0 +1,47 @@
+root: ../u14m
+task: str
+download_links:
+ # artistic
+ - https://drive.usercontent.google.com/download?id=1Je2DTuFHnkXDI99yDnm9Anl5naWaCQwd&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1xtT_Q0juBJUIvAG55qBxoVNNTECd2usZ&authuser=0&confirm=t
+ # contextless
+ - https://drive.usercontent.google.com/download?id=1_0OzyzWhZOmGrHkayFTVrzhrQrNRDRPR&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1PPgC42y3xoM9bR0HQFbDYbcT3PzMdD_y&authuser=0&confirm=t
+ # salient
+ - https://drive.usercontent.google.com/download?id=1tHLMYBmTqRnxvFOTT3dfLfQiundqFWfd&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=13NQgpAtCK0kh9M5E2pAUmKKEp6Qu5Xwj&authuser=0&confirm=t
+ # multi_words
+ - https://drive.usercontent.google.com/download?id=1IlnDKX3V_Vp9gsDGFB0xoqsVLH1vtxUI&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1mFFjC7C0CwevvkwFU9YeVbZBdps_3Qpb&authuser=0&confirm=t
+ # curve
+ - https://drive.usercontent.google.com/download?id=1MxhMd85cmhUtI2lmtXhZQuFk7lav0_fw&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1N03g-4e-kJG2mRvlM0c5TrwWAkd-iG-Q&authuser=0&confirm=t
+ # general
+ - https://drive.usercontent.google.com/download?id=1Oqt7OaycP466NWoDmoJ3FqS8YP3YRgvu&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1K0MrX5eYNt8IIGFHXCwg0_oI5OF5PPFO&authuser=0&confirm=t
+ # multi_oriented
+ - https://drive.usercontent.google.com/download?id=1TKZFcZPVk0ThqfF-AGhJk_OCLg0ykKbv&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1PAoLMUWuR7O2-7XRoKkNzQcSiznErQzD&authuser=0&confirm=t
+filenames:
+ # artistic
+ - ../u14m/artistic/data.mdb
+ - ../u14m/artistic/lock.mdb
+ # contextless
+ - ../u14m/contextless/data.mdb
+ - ../u14m/contextless/lock.mdb
+ # salient
+ - ../u14m/salient/data.mdb
+ - ../u14m/salient/lock.mdb
+ # multi_words
+ - ../u14m/multi_words/data.mdb
+ - ../u14m/multi_words/lock.mdb
+ # curve
+ - ../u14m/curve/data.mdb
+ - ../u14m/curve/lock.mdb
+ # general
+ - ../u14m/general/data.mdb
+ - ../u14m/general/lock.mdb
+ # multi_oriented
+ - ../u14m/multi_oriented/data.mdb
+ - ../u14m/multi_oriented/lock.mdb
+check_validity: true
diff --git a/configs/dataset/rec/union14m_l_filtered.yaml b/configs/dataset/rec/union14m_l_filtered.yaml
new file mode 100644
index 0000000..86ab60e
--- /dev/null
+++ b/configs/dataset/rec/union14m_l_filtered.yaml
@@ -0,0 +1,35 @@
+root: ../Union14M-L-LMDB-Filtered
+task: str
+download_links:
+ # train_challenging
+ - https://drive.usercontent.google.com/download?id=1etwzBgGHjsFsb0sygsaRnKbanW2PMe07&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1ly6FJfPjItwGlVQ-ifTrzzM3rVu3Ezhr&authuser=0&confirm=t
+ # train_easy
+ - https://drive.usercontent.google.com/download?id=1_zeNluTnywIaa5h3PN-Ah9tKyByypot7&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1caYLeQHDidXgVBDi9IWXbO1gg__DYq9a&authuser=0&confirm=t
+ # train_hard
+ - https://drive.usercontent.google.com/download?id=1eP6s2xyYPZX9gykvWA4VSOc3Fqul_UB_&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1-ZlCvocX8P5uVRclUXp_5DNGLDzd16EO&authuser=0&confirm=t
+ # train_medium
+ - https://drive.usercontent.google.com/download?id=1s_CoaLNJEr-UxHYiqZ5jOcliMCFiRUUy&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1Wpj6WVpZ5Ily77kVwfQ18CiZBzkgmEnF&authuser=0&confirm=t
+ # train_normal
+ - https://drive.usercontent.google.com/download?id=1jPt44arlAswl9cXZjzmVcdpptdTPpJ3I&authuser=0&confirm=t
+ - https://drive.usercontent.google.com/download?id=1Rfc5kE03AzOUv7B_eYcBhUV8KMQ2MZ1m&authuser=0&confirm=t
+filenames:
+ # train_challenging
+ - ../Union14M-L-LMDB-Filtered/train_challenging/data.mdb
+ - ../Union14M-L-LMDB-Filtered/train_challenging/lock.mdb
+ # train_easy
+ - ../Union14M-L-LMDB-Filtered/train_easy/data.mdb
+ - ../Union14M-L-LMDB-Filtered/train_easy/lock.mdb
+ # train_hard
+ - ../Union14M-L-LMDB-Filtered/train_hard/data.mdb
+ - ../Union14M-L-LMDB-Filtered/train_hard/lock.mdb
+ # train_medium
+ - ../Union14M-L-LMDB-Filtered/train_medium/data.mdb
+ - ../Union14M-L-LMDB-Filtered/train_medium/lock.mdb
+ # train_normal
+ - ../Union14M-L-LMDB-Filtered/train_normal/data.mdb
+ - ../Union14M-L-LMDB-Filtered/train_normal/lock.mdb
+check_validity: true
diff --git a/configs/det/dbnet/repvit_db.yml b/configs/det/dbnet/repvit_db.yml
index c9b1bc1..689f655 100644
--- a/configs/det/dbnet/repvit_db.yml
+++ b/configs/det/dbnet/repvit_db.yml
@@ -53,7 +53,7 @@ Architecture:
PostProcess:
name: DBPostProcess
thresh: 0.3
- box_thresh: 0.4
+ box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
score_mode: 'slow'
diff --git a/configs/rec/dptr/dptr_parseq_pretrain.yml b/configs/rec/dptr/dptr_parseq_pretrain.yml
new file mode 100644
index 0000000..32b8adf
--- /dev/null
+++ b/configs/rec/dptr/dptr_parseq_pretrain.yml
@@ -0,0 +1,88 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: /share/ckpt/zhaoshuai/openocr/dptr_parseq/
+ eval_epoch_step: [0, 1]
+ eval_batch_step: [0, 500]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ use_amp: True
+ save_res_path: /share/ckpt/zhaoshuai/openocr/dptr_parseq/predicts_dptr_parseq.txt
+ grad_clip_val: 20
+
+Optimizer:
+ name: AdamW
+ lr: 0.001485 # 2gpus 384bs/gpu
+ weight_decay: 0.
+ filter_bias_and_bn: False
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: DPTR
+ Decoder:
+ name: DptrParseq
+ decode_ar: True
+ refine_iters: 1
+ is_pretrain: True
+ ORP_path: /share/ckpt/zhaoshuai/parseq/clip_background.pth
+
+Loss:
+ name: PARSeqLoss
+
+PostProcess:
+ name: ARLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: TextLMDBDataSet
+ data_dir: /share/test/zhaoshuai/parseq-data/data/train/real/ArT
+ transforms:
+ - DPTRLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['clip_label', 'label'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: TextLMDBDataSet
+ data_dir: /share/test/zhaoshuai/parseq-data/data/val
+ transforms:
+ - DPTRLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['clip_label', 'label'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 2
diff --git a/configs/rec/igtr/readme.md b/configs/rec/igtr/readme.md
index 61c7d33..dc420f7 100644
--- a/configs/rec/igtr/readme.md
+++ b/configs/rec/igtr/readme.md
@@ -13,11 +13,12 @@
Paper:
-> [Instruction-Guided Scene Text Recognition](https://arxiv.org/abs/2401.17851)
-> Yongkun Du, Zhineng Chen, Yuchen Su, Caiyan Jia, Yu-Gang Jiang
+> [Instruction-Guided Scene Text Recognition](https://arxiv.org/abs/2401.17851),
+> Yongkun Du, Zhineng Chen, Yuchen Su, Caiyan Jia, Yu-Gang Jiang,
+> TPAMI
-Multi-modal models show appealing performance in visual recognition tasks recently, as free-form text-guided training evokes the ability to understand fine-grained visual content. However, current models are either inefficient or cannot be trivially upgraded to scene text recognition (STR) due to the composition difference between natural and text images. We propose a novel instruction-guided scene text recognition (IGTR) paradigm that formulates STR as an instruction learning problem and understands text images by predicting character attributes, e.g., character frequency, position, etc. IGTR first devises $\\left \\langle condition,question,answer\\right \\rangle$ instruction triplets, providing rich and diverse descriptions of character attributes. To effectively learn these attributes through question-answering, IGTR develops lightweight instruction encoder, cross-modal feature fusion module and multi-task answer head, which guides nuanced text image understanding. Furthermore, IGTR realizes different recognition pipelines simply by using different instructions, enabling a character-understanding-based text reasoning paradigm that considerably differs from current methods. Experiments on English and Chinese benchmarks show that IGTR outperforms existing models by significant margins, while maintaining a small model size and efficient inference speed. Moreover, by adjusting the sampling of instructions, IGTR offers an elegant way to tackle the recognition of both rarely appearing and morphologically similar characters, which were previous challenges.
+Multi-modal models have shown appealing performance in visual recognition tasks, as free-form text-guided training evokes the ability to understand fine-grained visual content. However, current models cannot be trivially applied to scene text recognition (STR) due to the compositional difference between natural and text images. We propose a novel instruction-guided scene text recognition (IGTR) paradigm that formulates STR as an instruction learning problem and understands text images by predicting character attributes, e.g., character frequency, position, etc. IGTR first devises $\\left \\langle condition,question,answer\\right \\rangle$ instruction triplets, providing rich and diverse descriptions of character attributes. To effectively learn these attributes through question-answering, IGTR develops a lightweight instruction encoder, a cross-modal feature fusion module and a multi-task answer head, which guides nuanced text image understanding. Furthermore, IGTR realizes different recognition pipelines simply by using different instructions, enabling a character-understanding-based text reasoning paradigm that differs from current methods considerably. Experiments on English and Chinese benchmarks show that IGTR outperforms existing models by significant margins, while maintaining a small model size and fast inference speed. Moreover, by adjusting the sampling of instructions, IGTR offers an elegant way to tackle the recognition of rarely appearing and morphologically similar characters, which were previous challenges.
The accuracy (%) and model files of IGTR on the public dataset of scene text recognition are as follows:
@@ -88,11 +89,11 @@ pip install -r requirements.txt
#### Dataset Preparation
-[English dataset download](https://github.com/baudm/parseq)
+- [English dataset download](https://github.com/baudm/parseq)
-[Union14M-L download](https://github.com/Mountchicken/Union14M)
+- [Union14M-L-LMDB-Filtered download](https://drive.google.com/drive/folders/1OlDWJZgvd6s4S09S3IGeAI90jI0i7AB_?usp=sharing)
-[Chinese dataset download](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download)
+- [Chinese dataset download](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download)
The expected filesystem structure is as follows:
@@ -143,7 +144,7 @@ u14m # lmdb format
├── multi_oriented
├── multi_words
└── salient
-Union14M-LMDB-L # lmdb format
+Union14M-L-LMDB-Filtered # lmdb format
├── train_challenging
├── train_easy
├── train_hard
@@ -168,13 +169,15 @@ Evaluation:
```shell
# The configuration file is available from the link provided in the table above.
# en
-python tools/eval_rec_all_ratio.py --c PATH/svtr_base_igtr_syn.yml
+python tools/eval_rec_all_en.py --c PATH/svtr_base_igtr_syn.yml
# ch
python tools/eval_rec_all_ch.py --c PATH/svtr_base_igtr_ch_aug.yml
```
## Citation
+If you find our method useful for your reserach, please cite:
+
```bibtex
@article{Du2024IGTR,
title = {Instruction-Guided Scene Text Recognition},
diff --git a/configs/rec/smtr/readme.md b/configs/rec/smtr/readme.md
index ddda487..52bbedd 100644
--- a/configs/rec/smtr/readme.md
+++ b/configs/rec/smtr/readme.md
@@ -13,8 +13,9 @@
Paper:
-> [Out of Length Text Recognition with Sub-String Matching](https://arxiv.org/abs/2407.12317)
-> Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang
+> [Out of Length Text Recognition with Sub-String Matching](https://arxiv.org/abs/2407.12317).
+> Yongkun Du, Zhineng Chen\*, Caiyan Jia, Xieping Gao, Yu-Gang Jiang.
+> AAAI 2025
Scene Text Recognition (STR) methods have demonstrated robust performance in word-level text recognition. However, in applications the text image is sometimes long due to detected with multiple horizontal words. It triggers the requirement to build long text recognition models from readily available short word-level text datasets, which has been less studied previously. In this paper, we term this the Out of Length (OOL) text recognition. We establish the first Long Text Benchmark (LTB) to facilitate the assessment of different methods in long text recognition. Meanwhile, we propose a novel method called OOL Text Recognition with sub-String Matching (SMTR). SMTR comprises two cross-attention-based modules: one encodes a sub-string containing multiple characters into next and previous queries, and the other employs the queries to attend to the image features, matching the sub-string and simultaneously recognizing its next and previous character. SMTR can recognize text of arbitrary length by iterating the process above. To avoid being trapped in recognizing highly similar sub-strings, we introduce a regularization training to compel SMTR to effectively discover subtle differences between similar sub-strings for precise matching. In addition, we propose an inference augmentation to alleviate confusion caused by identical sub-strings and improve the overall recognition efficiency. Extensive experimental results reveal that SMTR, even when trained exclusively on short text, outperforms existing methods in public short text benchmarks and exhibits a clear advantage on LTB.
@@ -23,31 +24,31 @@ The accuracy (%) and model files of SMTR on the public dataset of scene text rec
- Syn: Synth dataset(MJ+ST) from [PARSeq](https://github.com/baudm/parseq)
-- U14M: Union14M-L from [Union14M](https://github.com/Mountchicken/Union14M/)
+- Union14M-L-LMDB-Filtered: A filtered version of [Union14M](https://github.com/Mountchicken/Union14M/)
- Test on Long Text Benchmark ([Download LTB](https://drive.google.com/drive/folders/1NChdlw7ustbXtlFBmh_0xnHvRkffb9Ge?usp=sharing)):
-| Model | Training Data | LTB | Config&Model&Log |
-| :-------: | :-----------: | :--: | :---------------------------------------------------------------------------------------------: |
-| SMTR | Syn | 39.6 | [link](https://drive.google.com/drive/folders/11SplakPPOFDMhPixv7ABNgjeTg4jKyfU?usp=sharing) |
-| SMTR | U14M | 51.0 | [link](https://drive.google.com/drive/folders/1-K5O0d0q9fhY5fJvU6nn5fFFtSMnbE_-?usp=drive_link) |
-| FocalSVTR | U14M | 42.1 | [link](https://drive.google.com/drive/folders/100xF5wFr7xSCVBYM1h_0d_8xv5Qeqobp?usp=sharing) |
+| Model | Training Data | LTB | Config&Model&Log |
+| :-------: | :----------------------: | :--: | :---------------------------------------------------------------------------------------------: |
+| SMTR | Syn | 39.6 | [link](https://drive.google.com/drive/folders/11SplakPPOFDMhPixv7ABNgjeTg4jKyfU?usp=sharing) |
+| SMTR | Union14M-L-LMDB-Filtered | 51.0 | [link](https://drive.google.com/drive/folders/1-K5O0d0q9fhY5fJvU6nn5fFFtSMnbE_-?usp=drive_link) |
+| FocalSVTR | Union14M-L-LMDB-Filtered | 42.1 | [link](https://drive.google.com/drive/folders/100xF5wFr7xSCVBYM1h_0d_8xv5Qeqobp?usp=sharing) |
- Test on Common Benchmarks from [PARSeq](https://github.com/baudm/parseq):
-| Model | Training Data | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log |
-| :-------: | :-----------: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :---------------------: |
-| SMTR | Syn | 97.4 | 94.9 | 97.4 | 88.4 | 89.9 | 96.2 | 94.02 | Same as the above table |
-| SMTR | U14M | 98.3 | 97.4 | 99.0 | 90.1 | 92.7 | 97.9 | 95.90 | Same as the above table |
-| FocalSVTR | U14M | 97.3 | 96.3 | 98.2 | 87.4 | 88.4 | 96.2 | 93.97 | Same as the above table |
+| Model | Training Data | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log |
+| :-------: | :----------------------: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :---------------------: |
+| SMTR | Syn | 97.4 | 94.9 | 97.4 | 88.4 | 89.9 | 96.2 | 94.02 | Same as the above table |
+| SMTR | Union14M-L-LMDB-Filtered | 98.3 | 97.4 | 99.0 | 90.1 | 92.7 | 97.9 | 95.90 | Same as the above table |
+| FocalSVTR | Union14M-L-LMDB-Filtered | 97.3 | 96.3 | 98.2 | 87.4 | 88.4 | 96.2 | 93.97 | Same as the above table |
- Test on Union14M-L benchmark from [Union14M](https://github.com/Mountchicken/Union14M/).
-| Model | Traing Data | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log |
-| :-------: | :---------: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: |
-| SMTR | Syn | 74.2 | 30.6 | 58.5 | 67.6 | 79.6 | 75.1 | 67.9 | 64.79 | Same as the above table |
-| SMTR | U14M | 89.1 | 87.7 | 76.8 | 83.9 | 84.6 | 89.3 | 83.7 | 85.00 | Same as the above table |
-| FocalSVTR | U14M | 77.7 | 62.4 | 65.7 | 78.6 | 71.6 | 81.3 | 79.2 | 73.80 | Same as the above table |
+| Model | Traing Data | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log |
+| :-------: | :----------------------: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: |
+| SMTR | Syn | 74.2 | 30.6 | 58.5 | 67.6 | 79.6 | 75.1 | 67.9 | 64.79 | Same as the above table |
+| SMTR | Union14M-L-LMDB-Filtered | 89.1 | 87.7 | 76.8 | 83.9 | 84.6 | 89.3 | 83.7 | 85.00 | Same as the above table |
+| FocalSVTR | Union14M-L-LMDB-Filtered | 77.7 | 62.4 | 65.7 | 78.6 | 71.6 | 81.3 | 79.2 | 73.80 | Same as the above table |
- Training and test on Chinese dataset, from [Chinese Benckmark](https://github.com/FudanVI/benchmarking-chinese-text-recognition).
@@ -79,7 +80,7 @@ pip install -r requirements.txt
- [English dataset download](https://github.com/baudm/parseq)
-- [Union14M-L download](https://github.com/Mountchicken/Union14M)
+- [Union14M-L-LMDB-Filtered download](https://drive.google.com/drive/folders/1OlDWJZgvd6s4S09S3IGeAI90jI0i7AB_?usp=sharing)
- [Chinese dataset download](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download)
@@ -135,7 +136,7 @@ u14m # lmdb format
├── multi_words
└── salient
ltb # download link: https://drive.google.com/drive/folders/1NChdlw7ustbXtlFBmh_0xnHvRkffb9Ge?usp=sharing
-Union14M-LMDB-L # lmdb format
+Union14M-L-LMDB-Filtered # lmdb format
├── train_challenging
├── train_easy
├── train_hard
@@ -160,7 +161,7 @@ Evaluation:
```shell
# en
-python tools/eval_rec_all_ratio.py --c configs/rec/smtr/focalsvtr_smtr.yml
+python tools/eval_rec_all_en.py --c configs/rec/smtr/focalsvtr_smtr.yml
# long text
python tools/eval_rec_all_long_simple.py --c configs/rec/smtr/focalsvtr_smtr_long.yml
# ch
@@ -169,6 +170,8 @@ python tools/eval_rec_all_ch.py --c configs/rec/smtr/focalsvtr_smtr_ch.yml
## Citation
+If you find our method useful for your reserach, please cite:
+
```bibtex
@article{Du2024SMTR,
title = {Out of Length Text Recognition with Sub-String Matching},
diff --git a/configs/rec/svtrv2/SVTRv2.pdf b/configs/rec/svtrv2/SVTRv2.pdf
deleted file mode 100644
index 0be5128..0000000
Binary files a/configs/rec/svtrv2/SVTRv2.pdf and /dev/null differ
diff --git a/configs/rec/svtrv2/readme.md b/configs/rec/svtrv2/readme.md
index aade0b5..31abbf6 100644
--- a/configs/rec/svtrv2/readme.md
+++ b/configs/rec/svtrv2/readme.md
@@ -18,7 +18,7 @@
Paper:
-> [SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition](./SVTRv2.pdf)
+> [SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition](https://arxiv.org/abs/2411.15858)
> Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang
@@ -30,11 +30,11 @@ The accuracy (%) and model files of SVTRv2 on the public dataset of scene text r
Download all Configs, Models, and Logs from [Google Drive](https://drive.google.com/drive/folders/1i2EZVT-oxfDIDdhwQRm9E6Fk8s6qD3C1?usp=sharing).
-| Model | Model size | FPS |
-| :------: | :--------: | :-: |
-| SVTRv2-T | 5.13 | 5.0 |
-| SVTRv2-S | 11.25 | 5.3 |
-| SVTRv2-B | 19.76 | 7.0 |
+| Model | Model size | Latency |
+| :------: | :--------: | :-----: |
+| SVTRv2-T | 5.13 | 5.0 |
+| SVTRv2-S | 11.25 | 5.3 |
+| SVTRv2-B | 19.76 | 7.0 |
- Test on Common Benchmarks from [PARSeq](https://github.com/baudm/parseq):
@@ -102,10 +102,10 @@ Referring to [Downloading Datasets](../../../docs/svtrv2.md#downloading-datasets
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_rctc.yml
# Second stage
-CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Global.pretrained_model=./output/rec/u14m_filter/svtrv2_rctc/best.pth
+CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch --master_port=23332 --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Global.pretrained_model=./output/rec/u14m_filter/svtrv2_rctc/best.pth
# For Multi RTX 4090
-NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_rctc.yml
+NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port=23333 --nproc_per_node=4 tools/train_rec.py --c configs/rec/svtrv2/svtrv2_rctc.yml
# 20epoch runs for about 6 hours
```
@@ -113,9 +113,10 @@ NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.laun
```shell
# short text: Common, Union14M-Benchmark, OST
-python tools/eval_rec_all_ratio.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml
-# long text
-python tools/eval_rec_all_long_simple.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml
+python tools/eval_rec_all_en.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml
+
+# long text: LTB
+python tools/eval_rec_all_long.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml --o Eval.loader.max_ratio=20
```
After a successful run, the results are saved in a csv file in `output_dir` in the config file.
@@ -123,7 +124,7 @@ After a successful run, the results are saved in a csv file in `output_dir` in t
### Inference
```shell
-python tools/infer_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Global.infer_img=/path/img_fold or /path/img_file
+python tools/infer_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml --o Global.infer_img=/path/img_fold or /path/img_file
```
### Latency Measurement
@@ -131,15 +132,22 @@ python tools/infer_rec.py --c configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml --o Gl
Firstly, downloading the IIIT5K images from [Google Drive](https://drive.google.com/drive/folders/1Po1LSBQb87DxGJuAgLNxhsJ-pdXxpIfS?usp=drive_link). Then, running the following command:
```shell
-python tools/infer_rec.py --c configs/rec/SVTRv2/svtrv2_rctc.yml --o Global.infer_img=../iiit5k_test_image
+python tools/infer_rec.py --c configs/rec/SVTRv2/svtrv2_smtr_gtc_rctc_infer.yml --o Global.infer_img=../iiit5k_test_image
```
## Citation
+If you find our method useful for your reserach, please cite:
+
```bibtex
-@article{Du2024SVTRv4,
- title = {SVTRv2: Scene Text Recognition with a Single Visual Model},
- author = {Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang},
- year = {2024}
+@article{Du2024SVTRv2,
+ title={SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition},
+ author={Yongkun Du and Zhineng Chen and Hongtao Xie and Caiyan Jia and Yu-Gang Jiang},
+ journal={CoRR},
+ volume={abs/2411.15858},
+ eprinttype={arXiv},
+ year={2024},
+ primaryClass={cs.CV},
+ url={https://arxiv.org/abs/2411.15858}
}
```
diff --git a/configs/rec/svtrv2/repsvtr_ch.yml b/configs/rec/svtrv2/repsvtr_ch.yml
index 033d6e7..fd68312 100644
--- a/configs/rec/svtrv2/repsvtr_ch.yml
+++ b/configs/rec/svtrv2/repsvtr_ch.yml
@@ -34,7 +34,7 @@ LRScheduler:
Architecture:
model_type: rec
- algorithm: SVTRv2
+ algorithm: SVTRv2_mobile
Transform:
Encoder:
name: RepSVTREncoder
@@ -53,6 +53,7 @@ Loss:
PostProcess:
name: CTCLabelDecode
+ character_dict_path: *character_dict_path
Metric:
name: RecMetric
diff --git a/configs/rec/svtrv2/svtrv2_ch.yml b/configs/rec/svtrv2/svtrv2_ch.yml
index a6538df..6ee2e8f 100644
--- a/configs/rec/svtrv2/svtrv2_ch.yml
+++ b/configs/rec/svtrv2/svtrv2_ch.yml
@@ -34,7 +34,7 @@ LRScheduler:
Architecture:
model_type: rec
- algorithm: SVTRv2
+ algorithm: SVTRv2_server
Transform:
Encoder:
name: SVTRv2LNConvTwo33
@@ -65,6 +65,7 @@ Loss:
PostProcess:
name: CTCLabelDecode
+ character_dict_path: *character_dict_path
Metric:
name: RecMetric
diff --git a/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml
index aa36492..20a9d7c 100644
--- a/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml
+++ b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc.yml
@@ -3,7 +3,7 @@ Global:
epoch_num: 20
log_smooth_window: 20
print_batch_step: 10
- output_dir: ./output/rec/u14m_filter/svtrv2_smtr_gtc_rctc_maxratio12
+ output_dir: ./output/rec/u14m_filter/svtrv2_smtr_gtc_rctc
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 500]
diff --git a/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml
new file mode 100644
index 0000000..2361eb9
--- /dev/null
+++ b/configs/rec/svtrv2/svtrv2_smtr_gtc_rctc_infer.yml
@@ -0,0 +1,151 @@
+Global:
+ device: gpu
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ output_dir: ./output/rec/u14m_filter/svtrv2_smtr_gtc_rctc
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations
+ eval_batch_step: [0, 500]
+ eval_epoch_step: [0, 1]
+ cal_metric_during_train: True
+ pretrained_model:
+ # ./output/rec/u14m_filter/svtrv2_rctc/best.pth
+ checkpoints:
+ use_tensorboard: false
+ infer_img:
+ # for data or label process
+ character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt # 96en
+ # ./tools/utils/ppocr_keys_v1.txt # ch
+ max_text_length: &max_text_length 25
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/u14m_filter/predicts_svtrv2_smtr_gtc_rctc.txt
+ use_amp: True
+
+Optimizer:
+ name: AdamW
+ lr: 0.000325 # for 4gpus bs128/gpu
+ weight_decay: 0.05
+ filter_bias_and_bn: True
+
+LRScheduler:
+ name: OneCycleLR
+ warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep
+ cycle_momentum: False
+
+Architecture:
+ model_type: rec
+ algorithm: SVTRv2
+ in_channels: 3
+ Transform:
+ Encoder:
+ name: SVTRv2LNConvTwo33
+ use_pos_embed: False
+ dims: [128, 256, 384]
+ depths: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','FGlobal','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
+ local_k: [[5, 5], [5, 5], [-1, -1]]
+ sub_k: [[1, 1], [2, 1], [-1, -1]]
+ last_stage: false
+ feat2d: True
+ Decoder:
+ name: GTCDecoder
+ infer_gtc: False
+ detach: False
+ gtc_decoder:
+ name: SMTRDecoder
+ num_layer: 1
+ ds: True
+ max_len: *max_text_length
+ next_mode: &next True
+ sub_str_len: &subsl 5
+ ctc_decoder:
+ name: RCTCDecoder
+
+Loss:
+ name: CTCLoss
+ zero_infinity: True
+
+PostProcess:
+ name: CTCLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+ is_filter: True
+
+Train:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: false
+ data_dir_list: ['../Union14M-L-LMDB-Filtered/filter_train_challenging',
+ '../Union14M-L-LMDB-Filtered/filter_train_hard',
+ '../Union14M-L-LMDB-Filtered/filter_train_medium',
+ '../Union14M-L-LMDB-Filtered/filter_train_normal',
+ '../Union14M-L-LMDB-Filtered/filter_train_easy',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - PARSeqAugPIL:
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: &bs 128
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: True
+ loader:
+ shuffle: True
+ batch_size_per_card: *bs
+ drop_last: True
+ max_ratio: &max_ratio 12
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSetTVResize
+ ds_width: True
+ padding: False
+ data_dir_list: [
+ '../evaluation/CUTE80',
+ '../evaluation/IC13_857',
+ '../evaluation/IC15_1811',
+ '../evaluation/IIIT5k',
+ '../evaluation/SVT',
+ '../evaluation/SVTP',
+ ]
+ transforms:
+ - DecodeImagePIL: # load image
+ img_mode: RGB
+ - CTCLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: *bs
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: *bs
+ max_ratio: *max_ratio
+ num_workers: 4
diff --git a/demo_gradio.py b/demo_gradio.py
new file mode 100644
index 0000000..a40ba0e
--- /dev/null
+++ b/demo_gradio.py
@@ -0,0 +1,184 @@
+# @Author: OpenOCR
+# @Contact: 784990967@qq.com
+import os
+import gradio as gr # gradio==4.20.0
+
+os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
+import cv2
+import numpy as np
+import json
+import time
+from PIL import Image
+from tools.infer_e2e import OpenOCR, check_and_download_font, draw_ocr_box_txt
+
+drop_score = 0.4
+text_sys = OpenOCR(drop_score=drop_score)
+# warm up 5 times
+if True:
+ img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
+ for i in range(5):
+ res = text_sys(img_numpy=img)
+font_path = './simfang.ttf'
+font_path = check_and_download_font(font_path)
+
+
+def main(input_image,
+ det_input_size_textbox=960,
+ rec_drop_score=0.01,
+ mask_thresh=0.3,
+ box_thresh=0.6,
+ unclip_ratio=1.5,
+ det_score_mode='slow'):
+ img = input_image[:, :, ::-1]
+ starttime = time.time()
+ results, time_dict, mask = text_sys(
+ img_numpy=img,
+ return_mask=True,
+ det_input_size=int(det_input_size_textbox),
+ thresh=mask_thresh,
+ box_thresh=box_thresh,
+ unclip_ratio=unclip_ratio,
+ score_mode=det_score_mode)
+ elapse = time.time() - starttime
+ save_pred = json.dumps(results[0], ensure_ascii=False)
+ image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+ boxes = [res['points'] for res in results[0]]
+ txts = [res['transcription'] for res in results[0]]
+ scores = [res['score'] for res in results[0]]
+ draw_img = draw_ocr_box_txt(
+ image,
+ boxes,
+ txts,
+ scores,
+ drop_score=rec_drop_score,
+ font_path=font_path,
+ )
+ mask = mask[0, 0, :, :] > mask_thresh
+ return save_pred, elapse, draw_img, mask.astype('uint8') * 255
+
+
+def get_all_file_names_including_subdirs(dir_path):
+ all_file_names = []
+
+ for root, dirs, files in os.walk(dir_path):
+ for file_name in files:
+ all_file_names.append(os.path.join(root, file_name))
+
+ file_names_only = [os.path.basename(file) for file in all_file_names]
+ return file_names_only
+
+
+def list_image_paths(directory):
+ image_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff')
+
+ image_paths = []
+
+ for root, dirs, files in os.walk(directory):
+ for file in files:
+ if file.lower().endswith(image_extensions):
+ relative_path = os.path.relpath(os.path.join(root, file),
+ directory)
+ full_path = os.path.join(directory, relative_path)
+ image_paths.append(full_path)
+ image_paths = sorted(image_paths)
+ return image_paths
+
+
+def find_file_in_current_dir_and_subdirs(file_name):
+ for root, dirs, files in os.walk('.'):
+ if file_name in files:
+ relative_path = os.path.join(root, file_name)
+ return relative_path
+
+
+e2e_img_example = list_image_paths('./OCR_e2e_img')
+
+if __name__ == '__main__':
+ css = '.image-container img { width: 100%; max-height: 320px;}'
+
+ with gr.Blocks(css=css) as demo:
+ gr.HTML("""
+
准确高效的通用 OCR 系统 (由FVL实验室 OCR Team 创建)
""" + ) + with gr.Row(): + with gr.Column(scale=1): + input_image = gr.Image(label='Input image', + elem_classes=['image-container']) + + examples = gr.Examples(examples=e2e_img_example, + inputs=input_image, + label='Examples') + downstream = gr.Button('Run') + + # 添加参数调节组件 + with gr.Column(): + with gr.Row(): + det_input_size_textbox = gr.Number( + label='Det Input Size', + value=960, + info='检测网络输入尺寸的最长边,默认为960。') + det_score_mode_dropdown = gr.Dropdown( + ['slow', 'fast'], + value='slow', + label='Det Score Mode', + info= + '文本框的置信度计算模式,默认为 slow。slow 模式计算速度较慢,但准确度较高。fast 模式计算速度较快,但准确度较低。' + ) + with gr.Row(): + rec_drop_score_slider = gr.Slider( + 0.0, + 1.0, + value=0.4, + step=0.01, + label='Recognition Drop Score', + info='识别置信度阈值,默认值为0.4。低于该阈值的识别结果和对应的文本框被丢弃。') + mask_thresh_slider = gr.Slider( + 0.0, + 1.0, + value=0.3, + step=0.01, + label='Mask Threshold', + info='Mask 阈值,用于二值化 mask,默认值为0.3。如果存在文本截断时,请调低该值。') + with gr.Row(): + box_thresh_slider = gr.Slider( + 0.0, + 1.0, + value=0.6, + step=0.01, + label='Box Threshold', + info='文本框置信度阈值,默认值为0.6。如果存在文本被漏检时,请调低该值。') + unclip_ratio_slider = gr.Slider( + 1.5, + 2.0, + value=1.5, + step=0.05, + label='Unclip Ratio', + info='文本框解析时的膨胀系数,默认值为1.5。值越大文本框越大。') + + with gr.Column(scale=1): + img_mask = gr.Image(label='mask', + interactive=False, + elem_classes=['image-container']) + img_output = gr.Image(label=' ', + interactive=False, + elem_classes=['image-container']) + + output = gr.Textbox(label='Result') + confidence = gr.Textbox(label='Latency') + + downstream.click(fn=main, + inputs=[ + input_image, det_input_size_textbox, + rec_drop_score_slider, mask_thresh_slider, + box_thresh_slider, unclip_ratio_slider, + det_score_mode_dropdown + ], + outputs=[ + output, + confidence, + img_output, + img_mask, + ]) + + demo.launch(share=True) diff --git a/demo_rec.py b/demo_rec.py deleted file mode 100644 index a1c604a..0000000 --- a/demo_rec.py +++ /dev/null @@ -1,131 +0,0 @@ -import os - -import gradio as gr -import numpy as np -import torch - -from openrec.modeling import build_model -from openrec.postprocess import build_post_process -from openrec.preprocess import create_operators, transform -from tools.engine import Config -from tools.utils.ckpt import load_ckpt - - -def build_rec_process(cfg): - transforms = [] - for op in cfg['Eval']['dataset']['transforms']: - op_name = list(op)[0] - if 'Label' in op_name: - continue - # TODO - elif op_name in ['DecodeImage']: - op[op_name]['gradio_infer_mode'] = True - - elif op_name in ['RecResizeImg']: - op[op_name]['infer_mode'] = True - elif op_name == 'KeepKeys': - if cfg['Architecture']['algorithm'] == 'SRN': - op[op_name]['keep_keys'] = [ - 'image', - 'encoder_word_pos', - 'gsrm_word_pos', - 'gsrm_slf_attn_bias1', - 'gsrm_slf_attn_bias2', - ] - elif cfg['Architecture']['algorithm'] == 'SAR': - op[op_name]['keep_keys'] = ['image', 'valid_ratio'] - elif cfg['Architecture']['algorithm'] == 'RobustScanner': - op[op_name]['keep_keys'] = [ - 'image', 'valid_ratio', 'word_positons' - ] - else: - op[op_name]['keep_keys'] = ['image'] - transforms.append(op) - return transforms - - -def get_all_file_names_including_subdirs(dir_path): - all_file_names = [] - - for root, dirs, files in os.walk(dir_path): - for file_name in files: - all_file_names.append(os.path.join(root, file_name)) - - file_names_only = [os.path.basename(file) for file in all_file_names] - return file_names_only - - -root_directory = './configs/rec' -yml_Config = get_all_file_names_including_subdirs(root_directory) - - -def find_file_in_current_dir_and_subdirs(file_name): - for root, dirs, files in os.walk('.'): - if file_name in files: - relative_path = os.path.join(root, file_name) - return relative_path - - -def predict(input_image, Model_type, OCR_type): - - path = find_file_in_current_dir_and_subdirs(Model_type) - - cfg = Config(path).cfg - post_process_class = build_post_process(cfg['PostProcess']) - global_config = cfg['Global'] - char_num = len(getattr(post_process_class, 'character')) - cfg['Architecture']['Decoder']['out_channels'] = char_num - model = build_model(cfg['Architecture']) - load_ckpt(model, cfg) - model.eval() - - transforms = build_rec_process(cfg) - global_config['infer_mode'] = True - ops = create_operators(transforms, global_config) - data = {'image': input_image} - batch = transform(data, ops) - others = None - images = np.expand_dims(batch[0], axis=0) - images = torch.from_numpy(images) - with torch.no_grad(): - preds = model(images, others) - post_result = post_process_class(preds) - return post_result[0][0], post_result[0][1] - - -if __name__ == '__main__': - - with gr.Blocks() as demo: - with gr.Row(): - with gr.Column(scale=1): - input_image = gr.Image(label='Input Image') - - # TODO - OCR_type = gr.Radio(['STR', 'STD', 'E2E'], label='模型类别') - - Model_type = gr.Dropdown(choices=yml_Config, label='现有模型配置文件') - - downstream = gr.Button('识别结果') - - with gr.Column(scale=1): - - # TODO - img_output = gr.Image(label='图片识别结果') - - output = gr.Textbox(label='文字识别结果') - confidence = gr.Textbox(label='置信度') - - downstream.click( - fn=predict, - inputs=[ - input_image, - Model_type, - OCR_type, - ], - outputs=[ - output, - confidence, - # TODO img_output, - ]) - - demo.launch(debug=True) diff --git a/docs/openocr.md b/docs/openocr.md index 4129747..2eb7b80 100644 --- a/docs/openocr.md +++ b/docs/openocr.md @@ -1,29 +1,34 @@ -# OpenOCR: A general OCR system for accuracy and efficiency +# OpenOCR: A general OCR system with accuracy and efficiency -We proposed strategies to comprehensively enhance CTC-based STR models and developed a novel CTC-based method, [SVTRv2](../configs/rec/svtrv2/). SVTRv2 can outperform previous attention-based STR methods in terms of accuracy while maintaining the advantages of CTC, such as fast inference and robust recognition of long text sequences. These features make SVTRv2 particularly well-suited for commercial applications. To this end, building on SVTRv2, we develop a practical version of the model from scratch on publicly available Chinese and English datasets. Combined with a detection model, this forms an accurate and efficient general OCR system, OpenOCR. Comparing with PP-OCRv4 released by PaddleOCR, OpenOCR achieve a 4.5% improvement on the [OCR competition leaderboard](https://aistudio.baidu.com/competition/detail/1131/0/leaderboard). +⚡\[[Quick Start](#quick-start)\] \[[Model](https://github.com/Topdu/OpenOCR/releases/tag/develop0.0.1)\] \[[ModelScope Demo](https://modelscope.cn/studios/topdktu/OpenOCR-Demo)\] \[[Hugging Face Demo](https://huggingface.co/spaces/topdu/OpenOCR-Demo)\] \[[Local Demo](#local-demo)\] \[[PaddleOCR Implementation](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/text_recognition/algorithm_rec_svtrv2.html)\] + +We proposed strategies to comprehensively enhance CTC-based STR models and developed a novel CTC-based method, [SVTRv2](../configs/rec/svtrv2/). SVTRv2 can outperform previous attention-based STR methods in terms of accuracy while maintaining the advantages of CTC, such as fast inference and robust recognition of long text. These features make SVTRv2 particularly well-suited for practical applications. To this end, building on SVTRv2, we develop a practical version of the model from scratch on publicly available Chinese and English datasets. Combined with a detection model, this forms a general OCR system with accuracy and efficiency, **OpenOCR**. Comparing with [PP-OCRv4](https://paddlepaddle.github.io/PaddleOCR/latest/ppocr/model_list.html) baseline in the [OCR competition leaderboard](https://aistudio.baidu.com/competition/detail/1131/0/leaderboard), OpenOCR (mobile) achieve a 4.5% improvement in terms of accuracy, while preserving quite similar inference speed on NVIDIA 1080Ti GPU. | Model | Config | E2E Metric | Downloading | | ------------------- | ----------------------------------------------------------------------------------- | ---------- | ---------------------------------------------------------------------------------------- | | PP-OCRv4 | | 62.77% | [PaddleOCR Model List](../../ppocr/model_list.md) | -| SVTRv2 (Rec Server) | [configs/rec/svtrv2/svtrv2_ch.yml](../configs/rec/svtrv2/svtrv2_ch.yml) | 68.81% | [Google Dirve ](https://drive.google.com/file/d/13LXbIVEyx2Aat3X_vVte4JQgQ7yJWdxH/view?usp=drive_link) | -| RepSVTR (Mobile) | [Rec: configs/rec/svtrv2/repsvtr_ch.yml](../configs/rec/svtrv2/repsvtr_ch.yml)Method | @@ -1678,10 +1647,17 @@ Downloading all model checkpoints from [Google Drive](<>) and [Baidu Yun](<>). ## Citation +If you find our method useful for your reserach, please cite: + ```bibtex -@article{Du2024SVTRv4, - title = {SVTRv2: Scene Text Recognition with a Single Visual Model}, - author = {Yongkun Du, Zhineng Chen\*, Hongtao Xie, Caiyan Jia, Yu-Gang Jiang}, - year = {2024} +@article{Du2024SVTRv2, + title={SVTRv2: CTC Beats Encoder-Decoder Models in Scene Text Recognition}, + author={Yongkun Du and Zhineng Chen and Hongtao Xie and Caiyan Jia and Yu-Gang Jiang}, + journal={CoRR}, + volume={abs/2411.15858}, + eprinttype={arXiv}, + year={2024}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2411.15858} } ``` diff --git a/opendet/postprocess/db_postprocess.py b/opendet/postprocess/db_postprocess.py index dd6b199..60c1e7f 100644 --- a/opendet/postprocess/db_postprocess.py +++ b/opendet/postprocess/db_postprocess.py @@ -208,7 +208,12 @@ def box_score_slow(self, bitmap, contour): cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype('int32'), 1) return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] - def __call__(self, outs_dict, shape_list): + def __call__(self, outs_dict, shape_list, **kwargs): + self.thresh = kwargs.get('mask_thresh', self.thresh) + self.box_thresh = kwargs.get('box_thresh', self.box_thresh) + self.unclip_ratio = kwargs.get('unclip_ratio', self.unclip_ratio) + self.box_type = kwargs.get('box_type', self.box_type) + self.score_mode = kwargs.get('score_mode', self.score_mode) pred = outs_dict['maps'] if isinstance(pred, torch.Tensor): pred = pred.detach().cpu().numpy() diff --git a/opendet/preprocess/db_resize_for_test.py b/opendet/preprocess/db_resize_for_test.py index dbd95f0..e974884 100644 --- a/opendet/preprocess/db_resize_for_test.py +++ b/opendet/preprocess/db_resize_for_test.py @@ -27,6 +27,8 @@ def __init__(self, **kwargs): def __call__(self, data): img = data['image'] + if 'max_sile_len' in data: + self.limit_side_len = data['max_sile_len'] src_h, src_w, _ = img.shape if sum([src_h, src_w]) < 64: img = self.image_padding(img) diff --git a/openrec/modeling/decoders/__init__.py b/openrec/modeling/decoders/__init__.py index f570fd0..bbb9b21 100644 --- a/openrec/modeling/decoders/__init__.py +++ b/openrec/modeling/decoders/__init__.py @@ -28,6 +28,7 @@ def build_decoder(config): from .cam_decoder import CAMDecoder from .ote_decoder import OTEDecoder from .bus_decoder import BUSDecoder + # from .dptr_parseq_clip_b_decoder import DptrParseq support_dict = [ 'CTCDecoder', 'NRTRDecoder', 'CPPDDecoder', 'ABINetDecoder', @@ -35,7 +36,7 @@ def build_decoder(config): 'SMTRDecoder', 'LPVDecoder', 'SARDecoder', 'RobustScannerDecoder', 'SRNDecoder', 'ASTERDecoder', 'RCTCDecoder', 'LISTERDecoder', 'GTCDecoder', 'SMTRDecoderNumAttn', 'MATRNDecoder', 'MGPDecoder', - 'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder' + 'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder', 'DptrParseq' ] module_name = config.pop('name') diff --git a/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py b/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py new file mode 100644 index 0000000..566b4a2 --- /dev/null +++ b/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py @@ -0,0 +1,1398 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from itertools import permutations +from collections import OrderedDict +import hashlib +import os +import gzip +import html +import urllib +import warnings +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules import transformer +from typing import Any, Optional, Tuple, List, Union +from pkg_resources import packaging +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm +from functools import lru_cache + +import ftfy +import regex as re + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = SimpleTokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x) + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # take features from the eot embedding (eot_token is the highest number in each sequence) + output = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + output = torch.cat([output.unsqueeze(1), x], dim=1) + + return output + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +class FMU(nn.Module): + """A Transformer decoder layer supporting two-stream attention (XLNet) + This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch.""" + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', + layer_norm_eps=1e-5): + super().__init__() + self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = transformer._get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.gelu + super().__setstate__(state) + + def forward(self, query: Tensor, memory: Tensor): + """Forward pass for a single stream (i.e. content or query) + tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency. + Both tgt_kv and memory are expected to be LayerNorm'd too. + memory is LayerNorm'd by ViT. + """ + query1, ca_weights = self.cross_attn(query, memory, memory) + query = query + self.dropout1(query1) + + query2 = self.linear2(self.dropout2(self.activation(self.linear1(self.norm(query))))) + query = query + self.dropout3(query2) + + return query + + +class DecoderLayer(nn.Module): + """A Transformer decoder layer supporting two-stream attention (XLNet) This + implements a pre-LN decoder, as opposed to the post-LN default in + PyTorch.""" + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation='gelu', + layer_norm_eps=1e-5, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, + nhead, + dropout=dropout, + batch_first=True) + self.cross_attn = nn.MultiheadAttention(d_model, + nhead, + dropout=dropout, + batch_first=True) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = transformer._get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.gelu + super().__setstate__(state) + + def forward_stream( + self, + tgt: Tensor, + tgt_norm: Tensor, + tgt_kv: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor], + tgt_key_padding_mask: Optional[Tensor], + ): + """Forward pass for a single stream (i.e. content or query) tgt_norm is + just a LayerNorm'd tgt. + + Added as a separate parameter for efficiency. Both tgt_kv and memory + are expected to be LayerNorm'd too. memory is LayerNorm'd by ViT. + """ + tgt2, sa_weights = self.self_attn( + tgt_norm, + tgt_kv, + tgt_kv, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + + tgt = tgt + self.dropout1(tgt2) + + tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory) + self.attn_map = ca_weights + tgt = tgt + self.dropout2(tgt2) + + tgt2 = self.linear2( + self.dropout(self.activation(self.linear1(self.norm2(tgt))))) + tgt = tgt + self.dropout3(tgt2) + return tgt, sa_weights, ca_weights + + def forward( + self, + query, + content, + memory, + query_mask: Optional[Tensor] = None, + content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None, + update_content: bool = True, + ): + query_norm = self.norm_q(query) + content_norm = self.norm_c(content) + query = self.forward_stream(query, query_norm, content_norm, memory, + query_mask, content_key_padding_mask)[0] + if update_content: + content = self.forward_stream(content, content_norm, content_norm, + memory, content_mask, + content_key_padding_mask)[0] + return query, content + + +class Decoder(nn.Module): + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm): + super().__init__() + self.layers = transformer._get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + query, + content, + memory, + query_mask: Optional[Tensor] = None, + content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None, + ): + for i, mod in enumerate(self.layers): + last = i == len(self.layers) - 1 + query, content = mod( + query, + content, + memory, + query_mask, + content_mask, + content_key_padding_mask, + update_content=not last, + ) + query = self.norm(query) + return query + + +class TokenEmbedding(nn.Module): + + def __init__(self, charset_size: int, embed_dim: int): + super().__init__() + self.embedding = nn.Embedding(charset_size, embed_dim) + self.embed_dim = embed_dim + + def forward(self, tokens: torch.Tensor): + return math.sqrt(self.embed_dim) * self.embedding(tokens) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + +class DptrParseq(nn.Module): + + def __init__(self, + in_channels, + out_channels, + max_label_length=25, + embed_dim=512, + dec_num_heads=8, + dec_mlp_ratio=4, + dec_depth=6, + perm_num=6, + perm_forward=True, + perm_mirrored=True, + decode_ar=True, + refine_iters=1, + dropout=0.1, + is_pretrain=True, + ORP_path=None, + **kwargs: Any) -> None: + super().__init__() + self.pad_id = out_channels - 1 + self.eos_id = 0 + self.bos_id = out_channels - 2 + self.max_label_length = max_label_length + self.decode_ar = decode_ar + self.refine_iters = refine_iters + self.is_pretrain = is_pretrain + if not is_pretrain: + self.token_query = nn.Parameter(torch.Tensor(1, 26, embed_dim)) + self.fmu = FMU(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) + + decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) + self.decoder = Decoder(decoder_layer, + num_layers=dec_depth, + norm=nn.LayerNorm(embed_dim)) + + # Perm/attn mask stuff + self.rng = np.random.default_rng() + self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num + self.perm_forward = perm_forward + self.perm_mirrored = perm_mirrored + + # We don't predict
---|