Files
TexTeller/texteller/models/utils/mix_inference.py

262 lines
9.2 KiB
Python
Raw Normal View History

import re
import heapq
import cv2
import time
import numpy as np
from collections import Counter
from typing import List
from PIL import Image
2024-04-21 00:48:24 +08:00
from ..det_model.inference import predict as latex_det_predict
from ..det_model.Bbox import Bbox, draw_bboxes
from ..ocr_model.utils.inference import inference as latex_rec_predict
from ..ocr_model.utils.to_katex import to_katex, change_all
MAXV = 999999999
def mask_img(img, bboxes: List[Bbox], bg_color: np.ndarray) -> np.ndarray:
mask_img = img.copy()
for bbox in bboxes:
mask_img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w] = bg_color
return mask_img
def bbox_merge(sorted_bboxes: List[Bbox]) -> List[Bbox]:
if len(sorted_bboxes) == 0:
return []
bboxes = sorted_bboxes.copy()
guard = Bbox(MAXV, bboxes[-1].p.y, -1, -1, label="guard")
bboxes.append(guard)
res = []
prev = bboxes[0]
for curr in bboxes:
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
res.append(prev)
prev = curr
else:
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
return res
def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbox]:
if latex_bboxes == []:
return ocr_bboxes
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
return ocr_bboxes
bboxes = sorted(ocr_bboxes + latex_bboxes)
2024-04-21 00:48:24 +08:00
# log results
for idx, bbox in enumerate(bboxes):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png")
assert len(bboxes) > 1
heapq.heapify(bboxes)
res = []
candidate = heapq.heappop(bboxes)
curr = heapq.heappop(bboxes)
idx = 0
while len(bboxes) > 0:
idx += 1
assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x < curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text" and curr.label == "text":
candidate.w = curr.ur_point.x - candidate.p.x
curr = heapq.heappop(bboxes)
elif candidate.label != curr.label:
if candidate.label == "text":
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
curr.w = curr.ur_point.x - candidate.ur_point.x
curr.p.x = candidate.ur_point.x
heapq.heappush(bboxes, curr)
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x >= curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text":
assert curr.label != "text"
heapq.heappush(
bboxes,
Bbox(
curr.ur_point.x,
candidate.p.y,
candidate.h,
candidate.ur_point.x - curr.ur_point.x,
label="text",
confidence=candidate.confidence,
content=None,
),
)
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
assert curr.label == "text"
curr = heapq.heappop(bboxes)
else:
assert False
res.append(candidate)
res.append(curr)
2024-04-21 00:48:24 +08:00
# log results
for idx, bbox in enumerate(res):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png")
2024-04-21 00:48:24 +08:00
return res
def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray]:
sliced_imgs = []
for bbox in ocr_bboxes:
x, y = int(bbox.p.x), int(bbox.p.y)
w, h = int(bbox.w), int(bbox.h)
sliced_img = img[y : y + h, x : x + w]
sliced_imgs.append(sliced_img)
return sliced_imgs
def mix_inference(
img_path: str,
infer_config,
latex_det_model,
lang_ocr_models,
latex_rec_models,
accelerator="cpu",
num_beams=1,
) -> str:
'''
Input a mixed image of formula text and output str (in markdown syntax)
'''
global img
img = cv2.imread(img_path)
corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])]
bg_color = np.array(Counter(corners).most_common(1)[0][0])
start_time = time.time()
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
end_time = time.time()
print(f"latex_det_model time: {end_time - start_time:.2f}s")
latex_bboxes = sorted(latex_bboxes)
2024-04-21 00:48:24 +08:00
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
latex_bboxes = bbox_merge(latex_bboxes)
2024-04-21 00:48:24 +08:00
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
masked_img = mask_img(img, latex_bboxes, bg_color)
det_model, rec_model = lang_ocr_models
start_time = time.time()
det_prediction, _ = det_model(masked_img)
end_time = time.time()
print(f"ocr_det_model time: {end_time - start_time:.2f}s")
ocr_bboxes = [
Bbox(
p[0][0],
p[0][1],
p[3][1] - p[0][1],
p[1][0] - p[0][0],
label="text",
confidence=None,
content=None,
)
for p in det_prediction
]
# log results
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(unmerged).png")
ocr_bboxes = sorted(ocr_bboxes)
ocr_bboxes = bbox_merge(ocr_bboxes)
# log results
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
start_time = time.time()
rec_predictions, _ = rec_model(sliced_imgs)
end_time = time.time()
print(f"ocr_rec_model time: {end_time - start_time:.2f}s")
assert len(rec_predictions) == len(ocr_bboxes)
for content, bbox in zip(rec_predictions, ocr_bboxes):
bbox.content = content[0]
latex_imgs = []
for bbox in latex_bboxes:
latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w])
start_time = time.time()
latex_rec_res = latex_rec_predict(
*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=800
)
end_time = time.time()
print(f"latex_rec_model time: {end_time - start_time:.2f}s")
for bbox, content in zip(latex_bboxes, latex_rec_res):
bbox.content = to_katex(content)
if bbox.label == "embedding":
bbox.content = " $" + bbox.content + "$ "
elif bbox.label == "isolated":
bbox.content = '\n\n' + r"$$" + bbox.content + r"$$" + '\n\n'
bboxes = sorted(ocr_bboxes + latex_bboxes)
if bboxes == []:
return ""
md = ""
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
for curr in bboxes:
# Add the formula number back to the isolated formula
if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr):
curr.content = curr.content.strip()
if curr.content.startswith('(') and curr.content.endswith(')'):
curr.content = curr.content[1:-1]
if re.search(r'\\tag\{.*\}$', md[:-4]) is not None:
# in case of multiple tag
md = md[:-5] + f', {curr.content}' + '}' + md[-4:]
else:
md = md[:-4] + f'\\tag{{{curr.content}}}' + md[-4:]
continue
if not prev.same_row(curr):
md += " "
if curr.label == "embedding":
# remove the bold effect from inline formulas
curr.content = change_all(curr.content, r'\bm', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\boldsymbol', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textit', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\mathbf', r' ', r'{', r'}', r'', r' ')
# change split environment into aligned
curr.content = curr.content.replace(r'\begin{split}', r'\begin{aligned}')
curr.content = curr.content.replace(r'\end{split}', r'\end{aligned}')
# remove extra spaces (keeping only one)
curr.content = re.sub(r' +', ' ', curr.content)
assert curr.content.startswith(' $') and curr.content.endswith('$ ')
curr.content = ' $' + curr.content[2:-2].strip() + '$ '
md += curr.content
prev = curr
return md.strip()