From 85d558f772f1e2bacfabe170791b60bcd092d4d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Mon, 27 May 2024 17:03:48 +0000 Subject: [PATCH] Added mixed recognition change suryaocr to paddleocr --- src/web.py | 53 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/src/web.py b/src/web.py index 9244fdc..92d11c6 100644 --- a/src/web.py +++ b/src/web.py @@ -1,5 +1,6 @@ import os import io +import re import base64 import tempfile import shutil @@ -8,6 +9,8 @@ import streamlit as st from PIL import Image from streamlit_paste_button import paste_image_button as pbutton from onnxruntime import InferenceSession +from models.thrid_party.paddleocr.infer import predict_det, predict_rec +from models.thrid_party.paddleocr.infer import utility from models.utils import mix_inference from models.det_model.inference import PredictConfig @@ -16,9 +19,6 @@ from models.ocr_model.model.TexTeller import TexTeller from models.ocr_model.utils.inference import inference as latex_recognition from models.ocr_model.utils.to_katex import to_katex -from surya.model.detection import segformer -from surya.model.recognition.model import load_model -from surya.model.recognition.processor import load_processor st.set_page_config( page_title="TexTeller", @@ -64,11 +64,27 @@ def get_det_models(): return infer_config, latex_det_model @st.cache_resource() -def get_ocr_models(): - det_processor, det_model = segformer.load_processor(), segformer.load_model() - rec_model, rec_processor = load_model(), load_processor() - lang_ocr_models = [det_model, det_processor, rec_model, rec_processor] - return lang_ocr_models +def get_ocr_models(accelerator): + use_gpu = accelerator == 'cuda' + + SIZE_LIMIT = 20 * 1024 * 1024 + det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx" + rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx" + # The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime) + det_use_gpu = False + rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT) + + paddleocr_args = utility.parse_args() + paddleocr_args.use_onnx = True + paddleocr_args.det_model_dir = det_model_dir + paddleocr_args.rec_model_dir = rec_model_dir + + paddleocr_args.use_gpu = det_use_gpu + detector = predict_det.TextDetector(paddleocr_args) + paddleocr_args.use_gpu = rec_use_gpu + recognizer = predict_rec.TextRecognizer(paddleocr_args) + return [detector, recognizer] + def get_image_base64(img_file): buffered = io.BytesIO() @@ -111,16 +127,6 @@ with st.sidebar: on_change=change_side_bar ) - if inf_mode == "Text formula mixed": - lang = st.selectbox( - "Language", - ("English", "Chinese") - ) - if lang == "English": - lang = "en" - elif lang == "Chinese": - lang = "zh" - num_beams = st.number_input( 'Number of beams', min_value=1, @@ -147,7 +153,7 @@ latex_rec_models = [texteller, tokenizer] if inf_mode == "Text formula mixed": infer_config, latex_det_model = get_det_models() - lang_ocr_models = get_ocr_models() + lang_ocr_models = get_ocr_models(accelerator) st.markdown(html_string, unsafe_allow_html=True) @@ -225,7 +231,7 @@ elif uploaded_file or paste_result.image_data is not None: )[0] katex_res = to_katex(TexTeller_result) else: - katex_res = mix_inference(png_file_path, lang, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams) + katex_res = mix_inference(png_file_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams) st.success('Completed!', icon="✅") st.markdown(suc_gif_html, unsafe_allow_html=True) @@ -234,7 +240,12 @@ elif uploaded_file or paste_result.image_data is not None: if inf_mode == "Formula only": st.latex(katex_res) elif inf_mode == "Text formula mixed": - st.markdown(katex_res) + mixed_res = re.split(r'(\n\$\$.*?\$\$\n)', katex_res) + for text in mixed_res: + if text.startswith('\n$$') and text.endswith('$$\n'): + st.latex(text[3:-3]) + else: + st.markdown(text) st.write("") st.write("")