Skip to content

Commit

Permalink
Added vietocr.vietocr
Browse files Browse the repository at this point in the history
  • Loading branch information
bmd1905 committed Jan 20, 2023
1 parent e666d3b commit 9ba1195
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 30 deletions.
Empty file added __init__.py
Empty file.
4 changes: 2 additions & 2 deletions vietocr/loader/dataloader_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import numpy as np
from PIL import Image
import random
from vietocr.model.vocab import Vocab
from vietocr.tool.translate import process_image
from vietocr.vietocr.model.vocab import Vocab
from vietocr.vietocr.tool.translate import process_image
import os
from collections import defaultdict
import math
Expand Down
4 changes: 2 additions & 2 deletions vietocr/model/backbone/cnn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from torch import nn

import vietocr.model.backbone.vgg as vgg
from vietocr.model.backbone.resnet import Resnet50
import vietocr.vietocr.model.backbone.vgg as vgg
from vietocr.vietocr.model.backbone.resnet import Resnet50

class CNN(nn.Module):
def __init__(self, backbone, **kwargs):
Expand Down
20 changes: 10 additions & 10 deletions vietocr/model/trainer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from vietocr.optim.optim import ScheduledOptim
from vietocr.optim.labelsmoothingloss import LabelSmoothingLoss
from vietocr.vietocr.optim.optim import ScheduledOptim
from vietocr.vietocr.optim.labelsmoothingloss import LabelSmoothingLoss
from torch.optim import Adam, SGD, AdamW
from torch import nn
from vietocr.tool.translate import build_model
from vietocr.tool.translate import translate, batch_translate_beam_search
from vietocr.tool.utils import download_weights
from vietocr.tool.logger import Logger
from vietocr.loader.aug import ImgAugTransform
from vietocr.vietocr.tool.translate import build_model
from vietocr.vietocr.tool.translate import translate, batch_translate_beam_search
from vietocr.vietocr.tool.utils import download_weights
from vietocr.vietocr.tool.logger import Logger
from vietocr.vietocr.loader.aug import ImgAugTransform

import yaml
import torch
from vietocr.loader.dataloader_v1 import DataGen
from vietocr.loader.dataloader import OCRDataset, ClusterRandomSampler, Collator
from vietocr.vietocr.loader.dataloader_v1 import DataGen
from vietocr.vietocr.loader.dataloader import OCRDataset, ClusterRandomSampler, Collator
from torch.utils.data import DataLoader
from einops import rearrange
from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR, OneCycleLR

import torchvision

from vietocr.tool.utils import compute_accuracy
from vietocr.vietocr.tool.utils import compute_accuracy
from PIL import Image
import numpy as np
import os
Expand Down
8 changes: 4 additions & 4 deletions vietocr/model/transformerocr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from vietocr.model.backbone.cnn import CNN
from vietocr.model.seqmodel.transformer import LanguageTransformer
from vietocr.model.seqmodel.seq2seq import Seq2Seq
from vietocr.model.seqmodel.convseq2seq import ConvSeq2Seq
from vietocr.vietocr.model.backbone.cnn import CNN
from vietocr.vietocr.model.seqmodel.transformer import LanguageTransformer
from vietocr.vietocr.model.seqmodel.seq2seq import Seq2Seq
from vietocr.vietocr.model.seqmodel.convseq2seq import ConvSeq2Seq
from torch import nn

class VietOCR(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions vietocr/predict.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import argparse
from PIL import Image

from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg
from vietocr.vietocr.tool.predictor import Predictor
from vietocr.vietocr.tool.config import Cfg

def main():
parser = argparse.ArgumentParser()
Expand Down
4 changes: 2 additions & 2 deletions vietocr/tests/utest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from vietocr.loader.dataloader_v1 import DataGen
from vietocr.model.vocab import Vocab
from vietocr.vietocr.loader.dataloader_v1 import DataGen
from vietocr.vietocr.model.vocab import Vocab

def test_loader():
chars = 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '
Expand Down
2 changes: 1 addition & 1 deletion vietocr/tool/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import yaml
from vietocr.tool.utils import download_config
from vietocr.vietocr.tool.utils import download_config

url_config = {
'vgg_transformer':'vgg-transformer.yml',
Expand Down
4 changes: 2 additions & 2 deletions vietocr/tool/predictor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from vietocr.tool.translate import build_model, translate, translate_beam_search, process_input, predict
from vietocr.tool.utils import download_weights
from vietocr.vietocr.tool.translate import build_model, translate, translate_beam_search, process_input, predict
from vietocr.vietocr.tool.utils import download_weights

import torch
from collections import defaultdict
Expand Down
6 changes: 3 additions & 3 deletions vietocr/tool/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from PIL import Image
from torch.nn.functional import log_softmax, softmax

from vietocr.model.transformerocr import VietOCR
from vietocr.model.vocab import Vocab
from vietocr.model.beam import Beam
from vietocr.vietocr.model.transformerocr import VietOCR
from vietocr.vietocr.model.vocab import Vocab
from vietocr.vietocr.model.beam import Beam

def batch_translate_beam_search(img, model, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2):
# img: NxCxHxW
Expand Down
4 changes: 2 additions & 2 deletions vietocr/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse

from vietocr.model.trainer import Trainer
from vietocr.tool.config import Cfg
from vietocr.vietocr.model.trainer import Trainer
from vietocr.vietocr.tool.config import Cfg

def main():
parser = argparse.ArgumentParser()
Expand Down

0 comments on commit 9ba1195

Please sign in to comment.