Refine mix_inference

1. Add the formula number back to the isolated formula and merge multiple tag.
2. remove bold effect from inline formuals
3. change split environment into aligned
This commit is contained in:
三洋三洋
2024-06-04 14:24:23 +00:00
parent c0e730f697
commit 760bd78c10

View File

@@ -12,7 +12,7 @@ from ..det_model.inference import predict as latex_det_predict
from ..det_model.Bbox import Bbox, draw_bboxes from ..det_model.Bbox import Bbox, draw_bboxes
from ..ocr_model.utils.inference import inference as latex_rec_predict from ..ocr_model.utils.inference import inference as latex_rec_predict
from ..ocr_model.utils.to_katex import to_katex from ..ocr_model.utils.to_katex import to_katex, change_all
MAXV = 999999999 MAXV = 999999999
@@ -153,7 +153,10 @@ def mix_inference(
tuple(img[-1, 0]), tuple(img[-1, -1])] tuple(img[-1, 0]), tuple(img[-1, -1])]
bg_color = np.array(Counter(corners).most_common(1)[0][0]) bg_color = np.array(Counter(corners).most_common(1)[0][0])
start_time = time.time()
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config) latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
end_time = time.time()
print(f"latex_det_model time: {end_time - start_time:.2f}s")
latex_bboxes = sorted(latex_bboxes) latex_bboxes = sorted(latex_bboxes)
# log results # log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png") draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
@@ -163,7 +166,10 @@ def mix_inference(
masked_img = mask_img(img, latex_bboxes, bg_color) masked_img = mask_img(img, latex_bboxes, bg_color)
det_model, rec_model = lang_ocr_models det_model, rec_model = lang_ocr_models
start_time = time.time()
det_prediction, _ = det_model(masked_img) det_prediction, _ = det_model(masked_img)
end_time = time.time()
print(f"ocr_det_model time: {end_time - start_time:.2f}s")
ocr_bboxes = [ ocr_bboxes = [
Bbox( Bbox(
p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0], p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0],
@@ -184,7 +190,10 @@ def mix_inference(
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes)) ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes) sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
start_time = time.time()
rec_predictions, _ = rec_model(sliced_imgs) rec_predictions, _ = rec_model(sliced_imgs)
end_time = time.time()
print(f"ocr_rec_model time: {end_time - start_time:.2f}s")
assert len(rec_predictions) == len(ocr_bboxes) assert len(rec_predictions) == len(ocr_bboxes)
for content, bbox in zip(rec_predictions, ocr_bboxes): for content, bbox in zip(rec_predictions, ocr_bboxes):
@@ -193,14 +202,18 @@ def mix_inference(
latex_imgs =[] latex_imgs =[]
for bbox in latex_bboxes: for bbox in latex_bboxes:
latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w]) latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w])
latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=200) start_time = time.time()
latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=800)
end_time = time.time()
print(f"latex_rec_model time: {end_time - start_time:.2f}s")
for bbox, content in zip(latex_bboxes, latex_rec_res): for bbox, content in zip(latex_bboxes, latex_rec_res):
bbox.content = to_katex(content) bbox.content = to_katex(content)
if bbox.label == "embedding": if bbox.label == "embedding":
bbox.content = " $" + bbox.content + "$ " bbox.content = " $" + bbox.content + "$ "
elif bbox.label == "isolated": elif bbox.label == "isolated":
bbox.content = '\n' + r"$$" + bbox.content + r"$$" + '\n' bbox.content = '\n\n' + r"$$" + bbox.content + r"$$" + '\n\n'
bboxes = sorted(ocr_bboxes + latex_bboxes) bboxes = sorted(ocr_bboxes + latex_bboxes)
if bboxes == []: if bboxes == []:
@@ -209,14 +222,43 @@ def mix_inference(
md = "" md = ""
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard") prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
for curr in bboxes: for curr in bboxes:
if not prev.same_row(curr): # Add the formula number back to the isolated formula
md += "\n"
md += curr.content
if ( if (
prev.label == "isolated" prev.label == "isolated"
and curr.label == "text" and curr.label == "text"
and bool(re.fullmatch(r"\([1-9]\d*?\)", curr.content)) and prev.same_row(curr)
): ):
md += '\n' curr.content = curr.content.strip()
if curr.content.startswith('(') and curr.content.endswith(')'):
curr.content = curr.content[1:-1]
if re.search(r'\\tag\{.*\}$', md[:-4]) is not None:
# in case of multiple tag
md = md[:-5] + f', {curr.content}' + '}' + md[-4:]
else:
md = md[:-4] + f'\\tag{{{curr.content}}}' + md[-4:]
continue
if not prev.same_row(curr):
md += " "
if curr.label == "embedding":
# remove the bold effect from inline formulas
curr.content = change_all(curr.content, r'\bm', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\boldsymbol', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textit', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\mathbf', r' ', r'{', r'}', r'', r' ')
# change split environment into aligned
curr.content = curr.content.replace(r'\begin{split}', r'\begin{aligned}')
curr.content = curr.content.replace(r'\end{split}', r'\end{aligned}')
# remove extra spaces (keeping only one)
curr.content = re.sub(r' +', ' ', curr.content)
assert curr.content.startswith(' $') and curr.content.endswith('$ ')
curr.content = ' $' + curr.content[2:-2].strip() + '$ '
md += curr.content
prev = curr prev = curr
return md return md.strip()