forked from vba34520/chineseocr_lite
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathocr.py
133 lines (118 loc) · 4.97 KB
/
ocr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# -*- coding: utf-8 -*-
import cv2
import torch
import numpy as np
from PIL import Image
from io import BytesIO
from crnn import LiteCrnn, CRNNHandle
from psenet import PSENet, PSENetHandel
from application import idcard, trainTicket
from crnn.keys import alphabetChinese as alphabet
from angle_class import AangleClassHandle, shufflenet_v2_x0_5
def crop_rect(img, rect, alph=0.15):
img = np.asarray(img)
center, size, angle = rect[0], rect[1], rect[2]
min_size = min(size)
if (angle > -45):
center, size = tuple(map(int, center)), tuple(map(int, size))
size = (int(size[0] + min_size * alph), int(size[1] + min_size * alph))
height, width = img.shape[0], img.shape[1]
M = cv2.getRotationMatrix2D(center, angle, 1)
img_rot = cv2.warpAffine(img, M, (width, height))
img_crop = cv2.getRectSubPix(img_rot, size, center)
else:
center = tuple(map(int, center))
size = tuple([int(rect[1][1]), int(rect[1][0])])
size = (int(size[0] + min_size * alph), int(size[1] + min_size * alph))
angle -= 270
height, width = img.shape[0], img.shape[1]
M = cv2.getRotationMatrix2D(center, angle, 1)
img_rot = cv2.warpAffine(img, M, (width, height))
img_crop = cv2.getRectSubPix(img_rot, size, center)
img_crop = Image.fromarray(img_crop)
return img_crop
def crnnRec(im, rects_re, f=1.0):
results = []
im = Image.fromarray(im)
for index, rect in enumerate(rects_re):
degree, w, h, cx, cy = rect
partImg = crop_rect(im, ((cx, cy), (h, w), degree))
newW, newH = partImg.size
partImg_array = np.uint8(partImg)
if newH > 1.5 * newW:
partImg_array = np.rot90(partImg_array, 1)
angel_index = angle_handle.predict(partImg_array)
angel_class = lable_map_dict[angel_index]
rotate_angle = rotae_map_dict[angel_class]
if rotate_angle != 0:
partImg_array = np.rot90(partImg_array, rotate_angle // 90)
partImg = Image.fromarray(partImg_array).convert("RGB")
partImg_ = partImg.convert('L')
try:
if crnn_vertical_handle is not None and angel_class in ["shudao", "shuzhen"]:
simPred = crnn_vertical_handle.predict(partImg_)
else:
simPred = crnn_handle.predict(partImg_) # 识别的文本
except:
continue
if simPred.strip() != u'':
results.append({'cx': cx * f, 'cy': cy * f, 'text': simPred, 'w': newW * f, 'h': newH * f,
'degree': degree})
return results
def text_predict(img):
'''文本预测'''
preds, boxes_list, rects_re, t = text_handle.predict(img, long_size=pse_long_size)
result = crnnRec(np.array(img), rects_re)
return result
# 调用CPU或GPU
gpu_id = 0
if gpu_id and isinstance(gpu_id, int) and torch.cuda.is_available():
device = torch.device("cuda:{}".format(gpu_id))
else:
device = torch.device("cpu")
print('device:', device)
# psenet相关
pse_scale = 1
pse_long_size = 960 # 图片长边
pse_model_type = "mobilenetv2"
pse_model_path = "models/psenet_lite_mbv2.pth"
text_detect_net = PSENet(backbone=pse_model_type, pretrained=False, result_num=6, scale=pse_scale)
text_handle = PSENetHandel(pse_model_path, text_detect_net, pse_scale, gpu_id=gpu_id)
# crnn相关
nh = 256
crnn_model_path = "models/crnn_lite_lstm_dw_v2.pth"
crnn_net = LiteCrnn(32, 1, len(alphabet) + 1, nh, n_rnn=2, leakyRelu=False, lstmFlag=True)
crnn_handle = CRNNHandle(crnn_model_path, crnn_net, gpu_id=gpu_id)
crnn_vertical_model_path = "models/crnn_dw_lstm_vertical.pth"
crnn_vertical_net = LiteCrnn(32, 1, len(alphabet) + 1, nh, n_rnn=2, leakyRelu=False, lstmFlag=True)
crnn_vertical_handle = CRNNHandle(crnn_vertical_model_path, crnn_vertical_net, gpu_id=gpu_id)
# angle_class相关
lable_map_dict = {0: "hengdao", 1: "hengzhen", 2: "shudao", 3: "shuzhen"} # hengdao: 文本行横向倒立 其他类似
rotae_map_dict = {"hengdao": 180, "hengzhen": 0, "shudao": 180, "shuzhen": 0} # 文本行需要旋转的角度
angle_model_path = "models/shufflenetv2_05.pth"
angle_net = shufflenet_v2_x0_5(num_classes=len(lable_map_dict), pretrained=False)
angle_handle = AangleClassHandle(angle_model_path, angle_net, gpu_id=gpu_id)
def result(img):
back = {}
img = Image.open(BytesIO(img)).convert('RGB')
img = np.array(img)
result = text_predict(img)
back['文本'] = list(map(lambda x: x['text'], result))
res = trainTicket.trainTicket(result)
back['火车票'] = str(res)
res = idcard.idcard(result)
back['身份证'] = str(res)
return back
if __name__ == '__main__':
img = './test/idcard-demo.jpg'
img = Image.open(img).convert('RGB')
# img.show()
img = np.array(img)
text = text_predict(img)
print('文本预测:', list(map(lambda x: x['text'], text)))
# 火车票
res = trainTicket.trainTicket(text)
print('火车票预测:', res)
# 身份证
res = idcard.idcard(text)
print('身份证预测:', res)