diff --git a/src/inference.py b/src/inference.py index 99a2992..c0e263c 100644 --- a/src/inference.py +++ b/src/inference.py @@ -3,6 +3,7 @@ import argparse import cv2 as cv from pathlib import Path +from utils import to_katex from models.ocr_model.utils.inference import inference as latex_inference from models.ocr_model.model.TexTeller import TexTeller from utils import load_det_tex_model, load_lang_models @@ -44,7 +45,8 @@ if __name__ == '__main__': print('Inference...') if not args.mix: res = latex_inference(latex_rec_model, tokenizer, [img], args.cuda) - print(res[0]) + res = to_katex(res[0]) + print(res) else: # latex_det_model = load_det_tex_model() # lang_model = load_lang_models()... diff --git a/src/utils.py b/src/utils.py index e6e4a9d..7822a76 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,8 +1,23 @@ import numpy as np +import re from models.ocr_model.utils.inference import inference as latex_inference +def to_katex(formula: str) -> str: + res = formula + res = re.sub(r'\\mbox\{([^}]*)\}', r'\1', res) + res = re.sub(r'boldmath\$(.*?)\$', r'bm{\1}', res) + res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res) + + pattern = r'(\\(?:left|middle|right|big|Big|bigg|Bigg|bigl|Bigl|biggl|Biggl|bigm|Bigm|biggm|Biggm|bigr|Bigr|biggr|Biggr))\{([^}]*)\}' + replacement = r'\1\2' + res = re.sub(pattern, replacement, res) + if res.endswith(r'\newline'): + res = res[:-8] + return res + + def load_lang_models(language: str): ... # language: 'ch' or 'en' diff --git a/src/web.py b/src/web.py index 88a4381..9b53a59 100644 --- a/src/web.py +++ b/src/web.py @@ -4,11 +4,11 @@ import base64 import tempfile import shutil import streamlit as st -import re from PIL import Image from models.ocr_model.utils.inference import inference from models.ocr_model.model.TexTeller import TexTeller +from utils import to_katex html_string = ''' @@ -66,19 +66,6 @@ def get_model(): def get_tokenizer(): return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR']) -def to_katex(formula: str) -> str: - res = formula - res = re.sub(r'\\mbox\{([^}]*)\}', r'\1', res) - res = re.sub(r'boldmath\$(.*?)\$', r'bm{\1}', res) - res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res) - - pattern = r'(\\(?:left|middle|right|big|Big|bigg|Bigg|bigl|Bigl|biggl|Biggl|bigm|Bigm|biggm|Biggm|bigr|Bigr|biggr|Biggr))\{([^}]*)\}' - replacement = r'\1\2' - res = re.sub(pattern, replacement, res) - if res.endswith(r'\newline'): - res = res[:-8] - return res - def get_image_base64(img_file): buffered = io.BytesIO() img_file.seek(0)