import os
import io
import base64
import tempfile
import shutil
import streamlit as st
from PIL import Image
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
from utils import to_katex
html_string = '''
TexTeller
'''
suc_gif_html = '''
'''
fail_gif_html = '''
'''
tex = r'''
\documentclass{{article}}
\usepackage[
left=1in, % 左边距
right=1in, % 右边距
top=1in, % 上边距
bottom=1in,% 下边距
paperwidth=40cm, % 页面宽度
paperheight=40cm % 页面高度,这里以A4纸为例
]{{geometry}}
\usepackage[utf8]{{inputenc}}
\usepackage{{multirow,multicol,amsmath,amsfonts,amssymb,mathtools,bm,mathrsfs,wasysym,amsbsy,upgreek,mathalfa,stmaryrd,mathrsfs,dsfont,amsthm,amsmath,multirow}}
\begin{{document}}
{formula}
\pagenumbering{{gobble}}
\end{{document}}
'''
@st.cache_resource
def get_model():
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
@st.cache_resource
def get_tokenizer():
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
def get_image_base64(img_file):
buffered = io.BytesIO()
img_file.seek(0)
img = Image.open(img_file)
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
model = get_model()
tokenizer = get_tokenizer()
if "start" not in st.session_state:
st.session_state["start"] = 1
st.toast('Hooray!', icon='🎉')
# ============================ pages =============================== #
st.markdown(html_string, unsafe_allow_html=True)
uploaded_file = st.file_uploader("",type=['jpg', 'png', 'pdf'])
if uploaded_file:
img = Image.open(uploaded_file)
temp_dir = tempfile.mkdtemp()
png_file_path = os.path.join(temp_dir, 'image.png')
img.save(png_file_path, 'PNG')
img_base64 = get_image_base64(uploaded_file)
st.markdown(f"""
Input image ({img.height}✖️{img.width})
""", unsafe_allow_html=True)
st.write("")
with st.spinner("Predicting..."):
uploaded_file.seek(0)
TexTeller_result = inference(
model,
tokenizer,
[png_file_path],
True if os.environ['USE_CUDA'] == 'True' else False,
int(os.environ['NUM_BEAM'])
)[0]
st.success('Completed!', icon="✅")
st.markdown(suc_gif_html, unsafe_allow_html=True)
katex_res = to_katex(TexTeller_result)
st.text_area(":red[Predicted formula]", katex_res, height=150)
st.latex(katex_res)
shutil.rmtree(temp_dir)
# ============================ pages =============================== #