1) 实现了文本-公式混排识别; 2) 重构了项目结构

This commit is contained in:
三洋三洋
2024-04-21 00:05:14 +08:00
parent 3746ddd427
commit f84168a00b
19 changed files with 753 additions and 296 deletions

View File

@@ -4,19 +4,22 @@ import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List, Union
from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.transforms import inference_transform
from models.ocr_model.utils.helpers import convert2rgb
from models.globals import MAX_TOKEN_SIZE
from .transforms import inference_transform
from .helpers import convert2rgb
from ..model.TexTeller import TexTeller
from ...globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs: Union[List[str], List[np.ndarray]],
inf_mode: str = 'cpu',
accelerator: str = 'cpu',
num_beams: int = 1,
max_tokens = None
) -> List[str]:
if imgs == []:
return []
model.eval()
if isinstance(imgs[0], str):
imgs = convert2rgb(imgs)
@@ -26,11 +29,11 @@ def inference(
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
model = model.to(inf_mode)
pixel_values = pixel_values.to(inf_mode)
model = model.to(accelerator)
pixel_values = pixel_values.to(accelerator)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,

View File

@@ -0,0 +1,110 @@
import re
def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
result = ""
i = 0
n = len(input_str)
while i < n:
if input_str[i:i+len(old_inst)] == old_inst:
# check if the old_inst is followed by old_surr_l
start = i + len(old_inst)
else:
result += input_str[i]
i += 1
continue
if start < n and input_str[start] == old_surr_l:
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
count = 1
j = start + 1
escaped = False
while j < n and count > 0:
if input_str[j] == '\\' and not escaped:
escaped = True
j += 1
continue
if input_str[j] == old_surr_r and not escaped:
count -= 1
if count == 0:
break
elif input_str[j] == old_surr_l and not escaped:
count += 1
escaped = False
j += 1
if count == 0:
assert j < n
assert input_str[start] == old_surr_l
assert input_str[j] == old_surr_r
inner_content = input_str[start + 1:j]
# Replace the content with new pattern
result += new_inst + new_surr_l + inner_content + new_surr_r
i = j + 1
continue
else:
assert count > 1
assert j == n
print("Warning: unbalanced surrogate pair in input string")
result += new_inst + new_surr_l
i = start + 1
continue
else:
i = start
if old_inst != new_inst and old_inst in result:
return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
else:
return result
def to_katex(formula: str) -> str:
res = formula
res = change(res, r'\mbox', r'', r'{', r'}', r'', r'')
origin_instructions = [
r'\Huge',
r'\huge',
r'\LARGE',
r'\Large',
r'\large',
r'\normalsize',
r'\small',
r'\footnotesize',
r'\scriptsize',
r'\tiny'
]
for (old_ins, new_ins) in zip(origin_instructions, origin_instructions):
res = change(res, old_ins, new_ins, r'$', r'$', '{', '}')
res = change(res, r'\boldmath', r'\bm', r'$', r'$', r'{', r'}')
origin_instructions = [
r'\left',
r'\middle',
r'\right',
r'\big',
r'\Big',
r'\bigg',
r'\Bigg',
r'\bigl',
r'\Bigl',
r'\biggl',
r'\Biggl',
r'\bigm',
r'\Bigm',
r'\biggm',
r'\Biggm',
r'\bigr',
r'\Bigr',
r'\biggr',
r'\Biggr'
]
for origin_ins in origin_instructions:
res = change(res, origin_ins, origin_ins, r'{', r'}', r'', r'')
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
if res.endswith(r'\newline'):
res = res[:-8]
return res