Skip to content

Commit

Permalink
detector static cache [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Nov 5, 2024
1 parent 7032ac5 commit f277df1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
15 changes: 11 additions & 4 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@ def get_batch_size():
batch_size = 36
return batch_size

fake_image = Image.new("RGB", (1200, 1200), color=(255, 255, 255))
def pad_to_batch_size(tensor, batch_size):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
return tensor

pad_size = batch_size - current_batch_size
padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)

return F.pad(tensor, padding, mode='constant', value=0)

def batch_detection(
images: List,
Expand Down Expand Up @@ -73,12 +81,11 @@ def batch_detection(
split_index.extend([image_idx] * len(image_parts))
split_heights.extend(split_height)

if len(image_splits) < batch_size:
pad_size = batch_size - len(image_splits)
image_splits += [fake_image] * pad_size
image_splits = [prepare_image_detection(image, processor) for image in image_splits]
# Batch images in dim 0
batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device)
if settings.DETECTOR_STATIC_CACHE or settings.LAYOUT_STATIC_CACHE:
batch = pad_to_batch_size(batch, batch_size)

with torch.inference_mode():
pred = model(pixel_values=batch)
Expand Down
2 changes: 2 additions & 0 deletions surya/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,6 @@ def batch_table_recognition(images: List, table_cells: List[List[Dict]], model:

output_order.append(result)

del text_encoder_hidden_states

return output_order

0 comments on commit f277df1

Please sign in to comment.