Skip to content

Commit

Permalink
Refactor batch sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 7, 2025
1 parent d17281e commit 4c5a180
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 126 deletions.
76 changes: 17 additions & 59 deletions surya/common/polygon.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
from typing import List, Optional

import numpy as np
from pydantic import BaseModel, field_validator, computed_field

from surya.postprocessing.util import rescale_bbox
Expand All @@ -11,23 +12,26 @@ class PolygonBox(BaseModel):
confidence: Optional[float] = None

@field_validator('polygon', mode='before')
@classmethod
def convert_bbox_to_polygon(cls, value):
if isinstance(value, (list, tuple)) and len(value) == 4:
x_min, y_min, x_max, y_max = value
polygon = [
[x_min, y_min],
[x_max, y_min],
[x_max, y_max],
[x_min, y_max],
]
return polygon

if isinstance(value, list) and len(value) == 4:
if all(isinstance(point, (list, tuple)) and len(point) == 2 for point in value):
if all(isinstance(x, (int, float)) for x in value):
x_min, y_min, x_max, y_max = value
polygon = [
[x_min, y_min],
[x_max, y_min],
[x_max, y_max],
[x_min, y_max],
]
return polygon
elif all(isinstance(point, (list, tuple)) and len(point) == 2 for point in value):
return value
elif isinstance(value, np.ndarray):
if value.shape == (4, 2):
return value.tolist()

raise ValueError(
"Input must be either a bbox [x_min, y_min, x_max, y_max] or a polygon with 4 corners [(x,y), (x,y), (x,y), (x,y)]")
f"Input must be either a bbox [x_min, y_min, x_max, y_max] or a polygon with 4 corners [(x,y), (x,y), (x,y), (x,y)]. You passed {value}.")

@property
def height(self):
Expand Down Expand Up @@ -121,50 +125,4 @@ def shift(self, x_shift: float | None = None, y_shift: float | None = None):
corner[0] += x_shift
if y_shift is not None:
for corner in self.polygon:
corner[1] += y_shift


class Bbox(BaseModel):
bbox: List[float]

@field_validator('bbox')
@classmethod
def check_4_elements(cls, v: List[float]) -> List[float]:
if len(v) != 4:
raise ValueError('bbox must have 4 elements')
return v

def rescale_bbox(self, orig_size, new_size):
self.bbox = rescale_bbox(self.bbox, orig_size, new_size)

def round_bbox(self, divisor):
self.bbox = [x // divisor * divisor for x in self.bbox]

@property
def height(self):
return self.bbox[3] - self.bbox[1]

@property
def width(self):
return self.bbox[2] - self.bbox[0]

@property
def area(self):
return self.width * self.height

@property
def polygon(self):
return [[self.bbox[0], self.bbox[1]], [self.bbox[2], self.bbox[1]], [self.bbox[2], self.bbox[3]], [self.bbox[0], self.bbox[3]]]

@property
def center(self):
return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2]

def intersection_pct(self, other):
if self.area == 0:
return 0

x_overlap = max(0, min(self.bbox[2], other.bbox[2]) - max(self.bbox[0], other.bbox[0]))
y_overlap = max(0, min(self.bbox[3], other.bbox[3]) - max(self.bbox[1], other.bbox[1]))
intersection = x_overlap * y_overlap
return intersection / self.area
corner[1] += y_shift
17 changes: 14 additions & 3 deletions surya/common/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@

class BasePredictor:
model_loader_cls = ModelLoader
batch_size = None
default_batch_sizes = {
"cpu": 1,
"mps": 1,
"cuda": 1
}

def __init__(self, checkpoint: Optional[str] = None, device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE):
self.model = None
self.processor = None
Expand All @@ -21,9 +28,13 @@ def to(self, device_dtype: torch.device | str | None = None):
else:
raise ValueError("Model not loaded")

@staticmethod
def get_batch_size():
raise NotImplementedError()
def get_batch_size(self):
batch_size = self.batch_size
if batch_size is None:
batch_size = self.default_batch_sizes["cpu"]
if settings.TORCH_DEVICE_MODEL in self.default_batch_sizes:
batch_size = self.default_batch_sizes[settings.TORCH_DEVICE_MODEL]
return batch_size

def __call__(self, *args, **kwargs):
raise NotImplementedError()
17 changes: 6 additions & 11 deletions surya/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@

class DetectionPredictor(BasePredictor):
model_loader_cls = DetectionModelLoader
batch_size = settings.DETECTOR_BATCH_SIZE
default_batch_sizes = {
"cpu": 8,
"mps": 8,
"cuda": 36
}

def __call__(self, images: List[Image.Image], batch_size=None, include_maps=False) -> List[TextDetectionResult]:
detection_generator = self.batch_detection(images, batch_size=batch_size, static_cache=settings.DETECTOR_STATIC_CACHE)
Expand All @@ -36,17 +42,6 @@ def __call__(self, images: List[Image.Image], batch_size=None, include_maps=Fals

return [future.result() for future in postprocessing_futures]

@staticmethod
def get_batch_size():
batch_size = settings.DETECTOR_BATCH_SIZE
if batch_size is None:
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 36
return batch_size

def pad_to_batch_size(self, tensor, batch_size):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
Expand Down
1 change: 0 additions & 1 deletion surya/input/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import json



def get_name_from_path(path):
return os.path.basename(path).split(".")[0]

Expand Down
17 changes: 6 additions & 11 deletions surya/layout/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

class LayoutPredictor(BasePredictor):
model_loader_cls = LayoutModelLoader
batch_size = settings.LAYOUT_BATCH_SIZE
default_batch_sizes = {
"cpu": 4,
"mps": 4,
"cuda": 32
}

def __call__(
self,
Expand All @@ -30,17 +36,6 @@ def __call__(
batch_size=batch_size
)

@staticmethod
def get_batch_size():
batch_size = settings.LAYOUT_BATCH_SIZE
if batch_size is None:
batch_size = 4
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 4
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 32
return batch_size

def batch_layout_detection(
self,
images: List[Image.Image],
Expand Down
5 changes: 1 addition & 4 deletions surya/layout/slicer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import math
from typing import List, Tuple

import cv2
import numpy as np
from PIL import Image

from surya.layout import LayoutResult
from surya.layout.schema import LayoutResult

SLICES_TYPE = Tuple[List[Image.Image], List[Tuple[int, int, int]]]

Expand Down
17 changes: 6 additions & 11 deletions surya/ocr_error/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@

class OCRErrorPredictor(BasePredictor):
model_loader_cls = OCRErrorModelLoader
batch_size = settings.OCR_ERROR_BATCH_SIZE
default_batch_sizes = {
"cpu": 8,
"mps": 8,
"cuda": 64
}

def __call__(
self,
Expand All @@ -22,17 +28,6 @@ def __call__(
):
return self.batch_ocr_error_detection(texts, batch_size)

@staticmethod
def get_batch_size():
batch_size = settings.OCR_ERROR_BATCH_SIZE
if batch_size is None:
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 64
return batch_size

def batch_ocr_error_detection(
self,
texts: List[str],
Expand Down
19 changes: 6 additions & 13 deletions surya/recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@

class RecognitionPredictor(BasePredictor):
model_loader_cls = RecognitionModelLoader
batch_size = settings.RECOGNITION_BATCH_SIZE
default_batch_sizes = {
"cpu": 32,
"mps": 64,
"cuda": 256
}

def __call__(
self,
Expand Down Expand Up @@ -170,17 +176,6 @@ def slice_bboxes(
"polygons": all_polygons
}

@staticmethod
def get_batch_size():
batch_size = settings.RECOGNITION_BATCH_SIZE
if batch_size is None:
batch_size = 32
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 64 # 12GB RAM max
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 256
return batch_size

def pad_to_batch_size(self, tensor, batch_size):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
Expand Down Expand Up @@ -240,8 +235,6 @@ def batch_recognition(
batch_images = [image.convert("RGB") for image in batch_images] # also copies the images
real_batch_size = len(batch_images)
batch_langs = languages[i:i + real_batch_size]
has_math = [lang and "_math" in lang for lang in batch_langs]

processed_batch = self.processor(text=[""] * len(batch_images), images=batch_images, langs=batch_langs)

batch_pixel_values = processed_batch["pixel_values"]
Expand Down
2 changes: 1 addition & 1 deletion surya/recognition/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List

from surya.recognition import TextLine
from surya.recognition.schema import TextLine


def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25):
Expand Down
Loading

0 comments on commit 4c5a180

Please sign in to comment.