From dbbec511efa7c84e2c1400cbd2a332c115e5b934 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Tue, 4 Jun 2024 14:24:23 +0000 Subject: [PATCH] Refine mix_inference 1. Add the formula number back to the isolated formula and merge multiple tag. 2. remove bold effect from inline formuals 3. change split environment into aligned --- src/models/utils/mix_inference.py | 60 ++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/src/models/utils/mix_inference.py b/src/models/utils/mix_inference.py index f063197..9da3c85 100644 --- a/src/models/utils/mix_inference.py +++ b/src/models/utils/mix_inference.py @@ -12,7 +12,7 @@ from ..det_model.inference import predict as latex_det_predict from ..det_model.Bbox import Bbox, draw_bboxes from ..ocr_model.utils.inference import inference as latex_rec_predict -from ..ocr_model.utils.to_katex import to_katex +from ..ocr_model.utils.to_katex import to_katex, change_all MAXV = 999999999 @@ -153,7 +153,10 @@ def mix_inference( tuple(img[-1, 0]), tuple(img[-1, -1])] bg_color = np.array(Counter(corners).most_common(1)[0][0]) + start_time = time.time() latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config) + end_time = time.time() + print(f"latex_det_model time: {end_time - start_time:.2f}s") latex_bboxes = sorted(latex_bboxes) # log results draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png") @@ -163,7 +166,10 @@ def mix_inference( masked_img = mask_img(img, latex_bboxes, bg_color) det_model, rec_model = lang_ocr_models + start_time = time.time() det_prediction, _ = det_model(masked_img) + end_time = time.time() + print(f"ocr_det_model time: {end_time - start_time:.2f}s") ocr_bboxes = [ Bbox( p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0], @@ -184,7 +190,10 @@ def mix_inference( ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes)) sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes) + start_time = time.time() rec_predictions, _ = rec_model(sliced_imgs) + end_time = time.time() + print(f"ocr_rec_model time: {end_time - start_time:.2f}s") assert len(rec_predictions) == len(ocr_bboxes) for content, bbox in zip(rec_predictions, ocr_bboxes): @@ -193,14 +202,18 @@ def mix_inference( latex_imgs =[] for bbox in latex_bboxes: latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w]) - latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=200) + start_time = time.time() + latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=800) + end_time = time.time() + print(f"latex_rec_model time: {end_time - start_time:.2f}s") for bbox, content in zip(latex_bboxes, latex_rec_res): bbox.content = to_katex(content) if bbox.label == "embedding": bbox.content = " $" + bbox.content + "$ " elif bbox.label == "isolated": - bbox.content = '\n' + r"$$" + bbox.content + r"$$" + '\n' + bbox.content = '\n\n' + r"$$" + bbox.content + r"$$" + '\n\n' + bboxes = sorted(ocr_bboxes + latex_bboxes) if bboxes == []: @@ -209,14 +222,43 @@ def mix_inference( md = "" prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard") for curr in bboxes: - if not prev.same_row(curr): - md += "\n" - md += curr.content + # Add the formula number back to the isolated formula if ( prev.label == "isolated" and curr.label == "text" - and bool(re.fullmatch(r"\([1-9]\d*?\)", curr.content)) + and prev.same_row(curr) ): - md += '\n' + curr.content = curr.content.strip() + if curr.content.startswith('(') and curr.content.endswith(')'): + curr.content = curr.content[1:-1] + + if re.search(r'\\tag\{.*\}$', md[:-4]) is not None: + # in case of multiple tag + md = md[:-5] + f', {curr.content}' + '}' + md[-4:] + else: + md = md[:-4] + f'\\tag{{{curr.content}}}' + md[-4:] + continue + + if not prev.same_row(curr): + md += " " + + if curr.label == "embedding": + # remove the bold effect from inline formulas + curr.content = change_all(curr.content, r'\bm', r' ', r'{', r'}', r'', r' ') + curr.content = change_all(curr.content, r'\boldsymbol', r' ', r'{', r'}', r'', r' ') + curr.content = change_all(curr.content, r'\textit', r' ', r'{', r'}', r'', r' ') + curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ') + curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ') + curr.content = change_all(curr.content, r'\mathbf', r' ', r'{', r'}', r'', r' ') + + # change split environment into aligned + curr.content = curr.content.replace(r'\begin{split}', r'\begin{aligned}') + curr.content = curr.content.replace(r'\end{split}', r'\end{aligned}') + + # remove extra spaces (keeping only one) + curr.content = re.sub(r' +', ' ', curr.content) + assert curr.content.startswith(' $') and curr.content.endswith('$ ') + curr.content = ' $' + curr.content[2:-2].strip() + '$ ' + md += curr.content prev = curr - return md + return md.strip()