Skip to content

Commit

Permalink
Modify prediction logic
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Dec 18, 2024
1 parent 4701f96 commit f8188f4
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 36 deletions.
2 changes: 1 addition & 1 deletion ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Im
(item.bbox[3] + table_bbox[1])
])
labels.append(item.label)
if hasattr(item, "row_id"):
if "Row" in item.label:
colors.append("blue")
else:
colors.append("red")
Expand Down
2 changes: 1 addition & 1 deletion surya/model/table_rec/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
initializer_range=0.02,
layer_norm_eps=1e-5,
encoder_length=1024,
use_positional_embeddings=False,
use_positional_embeddings=True,
**kwargs,
):
super().__init__(**kwargs)
Expand Down
17 changes: 16 additions & 1 deletion surya/model/table_rec/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,15 @@ def resize_polygon(self, polygon, orig_size, new_size):

return polygon

def __call__(self, images: List[PIL.Image.Image] | None, query_items: List[dict], convert_images: bool = True, *args, **kwargs):
def __call__(
self,
images: List[PIL.Image.Image] | None,
query_items: List[dict],
columns: List[dict] | None = None,
convert_images: bool = True,
*args,
**kwargs
):
if convert_images:
assert len(images) == len(query_items)
assert len(images) > 0
Expand All @@ -75,6 +83,13 @@ def __call__(self, images: List[PIL.Image.Image] | None, query_items: List[dict]
[self.token_query_end_id] * col_count
])

# Add columns to end of decoder input
if columns:
columns = self.shaper.convert_polygons_to_bboxes(columns)
column_labels = self.shaper.dict_to_labels(columns)
for decoder_box in decoder_input_boxes:
decoder_box += column_labels

input_boxes = torch.tensor(decoder_input_boxes, dtype=torch.long)
input_boxes_mask = torch.ones_like(input_boxes, dtype=torch.long)

Expand Down
23 changes: 23 additions & 0 deletions surya/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from copy import deepcopy
from typing import List, Tuple, Any, Optional

from pydantic import BaseModel, field_validator, computed_field
Expand Down Expand Up @@ -71,6 +72,21 @@ def merge(self, other):
y2 = max(self.bbox[3], other.bbox[3])
self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]

def intersection_polygon(self, other) -> List[List[float]]:
new_poly = []
for i in range(4):
if i == 0:
new_corner = [max(self.polygon[0][0], other.polygon[0][0]), max(self.polygon[0][1], other.polygon[0][1])]
elif i == 1:
new_corner = [min(self.polygon[1][0], other.polygon[1][0]), max(self.polygon[1][1], other.polygon[1][1])]
elif i == 2:
new_corner = [min(self.polygon[2][0], other.polygon[2][0]), min(self.polygon[2][1], other.polygon[2][1])]
elif i == 3:
new_corner = [max(self.polygon[3][0], other.polygon[3][0]), min(self.polygon[3][1], other.polygon[3][1])]
new_poly.append(new_corner)

return new_poly

def intersection_area(self, other, x_margin=0, y_margin=0):
x_overlap = self.x_overlap(other, x_margin)
y_overlap = self.y_overlap(other, y_margin)
Expand Down Expand Up @@ -190,10 +206,16 @@ class TableCell(PolygonBox):
row_id: int
colspan: int
within_row_id: int
cell_id: int
rowspan: int | None = None
merge_up: bool = False
merge_down: bool = False
col_id: int | None = None

@property
def label(self):
return f'{self.row_id} {self.rowspan}/{self.colspan}'


class TableRow(PolygonBox):
row_id: int
Expand All @@ -213,6 +235,7 @@ def label(self):

class TableResult(BaseModel):
cells: List[TableCell]
unmerged_cells: List[TableCell]
rows: List[TableRow]
cols: List[TableCol]
image_bbox: List[float]
2 changes: 1 addition & 1 deletion surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def TORCH_DEVICE_MODEL(self) -> str:
ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench"

# Table Rec
TABLE_REC_MODEL_CHECKPOINT: str = "datalab-to/table_rec_2_test"
TABLE_REC_MODEL_CHECKPOINT: str = "datalab-to/table_rec_2_test2"
TABLE_REC_IMAGE_SIZE: Dict = {"height": 768, "width": 768}
TABLE_REC_MAX_BOXES: int = 150
TABLE_REC_BATCH_SIZE: Optional[int] = None
Expand Down
146 changes: 114 additions & 32 deletions surya/tables.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from copy import deepcopy
from itertools import chain
from typing import List
import torch
from PIL import Image
Expand All @@ -6,7 +8,7 @@
from surya.model.table_rec.columns import find_columns
from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
from surya.model.table_rec.shaper import LabelShaper
from surya.schema import TableResult, TableCell, TableRow
from surya.schema import TableResult, TableCell, TableRow, TableCol, PolygonBox
from surya.settings import settings
from tqdm import tqdm
import numpy as np
Expand Down Expand Up @@ -74,7 +76,6 @@ def inference_loop(

model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)

print(batch_input_ids)
with torch.inference_mode():
token_count = 0
all_done = torch.zeros(current_batch_size, dtype=torch.bool)
Expand Down Expand Up @@ -107,12 +108,13 @@ def inference_loop(
elif mode == "regression":
if k == "bbox":
k_logits *= BOX_DIM
k_logits = k_logits.tolist()
elif k == "colspan":
k_logits = k_logits.clamp(min=1)
box_property[k] = k_logits.tolist()
k_logits = int(k_logits.round().item())
box_property[k] = k_logits
box_properties.append(box_property)

print(box_properties[0])
all_done = all_done | torch.tensor(done, dtype=torch.bool)

if all_done.all():
Expand Down Expand Up @@ -173,69 +175,149 @@ def batch_table_recognition(images: List, model: TableRecEncoderDecoderModel, pr
with torch.inference_mode():
encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values).last_hidden_state

row_predictions = inference_loop(model, encoder_hidden_states, batch_input_ids, current_batch_size, batch_size)
rowcol_predictions = inference_loop(model, encoder_hidden_states, batch_input_ids, current_batch_size, batch_size)

row_query_items = []
row_encoder_hidden_states = []
idx_map = []
for j, img_predictions in enumerate(row_predictions):
columns = []
for j, img_predictions in enumerate(rowcol_predictions):
for row_prediction in img_predictions:
polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"])
row_query_items.append({
"polygon": polygon,
"category": CATEGORY_TO_ID["Table-row"],
"colspan": 0,
"merges": 0,
})
row_encoder_hidden_states.append(encoder_hidden_states[j])
idx_map.append(j)
if row_prediction["category"] == CATEGORY_TO_ID["Table-row"]:
row_query_items.append({
"polygon": polygon,
"category": row_prediction["category"],
"colspan": 0,
"merges": 0,
})
row_encoder_hidden_states.append(encoder_hidden_states[j])
idx_map.append(j)
elif row_prediction["category"] == CATEGORY_TO_ID["Table-column"]:
columns.append({
"polygon": polygon,
"category": row_prediction["category"],
"colspan": 0,
"merges": 0,
})

row_encoder_hidden_states = torch.stack(row_encoder_hidden_states)
row_inputs = processor(images=None, query_items=row_query_items, convert_images=False)
row_inputs = processor(images=None, query_items=row_query_items, columns=columns, convert_images=False)
row_input_ids = row_inputs["input_ids"].to(model.device)
cell_predictions = []
"""
for j in tqdm(range(0, len(row_input_ids), batch_size), desc="Recognizing tables"):
cell_batch_hidden_states = row_encoder_hidden_states[j:j+batch_size]
cell_batch_input_ids = row_input_ids[j:j+batch_size]
cell_batch_size = len(cell_batch_input_ids)
cell_predictions.extend(
inference_loop(model, cell_batch_hidden_states, cell_batch_input_ids, cell_batch_size, batch_size)
)
"""

for j, (img_predictions, orig_size) in enumerate(zip(row_predictions, orig_sizes)):
for j, (img_predictions, orig_size) in enumerate(zip(rowcol_predictions, orig_sizes)):
row_cell_predictions = [c for i,c in enumerate(cell_predictions) if idx_map[i] == j]
# Each row prediction matches a cell prediction
#assert len(img_predictions) == len(row_cell_predictions)
rows = []
cells = []
for z, row_prediction in enumerate(img_predictions):
columns = []

cell_id = 0
row_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-row"]]
col_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-column"]]

for z, col_prediction in enumerate(col_predictions):
polygon = shaper.convert_bbox_to_polygon(col_prediction["bbox"])
polygon = processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
columns.append(
TableCol(
polygon=polygon,
col_id=z
)
)

for z, row_prediction in enumerate(row_predictions):
polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"])
polygon = processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
rows.append(TableRow(
row = TableRow(
polygon=polygon,
row_id=z
))
"""
for l, cell in enumerate(row_cell_predictions[z]):
polygon = shaper.convert_bbox_to_polygon(cell["bbox"])
)
rows.append(row)

spanning_cells = []
for l, spanning_cell in enumerate(row_cell_predictions[z]):
polygon = shaper.convert_bbox_to_polygon(spanning_cell["bbox"])
polygon = processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
cells.append(
spanning_cells.append(
TableCell(
polygon=polygon,
row_id=z,
rowspan=1,
cell_id=cell_id,
within_row_id=l,
colspan=max(1, int(cell["colspan"])),
merge_up=cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]],
merge_down=cell["merges"] in [MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]],
colspan=max(1, int(spanning_cell["colspan"])),
merge_up=spanning_cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]],
merge_down=spanning_cell["merges"] in [MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]],
)
)
"""
columns = find_columns(rows, cells)
cell_id += 1


used_spanning_cells = set()
for l, col in enumerate(columns):
cell_polygon = row.intersection_polygon(col)
cell_added = False
for zz, spanning_cell in enumerate(spanning_cells):
intersection_pct = PolygonBox(polygon=cell_polygon).intersection_pct(spanning_cell)
if intersection_pct > .5:
cell_added = True
if zz not in used_spanning_cells:
used_spanning_cells.add(zz)
cells.append(spanning_cell)

if not cell_added:
cells.append(
TableCell(
polygon=cell_polygon,
row_id=z,
rowspan=1,
cell_id=cell_id,
within_row_id=l,
colspan=1,
merge_up=False,
merge_down=False,
)
)
cell_id += 1

grid_cells = deepcopy([
[cell for cell in cells if cell.row_id == row.row_id]
for row in rows
])

for z, grid_row in enumerate(grid_cells[1:]):
prev_row = grid_cells[z]
for l, cell in enumerate(grid_row):
if l >= len(prev_row):
continue

above_cell = prev_row[l]
if above_cell.merge_down and cell.merge_up:
above_cell.merge(cell)
above_cell.rowspan += cell.rowspan
grid_row[l] = above_cell
merged_cells_all = list(chain.from_iterable(grid_cells))
used_ids = set()
merged_cells = []
for cell in merged_cells_all:
if cell.cell_id in used_ids:
continue
used_ids.add(cell.cell_id)
merged_cells.append(cell)


result = TableResult(
cells=cells,
cells=merged_cells,
unmerged_cells=cells,
rows=rows,
cols=columns,
image_bbox=[0, 0, orig_size[0], orig_size[1]],
Expand Down

0 comments on commit f8188f4

Please sign in to comment.