Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/sml2h3/ddddocr
Browse files Browse the repository at this point in the history
� Conflicts:
�	ddddocr/__init__.py
  • Loading branch information
sml2h3 committed Feb 26, 2022
2 parents b94785e + 97f23ad commit 2cf2924
Showing 1 changed file with 24 additions and 55 deletions.
79 changes: 24 additions & 55 deletions ddddocr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import io
import os
import base64
import json
import pathlib
import onnxruntime
from PIL import Image, ImageChops
import numpy as np
Expand All @@ -29,26 +29,13 @@ class TypeError(Exception):

class DdddOcr(object):
def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, use_gpu: bool = False,
device_id: int = 0, show_ad=True, import_onnx_path: str = "", charsets_path: str = ""):
device_id: int = 0, show_ad=True):
if show_ad:
print("欢迎使用ddddocr,本项目专注带动行业内卷,个人博客:wenanzhe.com")
print("训练数据支持来源于:http://146.56.204.113:19199/preview")
print("爬虫框架feapder可快速一键接入,快速开启爬虫之旅:https://github.com/Boris-code/feapder")
self.use_import_onnx = False
self.__word = False
self.__resize = []
self.__channel = 1
if import_onnx_path != "":
det = False
ocr = False
self.__graph_path = import_onnx_path
with open(charsets_path, 'r', encoding="utf-8") as f:
info = json.loads(f.read())
self.__charset = info['charset']
self.__word = info['word']
self.__resize = info['image']
self.__channel = info['channel']
self.use_import_onnx = True


if det:
ocr = False
print("开启det后自动关闭ocr")
Expand Down Expand Up @@ -1453,7 +1440,7 @@ def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, use_g
self.__providers = [
'CPUExecutionProvider',
]
if ocr or det or self.use_import_onnx:
if ocr or det:
self.__ort_session = onnxruntime.InferenceSession(self.__graph_path, providers=self.__providers)

def preproc(self, img, input_size, swap=(2, 0, 1)):
Expand Down Expand Up @@ -1594,53 +1581,35 @@ def get_bbox(self, image_bytes):
return []
return result

def classification(self, img_bytes: bytes = None, img_base64: str = None):
def classification(self, img):
if self.det:
raise TypeError("当前识别类型为目标检测")
if img_bytes:
image = Image.open(io.BytesIO(img_bytes))
else:
image = base64_to_image(img_base64)
if not self.use_import_onnx:
image = image.resize((int(image.size[0] * (64 / image.size[1])), 64), Image.ANTIALIAS).convert('L')
if not isinstance(img, (bytes, str, pathlib.PurePath, Image.Image)):
raise TypeError("未知图片类型")
if isinstance(img, bytes):
image = Image.open(io.BytesIO(img))
elif isinstance(img, Image.Image):
image = img.copy()
elif isinstance(img, str):
image = base64_to_image(img)
else:
if self.__resize[0] == -1:
if self.__word:
image = image.resize((self.__resize[1], self.__resize[1]), Image.ANTIALIAS)
else:
image = image.resize((int(image.size[0] * (self.__resize[1] / image.size[1])), self.__resize[1]), Image.ANTIALIAS)
else:
image = image.resize((self.__resize[0], self.__resize[1]), Image.ANTIALIAS)
if self.__channel == 1:
image = image.convert('L')
else:
image = image.convert('RGB')
assert isinstance(img, pathlib.PurePath)
image = Image.open(img)
image = image.resize((int(image.size[0] * (64 / image.size[1])), 64), Image.ANTIALIAS).convert('L')
image = np.array(image).astype(np.float32)
image = np.expand_dims(image, axis=0) / 255.
if not self.use_import_onnx:
image = (image - 0.5) / 0.5
else:
if self.__channel == 1:
image = (image - 0.456) / 0.224
else:
image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

image = (image - 0.5) / 0.5
ort_inputs = {'input1': np.array([image])}
ort_outs = self.__ort_session.run(None, ort_inputs)
result = []

last_item = 0
if self.__word:
for item in ort_outs[1]:
for item in ort_outs[0][0]:
if item == last_item:
continue
else:
last_item = item
if item != 0:
result.append(self.__charset[item])
else:
for item in ort_outs[0][0]:
if item == last_item:
continue
else:
last_item = item
if item != 0:
result.append(self.__charset[item])

return ''.join(result)

Expand Down

0 comments on commit 2cf2924

Please sign in to comment.