Skip to content

Commit

Permalink
fix bug that extended region of removal/copy to first row/column. add…
Browse files Browse the repository at this point in the history
… flag for storing visualization.
  • Loading branch information
sohaib023 committed Dec 28, 2020
1 parent e71616d commit 8dc38e8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 102 deletions.
116 changes: 16 additions & 100 deletions augmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def replicate_row(self, idx_copy, idx_paste):
print("Not convex")
return False

if abs(idx_p1 - idx_paste) <= abs(idx_p2 - idx_paste):
if idx_c1 == 0:
return False

if abs(idx_p1 - idx_paste) <= abs(idx_p2 - idx_paste) and idx_p1 != 0:
idx_paste = idx_p1
else:
idx_paste = idx_p2 + 1
Expand Down Expand Up @@ -166,11 +169,14 @@ def remove_row(self, idx):
print("Not convex")
return False

if idx_1 == 0:
return False

y1 = self.t.gtCells[idx_1][0].y1
y2 = self.t.gtCells[idx_2][0].y2
h = y2 - y1

if h >= self.image.shape[0] * 0.6 or len(self.t.gtCells) - (idx_1 - idx_2 + 1) <= 3:
if h >= self.image.shape[0] * 0.6 or len(self.t.gtCells) - (idx_2 - idx_1 + 1) <= 3:
return False

new_shape = list(self.image.shape)
Expand Down Expand Up @@ -210,7 +216,10 @@ def replicate_column(self, idx_copy, idx_paste):
print("Not convex")
return False

if abs(idx_p1 - idx_paste) <= abs(idx_p2 - idx_paste):
if idx_c1 == 0:
return False

if abs(idx_p1 - idx_paste) <= abs(idx_p2 - idx_paste) and idx_p1 != 0:
idx_paste = idx_p1
else:
idx_paste = idx_p2 + 1
Expand Down Expand Up @@ -267,11 +276,14 @@ def remove_column(self, idx):
print("Not convex")
return False

if idx_1 == 0:
return False

x1 = self.t.gtCells[0][idx_1].x1
x2 = self.t.gtCells[0][idx_2].x2
w = x2 - x1

if w >= self.image.shape[1] * 0.7 or len(self.t.gtCells[0]) - (idx_1 - idx_2 + 1) <= 2:
if w >= self.image.shape[1] * 0.7 or len(self.t.gtCells[0]) - (idx_2 - idx_1 + 1) <= 2:
return False

new_shape = list(self.image.shape)
Expand Down Expand Up @@ -305,102 +317,6 @@ def remove_column(self, idx):

return True

def remove(self, to_remove, is_row):
if is_row:
minspan = 0
maxspan = 0
for i in range(len(self.cells[to_remove])):
if self.cells[to_remove][i].attrib['rowspan'] != 0:
minspan = min(minspan, self.cells[to_remove][i].attrib['rowspan'])
minspan *= -1
to_remove -= minspan
for i in range(len(self.cells[to_remove])):
if self.cells[to_remove][i].attrib['rowspan'] != 0:
maxspan = max(maxspan, self.cells[to_remove][i].attrib['rowspan'])

y1 = self.rows[to_remove]
y2 = self.rows[to_remove + maxspan + 1]

if y2 - y1 >= self.image.shape[0] or len(self.rows) - (maxspan + 1) < 3:
return False

image_row = self.image[y1:y2, :]

image_new = np.zeros((self.image.shape[0] - image_row.shape[0], self.image.shape[1]), dtype=self.image.dtype)
image_new[:y1, :] = self.image[:y1, :]
image_new[y1:, :] = self.image[y2:, :]

ocr_row = get_bounded_ocr(self.ocr, (0, y1), (self.image.shape[1], y2), remove_org=True)

self.ocr += translate_ocr(get_bounded_ocr(self.ocr, (0, y1), (self.image.shape[1], self.image.shape[0]), remove_org=True), (0, y1 - y2))

self.rows = [row for row in self.rows if row not in self.rows[to_remove: to_remove + maxspan + 1]]
self.rows[to_remove:] = [row - (y2 - y1) for row in self.rows[to_remove:]]
self.rows.sort()
self.image = image_new

new_cells = [[None for i in range(len(self.columns) - 1)] for j in range(len(self.rows) - 1)]

for i in range(len(new_cells)):
for j in range(len(new_cells[0])):
if i < to_remove:
new_cells[i][j] = self.cells[i][j]
else:
cell = self.cells[i + maxspan + 1][j]
cell.attrib['y0'] -= y2 - y1
cell.attrib['y1'] -= y2 - y1
cell.attrib['startRow'] -= maxspan + 1
cell.attrib['endRow'] -= maxspan + 1
new_cells[i][j] = cell
self.cells = new_cells
else:
minspan = 0
maxspan = 0
for i in range(len(self.cells)):
if self.cells[i][to_remove].attrib['colspan'] != 0:
minspan = min(minspan, self.cells[i][to_remove].attrib['colspan'])
minspan *= -1
to_remove -= minspan
for i in range(len(self.cells)):
if self.cells[i][to_remove].attrib['colspan'] != 0:
maxspan = max(maxspan, self.cells[i][to_remove].attrib['colspan'])

x1 = self.columns[to_remove]
x2 = self.columns[to_remove + maxspan + 1]

if x2 - x1 >= self.image.shape[1] or len(self.columns) - (maxspan + 1) < 4:
return False

image_col = self.image[:, x1:x2]

image_new = np.zeros((self.image.shape[0], self.image.shape[1] - image_col.shape[1]), dtype=self.image.dtype)
image_new[:, :x1] = self.image[:, :x1]
image_new[:, x1:] = self.image[:, x2:]

ocr_col = get_bounded_ocr(self.ocr, (x1, 0), (x2, self.image.shape[0]), remove_org=True)

self.ocr += translate_ocr(get_bounded_ocr(self.ocr, (x1, 0), (self.image.shape[1], self.image.shape[0]), remove_org=True), (x1 - x2, 0))

self.columns = [col for col in self.columns if col not in self.columns[to_remove: to_remove + maxspan + 1]]
self.columns[to_remove:] = [col - (x2 - x1) for col in self.columns[to_remove:]]
self.columns.sort()
self.image = image_new

new_cells = [[None for i in range(len(self.columns) - 1)] for j in range(len(self.rows) - 1)]

for i in range(len(new_cells)):
for j in range(len(new_cells[0])):
if j < to_remove:
new_cells[i][j] = self.cells[i][j]
else:
cell = self.cells[i][j + maxspan + 1]
cell.attrib['x0'] -= x2 - x1
cell.attrib['x1'] -= x2 - x1
cell.attrib['startCol'] -= maxspan + 1
cell.attrib['endCol'] -= maxspan + 1
new_cells[i][j] = cell
self.cells = new_cells

def visualize(self, window="image"):
image = self.image.copy()
if len(self.image.shape) == 2 or self.image.shape[2] < 3:
Expand Down
14 changes: 12 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def ensure_exists(filename, log_data):
else:
return True

def process_files(image_dir, xml_dir, ocr_dir, out_dir, num_samples, log_file):
def process_files(image_dir, xml_dir, ocr_dir, out_dir, num_samples, log_file, visualize):
files = list(map(lambda name: os.path.basename(name).rsplit('.', 1)[0], glob.glob(os.path.join(xml_dir,'*.xml'))))

files.sort()
Expand Down Expand Up @@ -116,6 +116,9 @@ def process_files(image_dir, xml_dir, ocr_dir, out_dir, num_samples, log_file):
with open(os.path.join(out_dir, "ocr", table_name + '.pkl'), 'wb') as f:
pickle.dump(ocr, f)

if visualize:
table.visualize(img)
cv2.imwrite(os.path.join(args.out_dir,'vis', table_name + '.png'), img)
generated += 1
except Exception as e:
log_data.append("Exception thrown: " + str(e))
Expand Down Expand Up @@ -147,12 +150,19 @@ def process_files(image_dir, xml_dir, ocr_dir, out_dir, num_samples, log_file):
parser.add_argument("-log", "--log_file", type=str, required=False,
help="Output file path for error logging.")

parser.add_argument("-vis", "--visualize", action="store_true")


args = parser.parse_args()

os.makedirs(args.out_dir, exist_ok=True)
os.makedirs(os.path.join(args.out_dir,'images'), exist_ok=True)
os.makedirs(os.path.join(args.out_dir,'ocr'), exist_ok=True)
os.makedirs(os.path.join(args.out_dir,'gt'), exist_ok=True)

if args.visualize:
os.makedirs(os.path.join(args.out_dir,'vis'), exist_ok=True)

process_files(args.image_dir, args.xml_dir, args.ocr_dir, args.out_dir, args.num_samples, args.log_file)
os.makedirs(args.ocr_dir, exist_ok=True)

process_files(args.image_dir, args.xml_dir, args.ocr_dir, args.out_dir, args.num_samples, args.log_file, args.visualize)

0 comments on commit 8dc38e8

Please sign in to comment.