This commit is contained in:
三洋三洋
2024-03-18 15:48:04 +00:00
parent 5d089b5a7f
commit 74341c7e8a
6 changed files with 330 additions and 118 deletions

View File

@@ -2,13 +2,65 @@ import os
import io
import base64
import tempfile
import time
import subprocess
import shutil
import streamlit as st
from PIL import Image
from PIL import Image, ImageChops
from pathlib import Path
from pdf2image import convert_from_path
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
html_string = '''
<h1 style="color: black; text-align: center;">
<img src="https://slackmojis.com/emojis/429-troll/download" width="50">
TexTeller
<img src="https://slackmojis.com/emojis/429-troll/download" width="50">
</h1>
'''
suc_gif_html = '''
<h1 style="color: black; text-align: center;">
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
</h1>
'''
fail_gif_html = '''
<h1 style="color: black; text-align: center;">
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download" >
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download" >
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download" >
</h1>
'''
tex = r'''
\documentclass{{article}}
\usepackage[
left=1in, % 左边距
right=1in, % 右边距
top=1in, % 上边距
bottom=1in,% 下边距
paperwidth=40cm, % 页面宽度
paperheight=40cm % 页面高度这里以A4纸为例
]{{geometry}}
\usepackage[utf8]{{inputenc}}
\usepackage{{multirow,multicol,amsmath,amsfonts,amssymb,mathtools,bm,mathrsfs,wasysym,amsbsy,upgreek,mathalfa,stmaryrd,mathrsfs,dsfont,amsthm,amsmath,multirow}}
\begin{{document}}
{formula}
\pagenumbering{{gobble}}
\end{{document}}
'''
@st.cache_resource
def get_model():
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
@@ -18,24 +70,74 @@ def get_model():
def get_tokenizer():
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
def get_image_base64(img_file):
buffered = io.BytesIO()
img_file.seek(0)
img = Image.open(img_file)
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def rendering(formula: str, out_img_path: Path) -> bool:
build_dir = out_img_path / 'build'
build_dir.mkdir(exist_ok=True, parents=True)
f = build_dir / 'formula.tex'
f.touch(exist_ok=True)
f.write_text(tex.format(formula=formula))
p = subprocess.Popen([
'xelatex',
f'-output-directory={build_dir}',
'-interaction=nonstopmode',
'-halt-on-error',
f'{f}'
])
p.communicate()
return p.returncode == 0
def pdf_to_pngbytes(pdf_path):
images = convert_from_path(pdf_path, first_page=1, last_page=1)
trimmed_images = trim(images[0])
png_image_bytes = io.BytesIO()
trimmed_images.save(png_image_bytes, format='PNG')
png_image_bytes.seek(0)
return png_image_bytes
def trim(im):
bg = Image.new(im.mode, im.size, im.getpixel((0,0)))
diff = ImageChops.difference(im, bg)
diff = ImageChops.add(diff, diff, 2.0, -100)
bbox = diff.getbbox()
if bbox:
return im.crop(bbox)
return im
model = get_model()
tokenizer = get_tokenizer()
# check if xelatex is installed
xelatex_installed = os.system('which xelatex > /dev/null 2>&1') == 0
if "start" not in st.session_state:
st.session_state["start"] = 1
if xelatex_installed:
st.toast('Hooray!', icon='🎉')
time.sleep(0.5)
st.toast("Xelatex have been detected.", icon='')
else:
st.error('xelatex is not installed. Please install it before using TexTeller.')
# ============================ pages =============================== #
html_string = '''
<h1 style="color: orange; text-align: center;">
✨ TexTeller ✨
</h1>
'''
st.markdown(html_string, unsafe_allow_html=True)
if "start" not in st.session_state:
st.balloons()
st.session_state["start"] = 1
uploaded_file = st.file_uploader("",type=['jpg', 'png', 'pdf'])
uploaded_file = st.file_uploader("",type=['jpg', 'png'])
if xelatex_installed:
st.caption('🥳 Xelatex have been detected, rendered image will be displayed in the web page.')
else:
st.caption('😭 Xelatex is not detected, please check the resulting latex code by yourself, or check ... to have your xelatex setup ready.')
if uploaded_file:
img = Image.open(uploaded_file)
@@ -44,13 +146,6 @@ if uploaded_file:
png_file_path = os.path.join(temp_dir, 'image.png')
img.save(png_file_path, 'PNG')
def get_image_base64(img_file):
buffered = io.BytesIO()
img_file.seek(0)
img = Image.open(img_file)
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
img_base64 = get_image_base64(uploaded_file)
st.markdown(f"""
@@ -62,7 +157,8 @@ if uploaded_file:
display: block;
margin-left: auto;
margin-right: auto;
max-width: 700px;
max-width: 500px;
max-height: 500px;
}}
</style>
<div class="centered-container">
@@ -71,7 +167,6 @@ if uploaded_file:
</div>
""", unsafe_allow_html=True)
st.write("")
st.write("")
with st.spinner("Predicting..."):
@@ -83,11 +178,54 @@ if uploaded_file:
True if os.environ['USE_CUDA'] == 'True' else False,
int(os.environ['NUM_BEAM'])
)[0]
if not xelatex_installed:
st.markdown(fail_gif_html, unsafe_allow_html=True)
st.warning('Unable to find xelatex to render image. Please check the prediction results yourself.', icon="🤡")
txt = st.text_area(
":red[Predicted formula]",
TeXTeller_result,
height=150,
)
else:
is_successed = rendering(TeXTeller_result, Path(temp_dir))
if is_successed:
# st.code(TeXTeller_result, language='latex')
# st.subheader(':rainbow[Predict] :sunglasses:', divider='rainbow')
st.subheader(':sunglasses:', divider='gray')
st.latex(TeXTeller_result)
st.code(TeXTeller_result, language='latex')
st.success('Done!')
img_base64 = get_image_base64(pdf_to_pngbytes(Path(temp_dir) / 'build' / 'formula.pdf'))
st.markdown(suc_gif_html, unsafe_allow_html=True)
st.success('Successfully rendered!', icon="")
txt = st.text_area(
":red[Predicted formula]",
TeXTeller_result,
height=150,
)
# st.latex(TeXTeller_result)
st.markdown(f"""
<style>
.centered-container {{
text-align: center;
}}
.centered-image {{
display: block;
margin-left: auto;
margin-right: auto;
max-width: 500px;
max-height: 500px;
}}
</style>
<div class="centered-container">
<img src="data:image/png;base64,{img_base64}" class="centered-image" alt="Input image">
</div>
""", unsafe_allow_html=True)
else:
st.markdown(fail_gif_html, unsafe_allow_html=True)
st.error('Rendering failed. You can try using a higher resolution image or splitting the multi line formula into a single line for better results.', icon="")
txt = st.text_area(
":red[Predicted formula]",
TeXTeller_result,
height=150,
)
shutil.rmtree(temp_dir)
# ============================ pages =============================== #