前端更新, inference.py更新
1) 前端支持剪贴板粘贴图片. 2) 前端支持模型配置. 3) 修改了inference.py的接口. 4) 删除了不必要的文件
This commit is contained in:
155
src/web.py
155
src/web.py
@@ -6,16 +6,22 @@ import shutil
|
||||
import streamlit as st
|
||||
|
||||
from PIL import Image
|
||||
from streamlit_paste_button import paste_image_button as pbutton
|
||||
from models.ocr_model.utils.inference import inference
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
from utils import to_katex
|
||||
|
||||
|
||||
st.set_page_config(
|
||||
page_title="TexTeller",
|
||||
page_icon="🧮"
|
||||
)
|
||||
|
||||
html_string = '''
|
||||
<h1 style="color: black; text-align: center;">
|
||||
<img src="https://slackmojis.com/emojis/429-troll/download" width="50">
|
||||
TexTeller
|
||||
<img src="https://slackmojis.com/emojis/429-troll/download" width="50">
|
||||
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
|
||||
𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
|
||||
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
|
||||
</h1>
|
||||
'''
|
||||
|
||||
@@ -35,29 +41,6 @@ fail_gif_html = '''
|
||||
</h1>
|
||||
'''
|
||||
|
||||
tex = r'''
|
||||
\documentclass{{article}}
|
||||
\usepackage[
|
||||
left=1in, % 左边距
|
||||
right=1in, % 右边距
|
||||
top=1in, % 上边距
|
||||
bottom=1in,% 下边距
|
||||
paperwidth=40cm, % 页面宽度
|
||||
paperheight=40cm % 页面高度,这里以A4纸为例
|
||||
]{{geometry}}
|
||||
|
||||
\usepackage[utf8]{{inputenc}}
|
||||
\usepackage{{multirow,multicol,amsmath,amsfonts,amssymb,mathtools,bm,mathrsfs,wasysym,amsbsy,upgreek,mathalfa,stmaryrd,mathrsfs,dsfont,amsthm,amsmath,multirow}}
|
||||
|
||||
\begin{{document}}
|
||||
|
||||
{formula}
|
||||
|
||||
\pagenumbering{{gobble}}
|
||||
\end{{document}}
|
||||
'''
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def get_model():
|
||||
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
|
||||
@@ -73,6 +56,12 @@ def get_image_base64(img_file):
|
||||
img.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
def on_file_upload():
|
||||
st.session_state["UPLOADED_FILE_CHANGED"] = True
|
||||
|
||||
def change_side_bar():
|
||||
st.session_state["CHANGE_SIDEBAR_FLAG"] = True
|
||||
|
||||
model = get_model()
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
@@ -80,37 +69,106 @@ if "start" not in st.session_state:
|
||||
st.session_state["start"] = 1
|
||||
st.toast('Hooray!', icon='🎉')
|
||||
|
||||
if "UPLOADED_FILE_CHANGED" not in st.session_state:
|
||||
st.session_state["UPLOADED_FILE_CHANGED"] = False
|
||||
|
||||
# ============================ pages =============================== #
|
||||
if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
|
||||
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
|
||||
|
||||
# ============================ begin sidebar =============================== #
|
||||
|
||||
with st.sidebar:
|
||||
num_beams = 1
|
||||
inf_mode = 'cpu'
|
||||
|
||||
st.markdown("# 🔨️ Config")
|
||||
st.markdown("")
|
||||
|
||||
model_type = st.selectbox(
|
||||
"Model type",
|
||||
("TexTeller", "None"),
|
||||
on_change=change_side_bar
|
||||
)
|
||||
if model_type == "TexTeller":
|
||||
num_beams = st.number_input(
|
||||
'Number of beams',
|
||||
min_value=1,
|
||||
max_value=20,
|
||||
step=1,
|
||||
on_change=change_side_bar
|
||||
)
|
||||
|
||||
inf_mode = st.radio(
|
||||
"Inference mode",
|
||||
("cpu", "cuda", "mps"),
|
||||
on_change=change_side_bar
|
||||
)
|
||||
|
||||
# ============================ end sidebar =============================== #
|
||||
|
||||
|
||||
|
||||
# ============================ begin pages =============================== #
|
||||
|
||||
st.markdown(html_string, unsafe_allow_html=True)
|
||||
|
||||
uploaded_file = st.file_uploader("",type=['jpg', 'png', 'pdf'])
|
||||
uploaded_file = st.file_uploader(
|
||||
" ",
|
||||
type=['jpg', 'png'],
|
||||
on_change=on_file_upload
|
||||
)
|
||||
|
||||
paste_result = pbutton(
|
||||
label="📋 Paste an image",
|
||||
background_color="#5BBCFF",
|
||||
hover_background_color="#3498db",
|
||||
)
|
||||
st.write("")
|
||||
|
||||
if st.session_state["CHANGE_SIDEBAR_FLAG"] == True:
|
||||
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
|
||||
elif uploaded_file or paste_result.image_data is not None:
|
||||
if st.session_state["UPLOADED_FILE_CHANGED"] == False and paste_result.image_data is not None:
|
||||
uploaded_file = io.BytesIO()
|
||||
paste_result.image_data.save(uploaded_file, format='PNG')
|
||||
uploaded_file.seek(0)
|
||||
|
||||
if st.session_state["UPLOADED_FILE_CHANGED"] == True:
|
||||
st.session_state["UPLOADED_FILE_CHANGED"] = False
|
||||
|
||||
if uploaded_file:
|
||||
img = Image.open(uploaded_file)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
png_file_path = os.path.join(temp_dir, 'image.png')
|
||||
img.save(png_file_path, 'PNG')
|
||||
|
||||
img_base64 = get_image_base64(uploaded_file)
|
||||
with st.container(height=300):
|
||||
img_base64 = get_image_base64(uploaded_file)
|
||||
|
||||
st.markdown(f"""
|
||||
<style>
|
||||
.centered-container {{
|
||||
text-align: center;
|
||||
}}
|
||||
.centered-image {{
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
max-height: 350px;
|
||||
max-width: 100%;
|
||||
}}
|
||||
</style>
|
||||
<div class="centered-container">
|
||||
<img src="data:image/png;base64,{img_base64}" class="centered-image" alt="Input image">
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
st.markdown(f"""
|
||||
<style>
|
||||
.centered-container {{
|
||||
text-align: center;
|
||||
}}
|
||||
.centered-image {{
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
max-width: 500px;
|
||||
max-height: 500px;
|
||||
}}
|
||||
</style>
|
||||
<div class="centered-container">
|
||||
<img src="data:image/png;base64,{img_base64}" class="centered-image" alt="Input image">
|
||||
<p style="color:gray;">Input image ({img.height}✖️{img.width})</p>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
@@ -123,15 +181,28 @@ if uploaded_file:
|
||||
model,
|
||||
tokenizer,
|
||||
[png_file_path],
|
||||
True if os.environ['USE_CUDA'] == 'True' else False,
|
||||
int(os.environ['NUM_BEAM'])
|
||||
inf_mode=inf_mode,
|
||||
num_beams=num_beams
|
||||
)[0]
|
||||
st.success('Completed!', icon="✅")
|
||||
st.markdown(suc_gif_html, unsafe_allow_html=True)
|
||||
katex_res = to_katex(TexTeller_result)
|
||||
st.text_area(":red[Predicted formula]", katex_res, height=150)
|
||||
st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150)
|
||||
st.latex(katex_res)
|
||||
|
||||
st.write("")
|
||||
st.write("")
|
||||
|
||||
with st.expander(":star2: :gray[Tips for better results]"):
|
||||
st.markdown('''
|
||||
* :mag_right: Use a clear and high-resolution image.
|
||||
* :scissors: Crop images as accurately as possible.
|
||||
* :jigsaw: Split large multi line formulas into smaller ones.
|
||||
* :page_facing_up: Use images with **white background and black text** as much as possible.
|
||||
* :book: Use a font with good readability.
|
||||
''')
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
# ============================ pages =============================== #
|
||||
paste_result.image_data = None
|
||||
|
||||
# ============================ end pages =============================== #
|
||||
|
||||
Reference in New Issue
Block a user