1) 实现了文本-公式混排识别; 2) 重构了项目结构

This commit is contained in:
三洋三洋
2024-04-21 00:05:14 +08:00
parent eab6e4c85d
commit 185b2e3db6
19 changed files with 753 additions and 296 deletions

View File

@@ -7,10 +7,18 @@ import streamlit as st
from PIL import Image
from streamlit_paste_button import paste_image_button as pbutton
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
from utils import to_katex
from onnxruntime import InferenceSession
from models.utils import mix_inference
from models.det_model.inference import PredictConfig
from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.inference import inference as latex_recognition
from models.ocr_model.utils.to_katex import to_katex
from surya.model.detection import segformer
from surya.model.recognition.model import load_model
from surya.model.recognition.processor import load_processor
st.set_page_config(
page_title="TexTeller",
@@ -42,13 +50,26 @@ fail_gif_html = '''
'''
@st.cache_resource
def get_model():
def get_texteller():
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
@st.cache_resource
def get_tokenizer():
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
@st.cache_resource
def get_det_models():
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
return infer_config, latex_det_model
@st.cache_resource()
def get_ocr_models():
det_processor, det_model = segformer.load_processor(), segformer.load_model()
rec_model, rec_processor = load_model(), load_processor()
lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
return lang_ocr_models
def get_image_base64(img_file):
buffered = io.BytesIO()
img_file.seek(0)
@@ -62,9 +83,6 @@ def on_file_upload():
def change_side_bar():
st.session_state["CHANGE_SIDEBAR_FLAG"] = True
model = get_model()
tokenizer = get_tokenizer()
if "start" not in st.session_state:
st.session_state["start"] = 1
st.toast('Hooray!', icon='🎉')
@@ -75,31 +93,34 @@ if "UPLOADED_FILE_CHANGED" not in st.session_state:
if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
if "INF_MODE" not in st.session_state:
st.session_state["INF_MODE"] = "Formula only"
# ============================ begin sidebar =============================== #
with st.sidebar:
num_beams = 1
inf_mode = 'cpu'
st.markdown("# 🔨️ Config")
st.markdown("")
model_type = st.selectbox(
"Model type",
("TexTeller", "None"),
inf_mode = st.selectbox(
"Inference mode",
("Formula only", "Text formula mixed"),
on_change=change_side_bar
)
if model_type == "TexTeller":
num_beams = st.number_input(
'Number of beams',
min_value=1,
max_value=20,
step=1,
on_change=change_side_bar
)
inf_mode = st.radio(
"Inference mode",
num_beams = st.number_input(
'Number of beams',
min_value=1,
max_value=20,
step=1,
on_change=change_side_bar
)
accelerator = st.radio(
"Accelerator",
("cpu", "cuda", "mps"),
on_change=change_side_bar
)
@@ -107,9 +128,16 @@ with st.sidebar:
# ============================ end sidebar =============================== #
# ============================ begin pages =============================== #
texteller = get_texteller()
tokenizer = get_tokenizer()
latex_rec_models = [texteller, tokenizer]
if inf_mode == "Text formula mixed":
infer_config, latex_det_model = get_det_models()
lang_ocr_models = get_ocr_models()
st.markdown(html_string, unsafe_allow_html=True)
uploaded_file = st.file_uploader(
@@ -176,19 +204,26 @@ elif uploaded_file or paste_result.image_data is not None:
st.write("")
with st.spinner("Predicting..."):
uploaded_file.seek(0)
TexTeller_result = inference(
model,
tokenizer,
[png_file_path],
inf_mode=inf_mode,
num_beams=num_beams
)[0]
if inf_mode == "Formula only":
TexTeller_result = latex_recognition(
texteller,
tokenizer,
[png_file_path],
accelerator=accelerator,
num_beams=num_beams
)[0]
katex_res = to_katex(TexTeller_result)
else:
katex_res = mix_inference(png_file_path, "en", infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams)
st.success('Completed!', icon="")
st.markdown(suc_gif_html, unsafe_allow_html=True)
katex_res = to_katex(TexTeller_result)
st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150)
st.latex(katex_res)
if inf_mode == "Formula only":
st.latex(katex_res)
elif inf_mode == "Text formula mixed":
st.markdown(katex_res)
st.write("")
st.write("")