From 35bc4e71a184d6acc811d1c8af757225fed0c21d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Sat, 6 Apr 2024 11:38:59 +0000 Subject: [PATCH] =?UTF-8?q?inference.py=E6=94=AF=E6=8C=81katex?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/inference.py | 14 ++++++++------ src/utils.py | 15 +++++++++++++++ src/web.py | 15 +-------------- 3 files changed, 24 insertions(+), 20 deletions(-) create mode 100644 src/utils.py diff --git a/src/inference.py b/src/inference.py index c6d6f61..2f4127a 100644 --- a/src/inference.py +++ b/src/inference.py @@ -1,8 +1,10 @@ import os import argparse +import cv2 as cv from pathlib import Path -from models.ocr_model.utils.inference import inference +from utils import to_katex +from models.ocr_model.utils.inference import inference as latex_inference from models.ocr_model.model.TexTeller import TexTeller @@ -21,16 +23,16 @@ if __name__ == '__main__': action='store_true', help='use cuda or not' ) - args = parser.parse_args() # You can use your own checkpoint and tokenizer path. print('Loading model and tokenizer...') - model = TexTeller.from_pretrained() + latex_rec_model = TexTeller.from_pretrained() tokenizer = TexTeller.get_tokenizer() print('Model and tokenizer loaded.') - img_path = [args.img] + img = cv.imread(args.img) print('Inference...') - res = inference(model, tokenizer, img_path, args.cuda) - print(res[0]) + res = latex_inference(latex_rec_model, tokenizer, [img], args.cuda) + res = to_katex(res[0]) + print(res) diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..6131bae --- /dev/null +++ b/src/utils.py @@ -0,0 +1,15 @@ +import re + + +def to_katex(formula: str) -> str: + res = formula + res = re.sub(r'\\mbox\{([^}]*)\}', r'\1', res) + res = re.sub(r'boldmath\$(.*?)\$', r'bm{\1}', res) + res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res) + + pattern = r'(\\(?:left|middle|right|big|Big|bigg|Bigg|bigl|Bigl|biggl|Biggl|bigm|Bigm|biggm|Biggm|bigr|Bigr|biggr|Biggr))\{([^}]*)\}' + replacement = r'\1\2' + res = re.sub(pattern, replacement, res) + if res.endswith(r'\newline'): + res = res[:-8] + return res diff --git a/src/web.py b/src/web.py index 88a4381..9b53a59 100644 --- a/src/web.py +++ b/src/web.py @@ -4,11 +4,11 @@ import base64 import tempfile import shutil import streamlit as st -import re from PIL import Image from models.ocr_model.utils.inference import inference from models.ocr_model.model.TexTeller import TexTeller +from utils import to_katex html_string = ''' @@ -66,19 +66,6 @@ def get_model(): def get_tokenizer(): return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR']) -def to_katex(formula: str) -> str: - res = formula - res = re.sub(r'\\mbox\{([^}]*)\}', r'\1', res) - res = re.sub(r'boldmath\$(.*?)\$', r'bm{\1}', res) - res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res) - - pattern = r'(\\(?:left|middle|right|big|Big|bigg|Bigg|bigl|Bigl|biggl|Biggl|bigm|Bigm|biggm|Biggm|bigr|Bigr|biggr|Biggr))\{([^}]*)\}' - replacement = r'\1\2' - res = re.sub(pattern, replacement, res) - if res.endswith(r'\newline'): - res = res[:-8] - return res - def get_image_base64(img_file): buffered = io.BytesIO() img_file.seek(0)