前端更新, inference.py更新

1) 前端支持剪贴板粘贴图片.
2) 前端支持模型配置.
3) 修改了inference.py的接口.
4) 删除了不必要的文件
This commit is contained in:
三洋三洋
2024-04-17 09:12:07 +00:00
parent 66d4902871
commit 3cebc2eb2a
11 changed files with 181 additions and 105 deletions

View File

@@ -18,10 +18,16 @@ if __name__ == '__main__':
help='path to the input image'
)
parser.add_argument(
'-cuda',
default=False,
action='store_true',
help='use cuda or not'
'--inference-mode',
type=str,
default='cpu',
help='Inference mode, select one of cpu, cuda, or mps'
)
parser.add_argument(
'--num-beam',
type=int,
default=1,
help='number of beam search for decoding'
)
args = parser.parse_args()
@@ -33,6 +39,6 @@ if __name__ == '__main__':
img = cv.imread(args.img)
print('Inference...')
res = latex_inference(latex_rec_model, tokenizer, [img], args.cuda)
res = latex_inference(latex_rec_model, tokenizer, [img], inf_mode=args.inference_mode, num_beams=args.num_beam)
res = to_katex(res[0])
print(res)

View File

@@ -14,7 +14,7 @@ def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs_path: Union[List[str], List[np.ndarray]],
use_cuda: bool,
inf_mode: str = 'cpu',
num_beams: int = 1,
) -> List[str]:
model.eval()
@@ -26,9 +26,8 @@ def inference(
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
if use_cuda:
model = model.to('cuda')
pixel_values = pixel_values.to('cuda')
model = model.to(inf_mode)
pixel_values = pixel_values.to(inf_mode)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,

View File

@@ -11,22 +11,22 @@ if __name__ == '__main__':
os.chdir(Path(__file__).resolve().parent)
parser = argparse.ArgumentParser()
parser.add_argument(
'-img_dir',
type=str,
default="./subimages",
help='path to the directory containing input images'
'-img',
type=str,
required=True,
help='path to the input image'
)
parser.add_argument(
'-output_dir',
'--inference-mode',
type=str,
default="./results",
help='path to the output directory for storing recognition results'
default='cpu',
help='Inference mode, select one of cpu, cuda, or mps'
)
parser.add_argument(
'-cuda',
default=False,
action='store_true',
help='use cuda or not'
'--num-beam',
type=int,
default=1,
help='number of beam search for decoding'
)
args = parser.parse_args()
@@ -46,7 +46,7 @@ if __name__ == '__main__':
if img is not None:
print(f'Inference for {filename}...')
res = latex_inference(latex_rec_model, tokenizer, [img], args.cuda)
res = latex_inference(latex_rec_model, tokenizer, [img], inf_mode=args.inference_mode, num_beams=args.num_beam)
res = to_katex(res[0])
# Save the recognition result to a text file

View File

@@ -23,8 +23,8 @@ parser.add_argument('--num_replicas', type=int, default=1)
parser.add_argument('--ncpu_per_replica', type=float, default=1.0)
parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
parser.add_argument('--use_cuda', action='store_true', default=False)
parser.add_argument('--num_beam', type=int, default=1)
parser.add_argument('--inference-mode', type=str, default='cpu')
parser.add_argument('--num_beams', type=int, default=1)
args = parser.parse_args()
if args.ngpu_per_replica > 0 and not args.use_cuda:
@@ -43,18 +43,21 @@ class TexTellerServer:
self,
checkpoint_path: str,
tokenizer_path: str,
use_cuda: bool = False,
num_beam: int = 1
inf_mode: str = 'cpu',
num_beams: int = 1
) -> None:
self.model = TexTeller.from_pretrained(checkpoint_path)
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
self.use_cuda = use_cuda
self.num_beam = num_beam
self.inf_mode = inf_mode
self.num_beams = num_beams
self.model = self.model.to('cuda') if use_cuda else self.model
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
def predict(self, image_nparray) -> str:
return inference(self.model, self.tokenizer, [image_nparray], self.use_cuda, self.num_beam)[0]
return inference(
self.model, self.tokenizer, [image_nparray],
inf_mode=self.inf_mode, num_beams=self.num_beams
)[0]
@serve.deployment()
@@ -78,7 +81,11 @@ if __name__ == '__main__':
tknz_dir = args.tokenizer_dir
serve.start(http_options={"port": args.server_port})
texteller_server = TexTellerServer.bind(ckpt_dir, tknz_dir, use_cuda=args.use_cuda, num_beam=args.num_beam)
texteller_server = TexTellerServer.bind(
ckpt_dir, tknz_dir,
inf_mode=args.inference_mode,
num_beams=args.num_beams
)
ingress = Ingress.bind(texteller_server)
ingress_handle = serve.run(ingress, route_prefix="/predict")

View File

@@ -3,8 +3,6 @@ SETLOCAL ENABLEEXTENSIONS
set CHECKPOINT_DIR=default
set TOKENIZER_DIR=default
set USE_CUDA=False REM True or False (case-sensitive)
set NUM_BEAM=1
streamlit run web.py

View File

@@ -3,7 +3,5 @@ set -exu
export CHECKPOINT_DIR="default"
export TOKENIZER_DIR="default"
export USE_CUDA=False # True or False (case-sensitive)
export NUM_BEAM=1
streamlit run web.py

View File

@@ -6,16 +6,22 @@ import shutil
import streamlit as st
from PIL import Image
from streamlit_paste_button import paste_image_button as pbutton
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
from utils import to_katex
st.set_page_config(
page_title="TexTeller",
page_icon="🧮"
)
html_string = '''
<h1 style="color: black; text-align: center;">
<img src="https://slackmojis.com/emojis/429-troll/download" width="50">
TexTeller
<img src="https://slackmojis.com/emojis/429-troll/download" width="50">
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
</h1>
'''
@@ -35,29 +41,6 @@ fail_gif_html = '''
</h1>
'''
tex = r'''
\documentclass{{article}}
\usepackage[
left=1in, % 左边距
right=1in, % 右边距
top=1in, % 上边距
bottom=1in,% 下边距
paperwidth=40cm, % 页面宽度
paperheight=40cm % 页面高度这里以A4纸为例
]{{geometry}}
\usepackage[utf8]{{inputenc}}
\usepackage{{multirow,multicol,amsmath,amsfonts,amssymb,mathtools,bm,mathrsfs,wasysym,amsbsy,upgreek,mathalfa,stmaryrd,mathrsfs,dsfont,amsthm,amsmath,multirow}}
\begin{{document}}
{formula}
\pagenumbering{{gobble}}
\end{{document}}
'''
@st.cache_resource
def get_model():
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
@@ -73,6 +56,12 @@ def get_image_base64(img_file):
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def on_file_upload():
st.session_state["UPLOADED_FILE_CHANGED"] = True
def change_side_bar():
st.session_state["CHANGE_SIDEBAR_FLAG"] = True
model = get_model()
tokenizer = get_tokenizer()
@@ -80,37 +69,106 @@ if "start" not in st.session_state:
st.session_state["start"] = 1
st.toast('Hooray!', icon='🎉')
if "UPLOADED_FILE_CHANGED" not in st.session_state:
st.session_state["UPLOADED_FILE_CHANGED"] = False
# ============================ pages =============================== #
if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
# ============================ begin sidebar =============================== #
with st.sidebar:
num_beams = 1
inf_mode = 'cpu'
st.markdown("# 🔨️ Config")
st.markdown("")
model_type = st.selectbox(
"Model type",
("TexTeller", "None"),
on_change=change_side_bar
)
if model_type == "TexTeller":
num_beams = st.number_input(
'Number of beams',
min_value=1,
max_value=20,
step=1,
on_change=change_side_bar
)
inf_mode = st.radio(
"Inference mode",
("cpu", "cuda", "mps"),
on_change=change_side_bar
)
# ============================ end sidebar =============================== #
# ============================ begin pages =============================== #
st.markdown(html_string, unsafe_allow_html=True)
uploaded_file = st.file_uploader("",type=['jpg', 'png', 'pdf'])
uploaded_file = st.file_uploader(
" ",
type=['jpg', 'png'],
on_change=on_file_upload
)
paste_result = pbutton(
label="📋 Paste an image",
background_color="#5BBCFF",
hover_background_color="#3498db",
)
st.write("")
if st.session_state["CHANGE_SIDEBAR_FLAG"] == True:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
elif uploaded_file or paste_result.image_data is not None:
if st.session_state["UPLOADED_FILE_CHANGED"] == False and paste_result.image_data is not None:
uploaded_file = io.BytesIO()
paste_result.image_data.save(uploaded_file, format='PNG')
uploaded_file.seek(0)
if st.session_state["UPLOADED_FILE_CHANGED"] == True:
st.session_state["UPLOADED_FILE_CHANGED"] = False
if uploaded_file:
img = Image.open(uploaded_file)
temp_dir = tempfile.mkdtemp()
png_file_path = os.path.join(temp_dir, 'image.png')
img.save(png_file_path, 'PNG')
img_base64 = get_image_base64(uploaded_file)
with st.container(height=300):
img_base64 = get_image_base64(uploaded_file)
st.markdown(f"""
<style>
.centered-container {{
text-align: center;
}}
.centered-image {{
display: block;
margin-left: auto;
margin-right: auto;
max-height: 350px;
max-width: 100%;
}}
</style>
<div class="centered-container">
<img src="data:image/png;base64,{img_base64}" class="centered-image" alt="Input image">
</div>
""", unsafe_allow_html=True)
st.markdown(f"""
<style>
.centered-container {{
text-align: center;
}}
.centered-image {{
display: block;
margin-left: auto;
margin-right: auto;
max-width: 500px;
max-height: 500px;
}}
</style>
<div class="centered-container">
<img src="data:image/png;base64,{img_base64}" class="centered-image" alt="Input image">
<p style="color:gray;">Input image ({img.height}✖️{img.width})</p>
</div>
""", unsafe_allow_html=True)
@@ -123,15 +181,28 @@ if uploaded_file:
model,
tokenizer,
[png_file_path],
True if os.environ['USE_CUDA'] == 'True' else False,
int(os.environ['NUM_BEAM'])
inf_mode=inf_mode,
num_beams=num_beams
)[0]
st.success('Completed!', icon="")
st.markdown(suc_gif_html, unsafe_allow_html=True)
katex_res = to_katex(TexTeller_result)
st.text_area(":red[Predicted formula]", katex_res, height=150)
st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150)
st.latex(katex_res)
st.write("")
st.write("")
with st.expander(":star2: :gray[Tips for better results]"):
st.markdown('''
* :mag_right: Use a clear and high-resolution image.
* :scissors: Crop images as accurately as possible.
* :jigsaw: Split large multi line formulas into smaller ones.
* :page_facing_up: Use images with **white background and black text** as much as possible.
* :book: Use a font with good readability.
''')
shutil.rmtree(temp_dir)
# ============================ pages =============================== #
paste_result.image_data = None
# ============================ end pages =============================== #