1) 实现了文本-公式混排识别; 2) 重构了项目结构
This commit is contained in:
99
src/web.py
99
src/web.py
@@ -7,10 +7,18 @@ import streamlit as st
|
||||
|
||||
from PIL import Image
|
||||
from streamlit_paste_button import paste_image_button as pbutton
|
||||
from models.ocr_model.utils.inference import inference
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
from utils import to_katex
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from models.utils import mix_inference
|
||||
from models.det_model.inference import PredictConfig
|
||||
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
from models.ocr_model.utils.inference import inference as latex_recognition
|
||||
from models.ocr_model.utils.to_katex import to_katex
|
||||
|
||||
from surya.model.detection import segformer
|
||||
from surya.model.recognition.model import load_model
|
||||
from surya.model.recognition.processor import load_processor
|
||||
|
||||
st.set_page_config(
|
||||
page_title="TexTeller",
|
||||
@@ -42,13 +50,26 @@ fail_gif_html = '''
|
||||
'''
|
||||
|
||||
@st.cache_resource
|
||||
def get_model():
|
||||
def get_texteller():
|
||||
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
|
||||
|
||||
@st.cache_resource
|
||||
def get_tokenizer():
|
||||
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
|
||||
|
||||
@st.cache_resource
|
||||
def get_det_models():
|
||||
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
|
||||
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
||||
return infer_config, latex_det_model
|
||||
|
||||
@st.cache_resource()
|
||||
def get_ocr_models():
|
||||
det_processor, det_model = segformer.load_processor(), segformer.load_model()
|
||||
rec_model, rec_processor = load_model(), load_processor()
|
||||
lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
|
||||
return lang_ocr_models
|
||||
|
||||
def get_image_base64(img_file):
|
||||
buffered = io.BytesIO()
|
||||
img_file.seek(0)
|
||||
@@ -62,9 +83,6 @@ def on_file_upload():
|
||||
def change_side_bar():
|
||||
st.session_state["CHANGE_SIDEBAR_FLAG"] = True
|
||||
|
||||
model = get_model()
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
if "start" not in st.session_state:
|
||||
st.session_state["start"] = 1
|
||||
st.toast('Hooray!', icon='🎉')
|
||||
@@ -75,31 +93,34 @@ if "UPLOADED_FILE_CHANGED" not in st.session_state:
|
||||
if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
|
||||
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
|
||||
|
||||
if "INF_MODE" not in st.session_state:
|
||||
st.session_state["INF_MODE"] = "Formula only"
|
||||
|
||||
|
||||
# ============================ begin sidebar =============================== #
|
||||
|
||||
with st.sidebar:
|
||||
num_beams = 1
|
||||
inf_mode = 'cpu'
|
||||
|
||||
st.markdown("# 🔨️ Config")
|
||||
st.markdown("")
|
||||
|
||||
model_type = st.selectbox(
|
||||
"Model type",
|
||||
("TexTeller", "None"),
|
||||
inf_mode = st.selectbox(
|
||||
"Inference mode",
|
||||
("Formula only", "Text formula mixed"),
|
||||
on_change=change_side_bar
|
||||
)
|
||||
if model_type == "TexTeller":
|
||||
num_beams = st.number_input(
|
||||
'Number of beams',
|
||||
min_value=1,
|
||||
max_value=20,
|
||||
step=1,
|
||||
on_change=change_side_bar
|
||||
)
|
||||
|
||||
inf_mode = st.radio(
|
||||
"Inference mode",
|
||||
num_beams = st.number_input(
|
||||
'Number of beams',
|
||||
min_value=1,
|
||||
max_value=20,
|
||||
step=1,
|
||||
on_change=change_side_bar
|
||||
)
|
||||
|
||||
accelerator = st.radio(
|
||||
"Accelerator",
|
||||
("cpu", "cuda", "mps"),
|
||||
on_change=change_side_bar
|
||||
)
|
||||
@@ -107,9 +128,16 @@ with st.sidebar:
|
||||
# ============================ end sidebar =============================== #
|
||||
|
||||
|
||||
|
||||
# ============================ begin pages =============================== #
|
||||
|
||||
texteller = get_texteller()
|
||||
tokenizer = get_tokenizer()
|
||||
latex_rec_models = [texteller, tokenizer]
|
||||
|
||||
if inf_mode == "Text formula mixed":
|
||||
infer_config, latex_det_model = get_det_models()
|
||||
lang_ocr_models = get_ocr_models()
|
||||
|
||||
st.markdown(html_string, unsafe_allow_html=True)
|
||||
|
||||
uploaded_file = st.file_uploader(
|
||||
@@ -176,19 +204,26 @@ elif uploaded_file or paste_result.image_data is not None:
|
||||
st.write("")
|
||||
|
||||
with st.spinner("Predicting..."):
|
||||
uploaded_file.seek(0)
|
||||
TexTeller_result = inference(
|
||||
model,
|
||||
tokenizer,
|
||||
[png_file_path],
|
||||
inf_mode=inf_mode,
|
||||
num_beams=num_beams
|
||||
)[0]
|
||||
if inf_mode == "Formula only":
|
||||
TexTeller_result = latex_recognition(
|
||||
texteller,
|
||||
tokenizer,
|
||||
[png_file_path],
|
||||
accelerator=accelerator,
|
||||
num_beams=num_beams
|
||||
)[0]
|
||||
katex_res = to_katex(TexTeller_result)
|
||||
else:
|
||||
katex_res = mix_inference(png_file_path, "en", infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams)
|
||||
|
||||
st.success('Completed!', icon="✅")
|
||||
st.markdown(suc_gif_html, unsafe_allow_html=True)
|
||||
katex_res = to_katex(TexTeller_result)
|
||||
st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150)
|
||||
st.latex(katex_res)
|
||||
|
||||
if inf_mode == "Formula only":
|
||||
st.latex(katex_res)
|
||||
elif inf_mode == "Text formula mixed":
|
||||
st.markdown(katex_res)
|
||||
|
||||
st.write("")
|
||||
st.write("")
|
||||
|
||||
Reference in New Issue
Block a user