[refactor] Init

This commit is contained in:
OleehyO
2025-04-16 14:23:02 +00:00
parent e0cbf2c99f
commit 0cba17d9ce
101 changed files with 1854 additions and 2758 deletions

View File

@@ -0,0 +1,225 @@
import base64
import io
import os
import re
import shutil
import tempfile
import streamlit as st
from PIL import Image
from streamlit_paste_button import paste_image_button as pbutton
from texteller.api import (
img2latex,
load_latexdet_model,
load_model,
load_textdet_model,
load_textrec_model,
load_tokenizer,
paragraph2md,
)
from texteller.cli.commands.web.style import (
HEADER_HTML,
IMAGE_EMBED_HTML,
IMAGE_INFO_HTML,
SUCCESS_GIF_HTML,
)
from texteller.utils import str2device
st.set_page_config(page_title="TexTeller", page_icon="🧮")
@st.cache_resource
def get_texteller(use_onnx):
return load_model(use_onnx=use_onnx)
@st.cache_resource
def get_tokenizer():
return load_tokenizer()
@st.cache_resource
def get_latexdet_model():
return load_latexdet_model()
@st.cache_resource()
def get_textrec_model():
return load_textrec_model()
@st.cache_resource()
def get_textdet_model():
return load_textdet_model()
def get_image_base64(img_file):
buffered = io.BytesIO()
img_file.seek(0)
img = Image.open(img_file)
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def on_file_upload():
st.session_state["UPLOADED_FILE_CHANGED"] = True
def change_side_bar():
st.session_state["CHANGE_SIDEBAR_FLAG"] = True
if "start" not in st.session_state:
st.session_state["start"] = 1
st.toast("Hooray!", icon="🎉")
if "UPLOADED_FILE_CHANGED" not in st.session_state:
st.session_state["UPLOADED_FILE_CHANGED"] = False
if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
if "INF_MODE" not in st.session_state:
st.session_state["INF_MODE"] = "Formula recognition"
# ====== <sidebar> ======
with st.sidebar:
num_beams = 1
st.markdown("# 🔨️ Config")
st.markdown("")
inf_mode = st.selectbox(
"Inference mode",
("Formula recognition", "Paragraph recognition"),
on_change=change_side_bar,
)
num_beams = st.number_input(
"Number of beams", min_value=1, max_value=20, step=1, on_change=change_side_bar
)
device = st.radio("device", ("cpu", "cuda", "mps"), on_change=change_side_bar)
st.markdown("## Seedup")
use_onnx = st.toggle("ONNX Runtime ")
# ====== </sidebar> ======
# ====== <page> ======
latexrec_model = get_texteller(use_onnx)
tokenizer = get_tokenizer()
if inf_mode == "Paragraph recognition":
latexdet_model = get_latexdet_model()
textrec_model = get_textrec_model()
textdet_model = get_textdet_model()
st.markdown(HEADER_HTML, unsafe_allow_html=True)
uploaded_file = st.file_uploader(" ", type=["jpg", "png"], on_change=on_file_upload)
paste_result = pbutton(
label="📋 Paste an image",
background_color="#5BBCFF",
hover_background_color="#3498db",
)
st.write("")
if st.session_state["CHANGE_SIDEBAR_FLAG"] is True:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
elif uploaded_file or paste_result.image_data is not None:
if st.session_state["UPLOADED_FILE_CHANGED"] is False and paste_result.image_data is not None:
uploaded_file = io.BytesIO()
paste_result.image_data.save(uploaded_file, format="PNG")
uploaded_file.seek(0)
if st.session_state["UPLOADED_FILE_CHANGED"] is True:
st.session_state["UPLOADED_FILE_CHANGED"] = False
img = Image.open(uploaded_file)
temp_dir = tempfile.mkdtemp()
png_fpath = os.path.join(temp_dir, "image.png")
img.save(png_fpath, "PNG")
with st.container(height=300):
img_base64 = get_image_base64(uploaded_file)
st.markdown(
IMAGE_EMBED_HTML.format(img_base64=img_base64),
unsafe_allow_html=True,
)
st.markdown(
IMAGE_INFO_HTML.format(img_height=img.height, img_width=img.width),
unsafe_allow_html=True,
)
st.write("")
with st.spinner("Predicting..."):
if inf_mode == "Formula recognition":
pred = img2latex(
model=latexrec_model,
tokenizer=tokenizer,
images=[png_fpath],
device=str2device(device),
out_format="katex",
num_beams=num_beams,
keep_style=False,
)[0]
else:
pred = paragraph2md(
img_path=png_fpath,
latexdet_model=latexdet_model,
textdet_model=textdet_model,
textrec_model=textrec_model,
latexrec_model=latexrec_model,
tokenizer=tokenizer,
device=str2device(device),
num_beams=num_beams,
)
st.success("Completed!", icon="")
# st.markdown(SUCCESS_GIF_HTML, unsafe_allow_html=True)
# st.text_area("Predicted LaTeX", pred, height=150)
if inf_mode == "Formula recognition":
st.code(pred, language="latex")
elif inf_mode == "Paragraph recognition":
st.code(pred, language="markdown")
else:
raise ValueError(f"Invalid inference mode: {inf_mode}")
if inf_mode == "Formula recognition":
st.latex(pred)
elif inf_mode == "Paragraph recognition":
mixed_res = re.split(r"(\$\$.*?\$\$)", pred, flags=re.DOTALL)
for text in mixed_res:
if text.startswith("$$") and text.endswith("$$"):
st.latex(text.strip("$$"))
else:
st.markdown(text)
st.write("")
st.write("")
with st.expander(":star2: :gray[Tips for better results]"):
st.markdown("""
* :mag_right: Use a clear and high-resolution image.
* :scissors: Crop images as accurately as possible.
* :jigsaw: Split large multi line formulas into smaller ones.
* :page_facing_up: Use images with **white background and black text** as much as possible.
* :book: Use a font with good readability.
""")
shutil.rmtree(temp_dir)
paste_result.image_data = None
# ====== </page> ======