[chore] exclude paddleocr directory from pre-commit hooks
This commit is contained in:
1
texteller/models/utils/__init__.py
Normal file
1
texteller/models/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .mix_inference import mix_inference
|
||||
BIN
texteller/models/utils/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
texteller/models/utils/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
texteller/models/utils/__pycache__/mix_inference.cpython-310.pyc
Normal file
BIN
texteller/models/utils/__pycache__/mix_inference.cpython-310.pyc
Normal file
Binary file not shown.
261
texteller/models/utils/mix_inference.py
Normal file
261
texteller/models/utils/mix_inference.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import re
|
||||
import heapq
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from collections import Counter
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
|
||||
from ..det_model.inference import predict as latex_det_predict
|
||||
from ..det_model.Bbox import Bbox, draw_bboxes
|
||||
|
||||
from ..ocr_model.utils.inference import inference as latex_rec_predict
|
||||
from ..ocr_model.utils.to_katex import to_katex, change_all
|
||||
|
||||
MAXV = 999999999
|
||||
|
||||
|
||||
def mask_img(img, bboxes: List[Bbox], bg_color: np.ndarray) -> np.ndarray:
|
||||
mask_img = img.copy()
|
||||
for bbox in bboxes:
|
||||
mask_img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w] = bg_color
|
||||
return mask_img
|
||||
|
||||
|
||||
def bbox_merge(sorted_bboxes: List[Bbox]) -> List[Bbox]:
|
||||
if len(sorted_bboxes) == 0:
|
||||
return []
|
||||
bboxes = sorted_bboxes.copy()
|
||||
guard = Bbox(MAXV, bboxes[-1].p.y, -1, -1, label="guard")
|
||||
bboxes.append(guard)
|
||||
res = []
|
||||
prev = bboxes[0]
|
||||
for curr in bboxes:
|
||||
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
|
||||
res.append(prev)
|
||||
prev = curr
|
||||
else:
|
||||
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
|
||||
return res
|
||||
|
||||
|
||||
def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbox]:
|
||||
if latex_bboxes == []:
|
||||
return ocr_bboxes
|
||||
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
|
||||
return ocr_bboxes
|
||||
|
||||
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
||||
|
||||
# log results
|
||||
for idx, bbox in enumerate(bboxes):
|
||||
bbox.content = str(idx)
|
||||
draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png")
|
||||
|
||||
assert len(bboxes) > 1
|
||||
|
||||
heapq.heapify(bboxes)
|
||||
res = []
|
||||
candidate = heapq.heappop(bboxes)
|
||||
curr = heapq.heappop(bboxes)
|
||||
idx = 0
|
||||
while len(bboxes) > 0:
|
||||
idx += 1
|
||||
assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
|
||||
|
||||
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
|
||||
res.append(candidate)
|
||||
candidate = curr
|
||||
curr = heapq.heappop(bboxes)
|
||||
elif candidate.ur_point.x < curr.ur_point.x:
|
||||
assert not (candidate.label != "text" and curr.label != "text")
|
||||
if candidate.label == "text" and curr.label == "text":
|
||||
candidate.w = curr.ur_point.x - candidate.p.x
|
||||
curr = heapq.heappop(bboxes)
|
||||
elif candidate.label != curr.label:
|
||||
if candidate.label == "text":
|
||||
candidate.w = curr.p.x - candidate.p.x
|
||||
res.append(candidate)
|
||||
candidate = curr
|
||||
curr = heapq.heappop(bboxes)
|
||||
else:
|
||||
curr.w = curr.ur_point.x - candidate.ur_point.x
|
||||
curr.p.x = candidate.ur_point.x
|
||||
heapq.heappush(bboxes, curr)
|
||||
curr = heapq.heappop(bboxes)
|
||||
|
||||
elif candidate.ur_point.x >= curr.ur_point.x:
|
||||
assert not (candidate.label != "text" and curr.label != "text")
|
||||
|
||||
if candidate.label == "text":
|
||||
assert curr.label != "text"
|
||||
heapq.heappush(
|
||||
bboxes,
|
||||
Bbox(
|
||||
curr.ur_point.x,
|
||||
candidate.p.y,
|
||||
candidate.h,
|
||||
candidate.ur_point.x - curr.ur_point.x,
|
||||
label="text",
|
||||
confidence=candidate.confidence,
|
||||
content=None,
|
||||
),
|
||||
)
|
||||
candidate.w = curr.p.x - candidate.p.x
|
||||
res.append(candidate)
|
||||
candidate = curr
|
||||
curr = heapq.heappop(bboxes)
|
||||
else:
|
||||
assert curr.label == "text"
|
||||
curr = heapq.heappop(bboxes)
|
||||
else:
|
||||
assert False
|
||||
res.append(candidate)
|
||||
res.append(curr)
|
||||
|
||||
# log results
|
||||
for idx, bbox in enumerate(res):
|
||||
bbox.content = str(idx)
|
||||
draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray]:
|
||||
sliced_imgs = []
|
||||
for bbox in ocr_bboxes:
|
||||
x, y = int(bbox.p.x), int(bbox.p.y)
|
||||
w, h = int(bbox.w), int(bbox.h)
|
||||
sliced_img = img[y : y + h, x : x + w]
|
||||
sliced_imgs.append(sliced_img)
|
||||
return sliced_imgs
|
||||
|
||||
|
||||
def mix_inference(
|
||||
img_path: str,
|
||||
infer_config,
|
||||
latex_det_model,
|
||||
lang_ocr_models,
|
||||
latex_rec_models,
|
||||
accelerator="cpu",
|
||||
num_beams=1,
|
||||
) -> str:
|
||||
'''
|
||||
Input a mixed image of formula text and output str (in markdown syntax)
|
||||
'''
|
||||
global img
|
||||
img = cv2.imread(img_path)
|
||||
corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])]
|
||||
bg_color = np.array(Counter(corners).most_common(1)[0][0])
|
||||
|
||||
start_time = time.time()
|
||||
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
|
||||
end_time = time.time()
|
||||
print(f"latex_det_model time: {end_time - start_time:.2f}s")
|
||||
latex_bboxes = sorted(latex_bboxes)
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
|
||||
latex_bboxes = bbox_merge(latex_bboxes)
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
|
||||
masked_img = mask_img(img, latex_bboxes, bg_color)
|
||||
|
||||
det_model, rec_model = lang_ocr_models
|
||||
start_time = time.time()
|
||||
det_prediction, _ = det_model(masked_img)
|
||||
end_time = time.time()
|
||||
print(f"ocr_det_model time: {end_time - start_time:.2f}s")
|
||||
ocr_bboxes = [
|
||||
Bbox(
|
||||
p[0][0],
|
||||
p[0][1],
|
||||
p[3][1] - p[0][1],
|
||||
p[1][0] - p[0][0],
|
||||
label="text",
|
||||
confidence=None,
|
||||
content=None,
|
||||
)
|
||||
for p in det_prediction
|
||||
]
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(unmerged).png")
|
||||
|
||||
ocr_bboxes = sorted(ocr_bboxes)
|
||||
ocr_bboxes = bbox_merge(ocr_bboxes)
|
||||
# log results
|
||||
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
|
||||
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
|
||||
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
|
||||
|
||||
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
|
||||
start_time = time.time()
|
||||
rec_predictions, _ = rec_model(sliced_imgs)
|
||||
end_time = time.time()
|
||||
print(f"ocr_rec_model time: {end_time - start_time:.2f}s")
|
||||
|
||||
assert len(rec_predictions) == len(ocr_bboxes)
|
||||
for content, bbox in zip(rec_predictions, ocr_bboxes):
|
||||
bbox.content = content[0]
|
||||
|
||||
latex_imgs = []
|
||||
for bbox in latex_bboxes:
|
||||
latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w])
|
||||
start_time = time.time()
|
||||
latex_rec_res = latex_rec_predict(
|
||||
*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=800
|
||||
)
|
||||
end_time = time.time()
|
||||
print(f"latex_rec_model time: {end_time - start_time:.2f}s")
|
||||
|
||||
for bbox, content in zip(latex_bboxes, latex_rec_res):
|
||||
bbox.content = to_katex(content)
|
||||
if bbox.label == "embedding":
|
||||
bbox.content = " $" + bbox.content + "$ "
|
||||
elif bbox.label == "isolated":
|
||||
bbox.content = '\n\n' + r"$$" + bbox.content + r"$$" + '\n\n'
|
||||
|
||||
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
||||
if bboxes == []:
|
||||
return ""
|
||||
|
||||
md = ""
|
||||
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
|
||||
for curr in bboxes:
|
||||
# Add the formula number back to the isolated formula
|
||||
if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr):
|
||||
curr.content = curr.content.strip()
|
||||
if curr.content.startswith('(') and curr.content.endswith(')'):
|
||||
curr.content = curr.content[1:-1]
|
||||
|
||||
if re.search(r'\\tag\{.*\}$', md[:-4]) is not None:
|
||||
# in case of multiple tag
|
||||
md = md[:-5] + f', {curr.content}' + '}' + md[-4:]
|
||||
else:
|
||||
md = md[:-4] + f'\\tag{{{curr.content}}}' + md[-4:]
|
||||
continue
|
||||
|
||||
if not prev.same_row(curr):
|
||||
md += " "
|
||||
|
||||
if curr.label == "embedding":
|
||||
# remove the bold effect from inline formulas
|
||||
curr.content = change_all(curr.content, r'\bm', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\boldsymbol', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\textit', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
|
||||
curr.content = change_all(curr.content, r'\mathbf', r' ', r'{', r'}', r'', r' ')
|
||||
|
||||
# change split environment into aligned
|
||||
curr.content = curr.content.replace(r'\begin{split}', r'\begin{aligned}')
|
||||
curr.content = curr.content.replace(r'\end{split}', r'\end{aligned}')
|
||||
|
||||
# remove extra spaces (keeping only one)
|
||||
curr.content = re.sub(r' +', ' ', curr.content)
|
||||
assert curr.content.startswith(' $') and curr.content.endswith('$ ')
|
||||
curr.content = ' $' + curr.content[2:-2].strip() + '$ '
|
||||
md += curr.content
|
||||
prev = curr
|
||||
return md.strip()
|
||||
Reference in New Issue
Block a user