Added mixed recognition

change suryaocr to paddleocr
This commit is contained in:
三洋三洋
2024-05-27 17:03:48 +00:00
parent 2af1e067c1
commit 85d558f772

View File

@@ -1,5 +1,6 @@
import os
import io
import re
import base64
import tempfile
import shutil
@@ -8,6 +9,8 @@ import streamlit as st
from PIL import Image
from streamlit_paste_button import paste_image_button as pbutton
from onnxruntime import InferenceSession
from models.thrid_party.paddleocr.infer import predict_det, predict_rec
from models.thrid_party.paddleocr.infer import utility
from models.utils import mix_inference
from models.det_model.inference import PredictConfig
@@ -16,9 +19,6 @@ from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.inference import inference as latex_recognition
from models.ocr_model.utils.to_katex import to_katex
from surya.model.detection import segformer
from surya.model.recognition.model import load_model
from surya.model.recognition.processor import load_processor
st.set_page_config(
page_title="TexTeller",
@@ -64,11 +64,27 @@ def get_det_models():
return infer_config, latex_det_model
@st.cache_resource()
def get_ocr_models():
det_processor, det_model = segformer.load_processor(), segformer.load_model()
rec_model, rec_processor = load_model(), load_processor()
lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
return lang_ocr_models
def get_ocr_models(accelerator):
use_gpu = accelerator == 'cuda'
SIZE_LIMIT = 20 * 1024 * 1024
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
det_use_gpu = False
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
paddleocr_args = utility.parse_args()
paddleocr_args.use_onnx = True
paddleocr_args.det_model_dir = det_model_dir
paddleocr_args.rec_model_dir = rec_model_dir
paddleocr_args.use_gpu = det_use_gpu
detector = predict_det.TextDetector(paddleocr_args)
paddleocr_args.use_gpu = rec_use_gpu
recognizer = predict_rec.TextRecognizer(paddleocr_args)
return [detector, recognizer]
def get_image_base64(img_file):
buffered = io.BytesIO()
@@ -111,16 +127,6 @@ with st.sidebar:
on_change=change_side_bar
)
if inf_mode == "Text formula mixed":
lang = st.selectbox(
"Language",
("English", "Chinese")
)
if lang == "English":
lang = "en"
elif lang == "Chinese":
lang = "zh"
num_beams = st.number_input(
'Number of beams',
min_value=1,
@@ -147,7 +153,7 @@ latex_rec_models = [texteller, tokenizer]
if inf_mode == "Text formula mixed":
infer_config, latex_det_model = get_det_models()
lang_ocr_models = get_ocr_models()
lang_ocr_models = get_ocr_models(accelerator)
st.markdown(html_string, unsafe_allow_html=True)
@@ -225,7 +231,7 @@ elif uploaded_file or paste_result.image_data is not None:
)[0]
katex_res = to_katex(TexTeller_result)
else:
katex_res = mix_inference(png_file_path, lang, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams)
katex_res = mix_inference(png_file_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams)
st.success('Completed!', icon="")
st.markdown(suc_gif_html, unsafe_allow_html=True)
@@ -234,7 +240,12 @@ elif uploaded_file or paste_result.image_data is not None:
if inf_mode == "Formula only":
st.latex(katex_res)
elif inf_mode == "Text formula mixed":
st.markdown(katex_res)
mixed_res = re.split(r'(\n\$\$.*?\$\$\n)', katex_res)
for text in mixed_res:
if text.startswith('\n$$') and text.endswith('$$\n'):
st.latex(text[3:-3])
else:
st.markdown(text)
st.write("")
st.write("")