Skip to content

Commit

Permalink
add batch prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
pbcquoc committed Oct 11, 2021
1 parent 79d284e commit 549f5ee
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion vietocr/tool/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from vietocr.tool.utils import download_weights

import torch
from collections import defaultdict

class Predictor():
def __init__(self, config):
Expand All @@ -21,7 +22,8 @@ def __init__(self, config):
self.config = config
self.model = model
self.vocab = vocab

self.device = device

def predict(self, img, return_prob=False):
img = process_input(img, self.config['dataset']['image_height'],
self.config['dataset']['image_min_width'], self.config['dataset']['image_max_width'])
Expand All @@ -42,3 +44,42 @@ def predict(self, img, return_prob=False):
return s, prob
else:
return s

def predict_batch(self, imgs, return_prob=False):
bucket = defaultdict(list)
bucket_idx = defaultdict(list)
bucket_pred = {}

sents, probs = [0]*len(imgs), [0]*len(imgs)

for i, img in enumerate(imgs):
img = process_input(img, self.config['dataset']['image_height'],
self.config['dataset']['image_min_width'], self.config['dataset']['image_max_width'])

bucket[img.shape[-1]].append(img)
bucket_idx[img.shape[-1]].append(i)


for k, batch in bucket.items():
batch = torch.cat(batch, 0).to(self.device)
s, prob = translate(batch, self.model)
prob = prob.tolist()

s = s.tolist()
s = self.vocab.batch_decode(s)

bucket_pred[k] = (s, prob)


for k in bucket_pred:
idx = bucket_idx[k]
sent, prob = bucket_pred[k]
for i, j in enumerate(idx):
sents[j] = sent[i]
probs[j] = prob[i]

if return_prob:
return sents, probs
else:
return sents

0 comments on commit 549f5ee

Please sign in to comment.