Refine mix_inference

1. Add the formula number back to the isolated formula and merge multiple tag.
2. remove bold effect from inline formuals
3. change split environment into aligned
This commit is contained in:
三洋三洋
2024-06-04 14:24:23 +00:00
parent 29e626c984
commit dbbec511ef

View File

@@ -12,7 +12,7 @@ from ..det_model.inference import predict as latex_det_predict
from ..det_model.Bbox import Bbox, draw_bboxes
from ..ocr_model.utils.inference import inference as latex_rec_predict
from ..ocr_model.utils.to_katex import to_katex
from ..ocr_model.utils.to_katex import to_katex, change_all
MAXV = 999999999
@@ -153,7 +153,10 @@ def mix_inference(
tuple(img[-1, 0]), tuple(img[-1, -1])]
bg_color = np.array(Counter(corners).most_common(1)[0][0])
start_time = time.time()
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
end_time = time.time()
print(f"latex_det_model time: {end_time - start_time:.2f}s")
latex_bboxes = sorted(latex_bboxes)
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
@@ -163,7 +166,10 @@ def mix_inference(
masked_img = mask_img(img, latex_bboxes, bg_color)
det_model, rec_model = lang_ocr_models
start_time = time.time()
det_prediction, _ = det_model(masked_img)
end_time = time.time()
print(f"ocr_det_model time: {end_time - start_time:.2f}s")
ocr_bboxes = [
Bbox(
p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0],
@@ -184,7 +190,10 @@ def mix_inference(
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
start_time = time.time()
rec_predictions, _ = rec_model(sliced_imgs)
end_time = time.time()
print(f"ocr_rec_model time: {end_time - start_time:.2f}s")
assert len(rec_predictions) == len(ocr_bboxes)
for content, bbox in zip(rec_predictions, ocr_bboxes):
@@ -193,14 +202,18 @@ def mix_inference(
latex_imgs =[]
for bbox in latex_bboxes:
latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w])
latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=200)
start_time = time.time()
latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=800)
end_time = time.time()
print(f"latex_rec_model time: {end_time - start_time:.2f}s")
for bbox, content in zip(latex_bboxes, latex_rec_res):
bbox.content = to_katex(content)
if bbox.label == "embedding":
bbox.content = " $" + bbox.content + "$ "
elif bbox.label == "isolated":
bbox.content = '\n' + r"$$" + bbox.content + r"$$" + '\n'
bbox.content = '\n\n' + r"$$" + bbox.content + r"$$" + '\n\n'
bboxes = sorted(ocr_bboxes + latex_bboxes)
if bboxes == []:
@@ -209,14 +222,43 @@ def mix_inference(
md = ""
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
for curr in bboxes:
if not prev.same_row(curr):
md += "\n"
md += curr.content
# Add the formula number back to the isolated formula
if (
prev.label == "isolated"
and curr.label == "text"
and bool(re.fullmatch(r"\([1-9]\d*?\)", curr.content))
and prev.same_row(curr)
):
md += '\n'
curr.content = curr.content.strip()
if curr.content.startswith('(') and curr.content.endswith(')'):
curr.content = curr.content[1:-1]
if re.search(r'\\tag\{.*\}$', md[:-4]) is not None:
# in case of multiple tag
md = md[:-5] + f', {curr.content}' + '}' + md[-4:]
else:
md = md[:-4] + f'\\tag{{{curr.content}}}' + md[-4:]
continue
if not prev.same_row(curr):
md += " "
if curr.label == "embedding":
# remove the bold effect from inline formulas
curr.content = change_all(curr.content, r'\bm', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\boldsymbol', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textit', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\mathbf', r' ', r'{', r'}', r'', r' ')
# change split environment into aligned
curr.content = curr.content.replace(r'\begin{split}', r'\begin{aligned}')
curr.content = curr.content.replace(r'\end{split}', r'\end{aligned}')
# remove extra spaces (keeping only one)
curr.content = re.sub(r' +', ' ', curr.content)
assert curr.content.startswith(' $') and curr.content.endswith('$ ')
curr.content = ' $' + curr.content[2:-2].strip() + '$ '
md += curr.content
prev = curr
return md
return md.strip()