Refine mix_inference
1. Add the formula number back to the isolated formula and merge multiple tag. 2. remove bold effect from inline formuals 3. change split environment into aligned
This commit is contained in:
@@ -12,7 +12,7 @@ from ..det_model.inference import predict as latex_det_predict
|
|||||||
from ..det_model.Bbox import Bbox, draw_bboxes
|
from ..det_model.Bbox import Bbox, draw_bboxes
|
||||||
|
|
||||||
from ..ocr_model.utils.inference import inference as latex_rec_predict
|
from ..ocr_model.utils.inference import inference as latex_rec_predict
|
||||||
from ..ocr_model.utils.to_katex import to_katex
|
from ..ocr_model.utils.to_katex import to_katex, change_all
|
||||||
|
|
||||||
MAXV = 999999999
|
MAXV = 999999999
|
||||||
|
|
||||||
@@ -153,7 +153,10 @@ def mix_inference(
|
|||||||
tuple(img[-1, 0]), tuple(img[-1, -1])]
|
tuple(img[-1, 0]), tuple(img[-1, -1])]
|
||||||
bg_color = np.array(Counter(corners).most_common(1)[0][0])
|
bg_color = np.array(Counter(corners).most_common(1)[0][0])
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
|
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"latex_det_model time: {end_time - start_time:.2f}s")
|
||||||
latex_bboxes = sorted(latex_bboxes)
|
latex_bboxes = sorted(latex_bboxes)
|
||||||
# log results
|
# log results
|
||||||
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
|
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
|
||||||
@@ -163,7 +166,10 @@ def mix_inference(
|
|||||||
masked_img = mask_img(img, latex_bboxes, bg_color)
|
masked_img = mask_img(img, latex_bboxes, bg_color)
|
||||||
|
|
||||||
det_model, rec_model = lang_ocr_models
|
det_model, rec_model = lang_ocr_models
|
||||||
|
start_time = time.time()
|
||||||
det_prediction, _ = det_model(masked_img)
|
det_prediction, _ = det_model(masked_img)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"ocr_det_model time: {end_time - start_time:.2f}s")
|
||||||
ocr_bboxes = [
|
ocr_bboxes = [
|
||||||
Bbox(
|
Bbox(
|
||||||
p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0],
|
p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0],
|
||||||
@@ -184,7 +190,10 @@ def mix_inference(
|
|||||||
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
|
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
|
||||||
|
|
||||||
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
|
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
|
||||||
|
start_time = time.time()
|
||||||
rec_predictions, _ = rec_model(sliced_imgs)
|
rec_predictions, _ = rec_model(sliced_imgs)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"ocr_rec_model time: {end_time - start_time:.2f}s")
|
||||||
|
|
||||||
assert len(rec_predictions) == len(ocr_bboxes)
|
assert len(rec_predictions) == len(ocr_bboxes)
|
||||||
for content, bbox in zip(rec_predictions, ocr_bboxes):
|
for content, bbox in zip(rec_predictions, ocr_bboxes):
|
||||||
@@ -193,14 +202,18 @@ def mix_inference(
|
|||||||
latex_imgs =[]
|
latex_imgs =[]
|
||||||
for bbox in latex_bboxes:
|
for bbox in latex_bboxes:
|
||||||
latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w])
|
latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w])
|
||||||
latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=200)
|
start_time = time.time()
|
||||||
|
latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=800)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"latex_rec_model time: {end_time - start_time:.2f}s")
|
||||||
|
|
||||||
for bbox, content in zip(latex_bboxes, latex_rec_res):
|
for bbox, content in zip(latex_bboxes, latex_rec_res):
|
||||||
bbox.content = to_katex(content)
|
bbox.content = to_katex(content)
|
||||||
if bbox.label == "embedding":
|
if bbox.label == "embedding":
|
||||||
bbox.content = " $" + bbox.content + "$ "
|
bbox.content = " $" + bbox.content + "$ "
|
||||||
elif bbox.label == "isolated":
|
elif bbox.label == "isolated":
|
||||||
bbox.content = '\n' + r"$$" + bbox.content + r"$$" + '\n'
|
bbox.content = '\n\n' + r"$$" + bbox.content + r"$$" + '\n\n'
|
||||||
|
|
||||||
|
|
||||||
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
||||||
if bboxes == []:
|
if bboxes == []:
|
||||||
@@ -209,14 +222,43 @@ def mix_inference(
|
|||||||
md = ""
|
md = ""
|
||||||
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
|
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
|
||||||
for curr in bboxes:
|
for curr in bboxes:
|
||||||
if not prev.same_row(curr):
|
# Add the formula number back to the isolated formula
|
||||||
md += "\n"
|
|
||||||
md += curr.content
|
|
||||||
if (
|
if (
|
||||||
prev.label == "isolated"
|
prev.label == "isolated"
|
||||||
and curr.label == "text"
|
and curr.label == "text"
|
||||||
and bool(re.fullmatch(r"\([1-9]\d*?\)", curr.content))
|
and prev.same_row(curr)
|
||||||
):
|
):
|
||||||
md += '\n'
|
curr.content = curr.content.strip()
|
||||||
|
if curr.content.startswith('(') and curr.content.endswith(')'):
|
||||||
|
curr.content = curr.content[1:-1]
|
||||||
|
|
||||||
|
if re.search(r'\\tag\{.*\}$', md[:-4]) is not None:
|
||||||
|
# in case of multiple tag
|
||||||
|
md = md[:-5] + f', {curr.content}' + '}' + md[-4:]
|
||||||
|
else:
|
||||||
|
md = md[:-4] + f'\\tag{{{curr.content}}}' + md[-4:]
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not prev.same_row(curr):
|
||||||
|
md += " "
|
||||||
|
|
||||||
|
if curr.label == "embedding":
|
||||||
|
# remove the bold effect from inline formulas
|
||||||
|
curr.content = change_all(curr.content, r'\bm', r' ', r'{', r'}', r'', r' ')
|
||||||
|
curr.content = change_all(curr.content, r'\boldsymbol', r' ', r'{', r'}', r'', r' ')
|
||||||
|
curr.content = change_all(curr.content, r'\textit', r' ', r'{', r'}', r'', r' ')
|
||||||
|
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
|
||||||
|
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
|
||||||
|
curr.content = change_all(curr.content, r'\mathbf', r' ', r'{', r'}', r'', r' ')
|
||||||
|
|
||||||
|
# change split environment into aligned
|
||||||
|
curr.content = curr.content.replace(r'\begin{split}', r'\begin{aligned}')
|
||||||
|
curr.content = curr.content.replace(r'\end{split}', r'\end{aligned}')
|
||||||
|
|
||||||
|
# remove extra spaces (keeping only one)
|
||||||
|
curr.content = re.sub(r' +', ' ', curr.content)
|
||||||
|
assert curr.content.startswith(' $') and curr.content.endswith('$ ')
|
||||||
|
curr.content = ' $' + curr.content[2:-2].strip() + '$ '
|
||||||
|
md += curr.content
|
||||||
prev = curr
|
prev = curr
|
||||||
return md
|
return md.strip()
|
||||||
|
|||||||
Reference in New Issue
Block a user