Added mixed recognition
change suryaocr to paddleocr
This commit is contained in:
53
src/web.py
53
src/web.py
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
|
import re
|
||||||
import base64
|
import base64
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
@@ -8,6 +9,8 @@ import streamlit as st
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from streamlit_paste_button import paste_image_button as pbutton
|
from streamlit_paste_button import paste_image_button as pbutton
|
||||||
from onnxruntime import InferenceSession
|
from onnxruntime import InferenceSession
|
||||||
|
from models.thrid_party.paddleocr.infer import predict_det, predict_rec
|
||||||
|
from models.thrid_party.paddleocr.infer import utility
|
||||||
|
|
||||||
from models.utils import mix_inference
|
from models.utils import mix_inference
|
||||||
from models.det_model.inference import PredictConfig
|
from models.det_model.inference import PredictConfig
|
||||||
@@ -16,9 +19,6 @@ from models.ocr_model.model.TexTeller import TexTeller
|
|||||||
from models.ocr_model.utils.inference import inference as latex_recognition
|
from models.ocr_model.utils.inference import inference as latex_recognition
|
||||||
from models.ocr_model.utils.to_katex import to_katex
|
from models.ocr_model.utils.to_katex import to_katex
|
||||||
|
|
||||||
from surya.model.detection import segformer
|
|
||||||
from surya.model.recognition.model import load_model
|
|
||||||
from surya.model.recognition.processor import load_processor
|
|
||||||
|
|
||||||
st.set_page_config(
|
st.set_page_config(
|
||||||
page_title="TexTeller",
|
page_title="TexTeller",
|
||||||
@@ -64,11 +64,27 @@ def get_det_models():
|
|||||||
return infer_config, latex_det_model
|
return infer_config, latex_det_model
|
||||||
|
|
||||||
@st.cache_resource()
|
@st.cache_resource()
|
||||||
def get_ocr_models():
|
def get_ocr_models(accelerator):
|
||||||
det_processor, det_model = segformer.load_processor(), segformer.load_model()
|
use_gpu = accelerator == 'cuda'
|
||||||
rec_model, rec_processor = load_model(), load_processor()
|
|
||||||
lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
|
SIZE_LIMIT = 20 * 1024 * 1024
|
||||||
return lang_ocr_models
|
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
|
||||||
|
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
|
||||||
|
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
|
||||||
|
det_use_gpu = False
|
||||||
|
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
|
||||||
|
|
||||||
|
paddleocr_args = utility.parse_args()
|
||||||
|
paddleocr_args.use_onnx = True
|
||||||
|
paddleocr_args.det_model_dir = det_model_dir
|
||||||
|
paddleocr_args.rec_model_dir = rec_model_dir
|
||||||
|
|
||||||
|
paddleocr_args.use_gpu = det_use_gpu
|
||||||
|
detector = predict_det.TextDetector(paddleocr_args)
|
||||||
|
paddleocr_args.use_gpu = rec_use_gpu
|
||||||
|
recognizer = predict_rec.TextRecognizer(paddleocr_args)
|
||||||
|
return [detector, recognizer]
|
||||||
|
|
||||||
|
|
||||||
def get_image_base64(img_file):
|
def get_image_base64(img_file):
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
@@ -111,16 +127,6 @@ with st.sidebar:
|
|||||||
on_change=change_side_bar
|
on_change=change_side_bar
|
||||||
)
|
)
|
||||||
|
|
||||||
if inf_mode == "Text formula mixed":
|
|
||||||
lang = st.selectbox(
|
|
||||||
"Language",
|
|
||||||
("English", "Chinese")
|
|
||||||
)
|
|
||||||
if lang == "English":
|
|
||||||
lang = "en"
|
|
||||||
elif lang == "Chinese":
|
|
||||||
lang = "zh"
|
|
||||||
|
|
||||||
num_beams = st.number_input(
|
num_beams = st.number_input(
|
||||||
'Number of beams',
|
'Number of beams',
|
||||||
min_value=1,
|
min_value=1,
|
||||||
@@ -147,7 +153,7 @@ latex_rec_models = [texteller, tokenizer]
|
|||||||
|
|
||||||
if inf_mode == "Text formula mixed":
|
if inf_mode == "Text formula mixed":
|
||||||
infer_config, latex_det_model = get_det_models()
|
infer_config, latex_det_model = get_det_models()
|
||||||
lang_ocr_models = get_ocr_models()
|
lang_ocr_models = get_ocr_models(accelerator)
|
||||||
|
|
||||||
st.markdown(html_string, unsafe_allow_html=True)
|
st.markdown(html_string, unsafe_allow_html=True)
|
||||||
|
|
||||||
@@ -225,7 +231,7 @@ elif uploaded_file or paste_result.image_data is not None:
|
|||||||
)[0]
|
)[0]
|
||||||
katex_res = to_katex(TexTeller_result)
|
katex_res = to_katex(TexTeller_result)
|
||||||
else:
|
else:
|
||||||
katex_res = mix_inference(png_file_path, lang, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams)
|
katex_res = mix_inference(png_file_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams)
|
||||||
|
|
||||||
st.success('Completed!', icon="✅")
|
st.success('Completed!', icon="✅")
|
||||||
st.markdown(suc_gif_html, unsafe_allow_html=True)
|
st.markdown(suc_gif_html, unsafe_allow_html=True)
|
||||||
@@ -234,7 +240,12 @@ elif uploaded_file or paste_result.image_data is not None:
|
|||||||
if inf_mode == "Formula only":
|
if inf_mode == "Formula only":
|
||||||
st.latex(katex_res)
|
st.latex(katex_res)
|
||||||
elif inf_mode == "Text formula mixed":
|
elif inf_mode == "Text formula mixed":
|
||||||
st.markdown(katex_res)
|
mixed_res = re.split(r'(\n\$\$.*?\$\$\n)', katex_res)
|
||||||
|
for text in mixed_res:
|
||||||
|
if text.startswith('\n$$') and text.endswith('$$\n'):
|
||||||
|
st.latex(text[3:-3])
|
||||||
|
else:
|
||||||
|
st.markdown(text)
|
||||||
|
|
||||||
st.write("")
|
st.write("")
|
||||||
st.write("")
|
st.write("")
|
||||||
|
|||||||
Reference in New Issue
Block a user