inference.py支持katex语法

This commit is contained in:
三洋三洋
2024-04-06 10:09:15 +00:00
parent b5f7166e58
commit 1db514bdbf
3 changed files with 19 additions and 15 deletions

View File

@@ -3,6 +3,7 @@ import argparse
import cv2 as cv import cv2 as cv
from pathlib import Path from pathlib import Path
from utils import to_katex
from models.ocr_model.utils.inference import inference as latex_inference from models.ocr_model.utils.inference import inference as latex_inference
from models.ocr_model.model.TexTeller import TexTeller from models.ocr_model.model.TexTeller import TexTeller
from utils import load_det_tex_model, load_lang_models from utils import load_det_tex_model, load_lang_models
@@ -44,7 +45,8 @@ if __name__ == '__main__':
print('Inference...') print('Inference...')
if not args.mix: if not args.mix:
res = latex_inference(latex_rec_model, tokenizer, [img], args.cuda) res = latex_inference(latex_rec_model, tokenizer, [img], args.cuda)
print(res[0]) res = to_katex(res[0])
print(res)
else: else:
# latex_det_model = load_det_tex_model() # latex_det_model = load_det_tex_model()
# lang_model = load_lang_models()... # lang_model = load_lang_models()...

View File

@@ -1,8 +1,23 @@
import numpy as np import numpy as np
import re
from models.ocr_model.utils.inference import inference as latex_inference from models.ocr_model.utils.inference import inference as latex_inference
def to_katex(formula: str) -> str:
res = formula
res = re.sub(r'\\mbox\{([^}]*)\}', r'\1', res)
res = re.sub(r'boldmath\$(.*?)\$', r'bm{\1}', res)
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
pattern = r'(\\(?:left|middle|right|big|Big|bigg|Bigg|bigl|Bigl|biggl|Biggl|bigm|Bigm|biggm|Biggm|bigr|Bigr|biggr|Biggr))\{([^}]*)\}'
replacement = r'\1\2'
res = re.sub(pattern, replacement, res)
if res.endswith(r'\newline'):
res = res[:-8]
return res
def load_lang_models(language: str): def load_lang_models(language: str):
... ...
# language: 'ch' or 'en' # language: 'ch' or 'en'

View File

@@ -4,11 +4,11 @@ import base64
import tempfile import tempfile
import shutil import shutil
import streamlit as st import streamlit as st
import re
from PIL import Image from PIL import Image
from models.ocr_model.utils.inference import inference from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller from models.ocr_model.model.TexTeller import TexTeller
from utils import to_katex
html_string = ''' html_string = '''
@@ -66,19 +66,6 @@ def get_model():
def get_tokenizer(): def get_tokenizer():
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR']) return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
def to_katex(formula: str) -> str:
res = formula
res = re.sub(r'\\mbox\{([^}]*)\}', r'\1', res)
res = re.sub(r'boldmath\$(.*?)\$', r'bm{\1}', res)
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
pattern = r'(\\(?:left|middle|right|big|Big|bigg|Bigg|bigl|Bigl|biggl|Biggl|bigm|Bigm|biggm|Biggm|bigr|Bigr|biggr|Biggr))\{([^}]*)\}'
replacement = r'\1\2'
res = re.sub(pattern, replacement, res)
if res.endswith(r'\newline'):
res = res[:-8]
return res
def get_image_base64(img_file): def get_image_base64(img_file):
buffered = io.BytesIO() buffered = io.BytesIO()
img_file.seek(0) img_file.seek(0)