Added mixed recognition
change suryaocr to paddleocr
This commit is contained in:
53
src/web.py
53
src/web.py
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import io
|
||||
import re
|
||||
import base64
|
||||
import tempfile
|
||||
import shutil
|
||||
@@ -8,6 +9,8 @@ import streamlit as st
|
||||
from PIL import Image
|
||||
from streamlit_paste_button import paste_image_button as pbutton
|
||||
from onnxruntime import InferenceSession
|
||||
from models.thrid_party.paddleocr.infer import predict_det, predict_rec
|
||||
from models.thrid_party.paddleocr.infer import utility
|
||||
|
||||
from models.utils import mix_inference
|
||||
from models.det_model.inference import PredictConfig
|
||||
@@ -16,9 +19,6 @@ from models.ocr_model.model.TexTeller import TexTeller
|
||||
from models.ocr_model.utils.inference import inference as latex_recognition
|
||||
from models.ocr_model.utils.to_katex import to_katex
|
||||
|
||||
from surya.model.detection import segformer
|
||||
from surya.model.recognition.model import load_model
|
||||
from surya.model.recognition.processor import load_processor
|
||||
|
||||
st.set_page_config(
|
||||
page_title="TexTeller",
|
||||
@@ -64,11 +64,27 @@ def get_det_models():
|
||||
return infer_config, latex_det_model
|
||||
|
||||
@st.cache_resource()
|
||||
def get_ocr_models():
|
||||
det_processor, det_model = segformer.load_processor(), segformer.load_model()
|
||||
rec_model, rec_processor = load_model(), load_processor()
|
||||
lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
|
||||
return lang_ocr_models
|
||||
def get_ocr_models(accelerator):
|
||||
use_gpu = accelerator == 'cuda'
|
||||
|
||||
SIZE_LIMIT = 20 * 1024 * 1024
|
||||
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
|
||||
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
|
||||
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
|
||||
det_use_gpu = False
|
||||
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
|
||||
|
||||
paddleocr_args = utility.parse_args()
|
||||
paddleocr_args.use_onnx = True
|
||||
paddleocr_args.det_model_dir = det_model_dir
|
||||
paddleocr_args.rec_model_dir = rec_model_dir
|
||||
|
||||
paddleocr_args.use_gpu = det_use_gpu
|
||||
detector = predict_det.TextDetector(paddleocr_args)
|
||||
paddleocr_args.use_gpu = rec_use_gpu
|
||||
recognizer = predict_rec.TextRecognizer(paddleocr_args)
|
||||
return [detector, recognizer]
|
||||
|
||||
|
||||
def get_image_base64(img_file):
|
||||
buffered = io.BytesIO()
|
||||
@@ -111,16 +127,6 @@ with st.sidebar:
|
||||
on_change=change_side_bar
|
||||
)
|
||||
|
||||
if inf_mode == "Text formula mixed":
|
||||
lang = st.selectbox(
|
||||
"Language",
|
||||
("English", "Chinese")
|
||||
)
|
||||
if lang == "English":
|
||||
lang = "en"
|
||||
elif lang == "Chinese":
|
||||
lang = "zh"
|
||||
|
||||
num_beams = st.number_input(
|
||||
'Number of beams',
|
||||
min_value=1,
|
||||
@@ -147,7 +153,7 @@ latex_rec_models = [texteller, tokenizer]
|
||||
|
||||
if inf_mode == "Text formula mixed":
|
||||
infer_config, latex_det_model = get_det_models()
|
||||
lang_ocr_models = get_ocr_models()
|
||||
lang_ocr_models = get_ocr_models(accelerator)
|
||||
|
||||
st.markdown(html_string, unsafe_allow_html=True)
|
||||
|
||||
@@ -225,7 +231,7 @@ elif uploaded_file or paste_result.image_data is not None:
|
||||
)[0]
|
||||
katex_res = to_katex(TexTeller_result)
|
||||
else:
|
||||
katex_res = mix_inference(png_file_path, lang, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams)
|
||||
katex_res = mix_inference(png_file_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams)
|
||||
|
||||
st.success('Completed!', icon="✅")
|
||||
st.markdown(suc_gif_html, unsafe_allow_html=True)
|
||||
@@ -234,7 +240,12 @@ elif uploaded_file or paste_result.image_data is not None:
|
||||
if inf_mode == "Formula only":
|
||||
st.latex(katex_res)
|
||||
elif inf_mode == "Text formula mixed":
|
||||
st.markdown(katex_res)
|
||||
mixed_res = re.split(r'(\n\$\$.*?\$\$\n)', katex_res)
|
||||
for text in mixed_res:
|
||||
if text.startswith('\n$$') and text.endswith('$$\n'):
|
||||
st.latex(text[3:-3])
|
||||
else:
|
||||
st.markdown(text)
|
||||
|
||||
st.write("")
|
||||
st.write("")
|
||||
|
||||
Reference in New Issue
Block a user