[refactor] Init
This commit is contained in:
3
texteller/cli/commands/__init__.py
Normal file
3
texteller/cli/commands/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
CLI commands for TexTeller
|
||||
"""
|
||||
51
texteller/cli/commands/inference.py
Normal file
51
texteller/cli/commands/inference.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
CLI command for formula inference from images.
|
||||
"""
|
||||
|
||||
import click
|
||||
|
||||
from texteller.api import img2latex, load_model, load_tokenizer
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("image_path", type=click.Path(exists=True, file_okay=True, dir_okay=False))
|
||||
@click.option(
|
||||
"--model-path",
|
||||
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
||||
default=None,
|
||||
help="Path to the model dir path, if not provided, will use model from huggingface repo",
|
||||
)
|
||||
@click.option(
|
||||
"--tokenizer-path",
|
||||
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
||||
default=None,
|
||||
help="Path to the tokenizer dir path, if not provided, will use tokenizer from huggingface repo",
|
||||
)
|
||||
@click.option(
|
||||
"--output-format",
|
||||
type=click.Choice(["latex", "katex"]),
|
||||
default="katex",
|
||||
help="Output format, either latex or katex",
|
||||
)
|
||||
@click.option(
|
||||
"--keep-style",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to keep the style of the LaTeX (e.g. bold, italic, etc.)",
|
||||
)
|
||||
def inference(image_path, model_path, tokenizer_path, output_format, keep_style):
|
||||
"""
|
||||
CLI command for formula inference from images.
|
||||
"""
|
||||
model = load_model(model_dir=model_path)
|
||||
tknz = load_tokenizer(tokenizer_dir=tokenizer_path)
|
||||
|
||||
pred = img2latex(
|
||||
model=model,
|
||||
tokenizer=tknz,
|
||||
images=[image_path],
|
||||
out_format=output_format,
|
||||
keep_style=keep_style,
|
||||
)[0]
|
||||
|
||||
click.echo(f"Predicted LaTeX: ```\n{pred}\n```")
|
||||
106
texteller/cli/commands/launch/__init__.py
Normal file
106
texteller/cli/commands/launch/__init__.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
CLI commands for launching server.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
import click
|
||||
from ray import serve
|
||||
|
||||
from texteller.globals import Globals
|
||||
from texteller.utils import get_device
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"-ckpt",
|
||||
"--checkpoint_dir",
|
||||
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
||||
default=None,
|
||||
help="Path to the checkpoint directory, if not provided, will use model from huggingface repo",
|
||||
)
|
||||
@click.option(
|
||||
"-tknz",
|
||||
"--tokenizer_dir",
|
||||
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
||||
default=None,
|
||||
help="Path to the tokenizer directory, if not provided, will use tokenizer from huggingface repo",
|
||||
)
|
||||
@click.option(
|
||||
"-p",
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to run the server on",
|
||||
)
|
||||
@click.option(
|
||||
"--num-replicas",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of replicas to run the server on",
|
||||
)
|
||||
@click.option(
|
||||
"--ncpu-per-replica",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Number of CPUs per replica",
|
||||
)
|
||||
@click.option(
|
||||
"--ngpu-per-replica",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Number of GPUs per replica",
|
||||
)
|
||||
@click.option(
|
||||
"--num-beams",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of beams to use",
|
||||
)
|
||||
@click.option(
|
||||
"--use-onnx",
|
||||
is_flag=True,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Use ONNX runtime",
|
||||
)
|
||||
def launch(
|
||||
checkpoint_dir,
|
||||
tokenizer_dir,
|
||||
port,
|
||||
num_replicas,
|
||||
ncpu_per_replica,
|
||||
ngpu_per_replica,
|
||||
num_beams,
|
||||
use_onnx,
|
||||
):
|
||||
"""Launch the api server"""
|
||||
device = get_device()
|
||||
if ngpu_per_replica > 0 and not device.type == "cuda":
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error: --ngpu-per-replica > 0 but detected device is {device.type}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
Globals().num_replicas = num_replicas
|
||||
Globals().ncpu_per_replica = ncpu_per_replica
|
||||
Globals().ngpu_per_replica = ngpu_per_replica
|
||||
from texteller.cli.commands.launch.server import Ingress, TexTellerServer
|
||||
|
||||
serve.start(http_options={"host": "0.0.0.0", "port": port})
|
||||
rec_server = TexTellerServer.bind(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
tokenizer_dir=tokenizer_dir,
|
||||
use_onnx=use_onnx,
|
||||
num_beams=num_beams,
|
||||
)
|
||||
ingress = Ingress.bind(rec_server)
|
||||
|
||||
serve.run(ingress, route_prefix="/predict")
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
69
texteller/cli/commands/launch/server.py
Normal file
69
texteller/cli/commands/launch/server.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from starlette.requests import Request
|
||||
from ray import serve
|
||||
from ray.serve.handle import DeploymentHandle
|
||||
|
||||
from texteller.api import load_model, load_tokenizer, img2latex
|
||||
from texteller.utils import get_device
|
||||
from texteller.globals import Globals
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@serve.deployment(
|
||||
num_replicas=Globals().num_replicas,
|
||||
ray_actor_options={
|
||||
"num_cpus": Globals().ncpu_per_replica,
|
||||
"num_gpus": Globals().ngpu_per_replica * 1.0 / 2,
|
||||
},
|
||||
)
|
||||
class TexTellerServer:
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str,
|
||||
tokenizer_dir: str,
|
||||
use_onnx: bool = False,
|
||||
out_format: Literal["latex", "katex"] = "katex",
|
||||
keep_style: bool = False,
|
||||
num_beams: int = 1,
|
||||
) -> None:
|
||||
self.model = load_model(
|
||||
model_dir=checkpoint_dir,
|
||||
use_onnx=use_onnx,
|
||||
)
|
||||
self.tokenizer = load_tokenizer(tokenizer_dir=tokenizer_dir)
|
||||
self.num_beams = num_beams
|
||||
self.out_format = out_format
|
||||
self.keep_style = keep_style
|
||||
|
||||
if not use_onnx:
|
||||
self.model = self.model.to(get_device())
|
||||
|
||||
def predict(self, image_nparray: np.ndarray) -> str:
|
||||
return img2latex(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
images=[image_nparray],
|
||||
device=get_device(),
|
||||
out_format=self.out_format,
|
||||
keep_style=self.keep_style,
|
||||
num_beams=self.num_beams,
|
||||
)[0]
|
||||
|
||||
|
||||
@serve.deployment()
|
||||
class Ingress:
|
||||
def __init__(self, rec_server: DeploymentHandle) -> None:
|
||||
self.texteller_server = rec_server
|
||||
|
||||
async def __call__(self, request: Request) -> str:
|
||||
form = await request.form()
|
||||
img_rb = await form["img"].read()
|
||||
|
||||
img_nparray = np.frombuffer(img_rb, np.uint8)
|
||||
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
|
||||
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
|
||||
|
||||
pred = await self.texteller_server.predict.remote(img_nparray)
|
||||
return pred
|
||||
9
texteller/cli/commands/web/__init__.py
Normal file
9
texteller/cli/commands/web/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import os
|
||||
import click
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@click.command()
|
||||
def web():
|
||||
"""Launch the web interface for TexTeller."""
|
||||
os.system(f"streamlit run {Path(__file__).parent / 'streamlit_demo.py'}")
|
||||
225
texteller/cli/commands/web/streamlit_demo.py
Normal file
225
texteller/cli/commands/web/streamlit_demo.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import streamlit as st
|
||||
from PIL import Image
|
||||
from streamlit_paste_button import paste_image_button as pbutton
|
||||
|
||||
from texteller.api import (
|
||||
img2latex,
|
||||
load_latexdet_model,
|
||||
load_model,
|
||||
load_textdet_model,
|
||||
load_textrec_model,
|
||||
load_tokenizer,
|
||||
paragraph2md,
|
||||
)
|
||||
from texteller.cli.commands.web.style import (
|
||||
HEADER_HTML,
|
||||
IMAGE_EMBED_HTML,
|
||||
IMAGE_INFO_HTML,
|
||||
SUCCESS_GIF_HTML,
|
||||
)
|
||||
from texteller.utils import str2device
|
||||
|
||||
st.set_page_config(page_title="TexTeller", page_icon="🧮")
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def get_texteller(use_onnx):
|
||||
return load_model(use_onnx=use_onnx)
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def get_tokenizer():
|
||||
return load_tokenizer()
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def get_latexdet_model():
|
||||
return load_latexdet_model()
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
def get_textrec_model():
|
||||
return load_textrec_model()
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
def get_textdet_model():
|
||||
return load_textdet_model()
|
||||
|
||||
|
||||
def get_image_base64(img_file):
|
||||
buffered = io.BytesIO()
|
||||
img_file.seek(0)
|
||||
img = Image.open(img_file)
|
||||
img.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
|
||||
def on_file_upload():
|
||||
st.session_state["UPLOADED_FILE_CHANGED"] = True
|
||||
|
||||
|
||||
def change_side_bar():
|
||||
st.session_state["CHANGE_SIDEBAR_FLAG"] = True
|
||||
|
||||
|
||||
if "start" not in st.session_state:
|
||||
st.session_state["start"] = 1
|
||||
st.toast("Hooray!", icon="🎉")
|
||||
|
||||
if "UPLOADED_FILE_CHANGED" not in st.session_state:
|
||||
st.session_state["UPLOADED_FILE_CHANGED"] = False
|
||||
|
||||
if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
|
||||
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
|
||||
|
||||
if "INF_MODE" not in st.session_state:
|
||||
st.session_state["INF_MODE"] = "Formula recognition"
|
||||
|
||||
|
||||
# ====== <sidebar> ======
|
||||
|
||||
with st.sidebar:
|
||||
num_beams = 1
|
||||
|
||||
st.markdown("# 🔨️ Config")
|
||||
st.markdown("")
|
||||
|
||||
inf_mode = st.selectbox(
|
||||
"Inference mode",
|
||||
("Formula recognition", "Paragraph recognition"),
|
||||
on_change=change_side_bar,
|
||||
)
|
||||
|
||||
num_beams = st.number_input(
|
||||
"Number of beams", min_value=1, max_value=20, step=1, on_change=change_side_bar
|
||||
)
|
||||
|
||||
device = st.radio("device", ("cpu", "cuda", "mps"), on_change=change_side_bar)
|
||||
|
||||
st.markdown("## Seedup")
|
||||
use_onnx = st.toggle("ONNX Runtime ")
|
||||
|
||||
|
||||
# ====== </sidebar> ======
|
||||
|
||||
|
||||
# ====== <page> ======
|
||||
|
||||
latexrec_model = get_texteller(use_onnx)
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
if inf_mode == "Paragraph recognition":
|
||||
latexdet_model = get_latexdet_model()
|
||||
textrec_model = get_textrec_model()
|
||||
textdet_model = get_textdet_model()
|
||||
|
||||
st.markdown(HEADER_HTML, unsafe_allow_html=True)
|
||||
|
||||
uploaded_file = st.file_uploader(" ", type=["jpg", "png"], on_change=on_file_upload)
|
||||
|
||||
paste_result = pbutton(
|
||||
label="📋 Paste an image",
|
||||
background_color="#5BBCFF",
|
||||
hover_background_color="#3498db",
|
||||
)
|
||||
st.write("")
|
||||
|
||||
if st.session_state["CHANGE_SIDEBAR_FLAG"] is True:
|
||||
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
|
||||
elif uploaded_file or paste_result.image_data is not None:
|
||||
if st.session_state["UPLOADED_FILE_CHANGED"] is False and paste_result.image_data is not None:
|
||||
uploaded_file = io.BytesIO()
|
||||
paste_result.image_data.save(uploaded_file, format="PNG")
|
||||
uploaded_file.seek(0)
|
||||
|
||||
if st.session_state["UPLOADED_FILE_CHANGED"] is True:
|
||||
st.session_state["UPLOADED_FILE_CHANGED"] = False
|
||||
|
||||
img = Image.open(uploaded_file)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
png_fpath = os.path.join(temp_dir, "image.png")
|
||||
img.save(png_fpath, "PNG")
|
||||
|
||||
with st.container(height=300):
|
||||
img_base64 = get_image_base64(uploaded_file)
|
||||
|
||||
st.markdown(
|
||||
IMAGE_EMBED_HTML.format(img_base64=img_base64),
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
st.markdown(
|
||||
IMAGE_INFO_HTML.format(img_height=img.height, img_width=img.width),
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
st.write("")
|
||||
|
||||
with st.spinner("Predicting..."):
|
||||
if inf_mode == "Formula recognition":
|
||||
pred = img2latex(
|
||||
model=latexrec_model,
|
||||
tokenizer=tokenizer,
|
||||
images=[png_fpath],
|
||||
device=str2device(device),
|
||||
out_format="katex",
|
||||
num_beams=num_beams,
|
||||
keep_style=False,
|
||||
)[0]
|
||||
else:
|
||||
pred = paragraph2md(
|
||||
img_path=png_fpath,
|
||||
latexdet_model=latexdet_model,
|
||||
textdet_model=textdet_model,
|
||||
textrec_model=textrec_model,
|
||||
latexrec_model=latexrec_model,
|
||||
tokenizer=tokenizer,
|
||||
device=str2device(device),
|
||||
num_beams=num_beams,
|
||||
)
|
||||
|
||||
st.success("Completed!", icon="✅")
|
||||
# st.markdown(SUCCESS_GIF_HTML, unsafe_allow_html=True)
|
||||
# st.text_area("Predicted LaTeX", pred, height=150)
|
||||
if inf_mode == "Formula recognition":
|
||||
st.code(pred, language="latex")
|
||||
elif inf_mode == "Paragraph recognition":
|
||||
st.code(pred, language="markdown")
|
||||
else:
|
||||
raise ValueError(f"Invalid inference mode: {inf_mode}")
|
||||
|
||||
if inf_mode == "Formula recognition":
|
||||
st.latex(pred)
|
||||
elif inf_mode == "Paragraph recognition":
|
||||
mixed_res = re.split(r"(\$\$.*?\$\$)", pred, flags=re.DOTALL)
|
||||
for text in mixed_res:
|
||||
if text.startswith("$$") and text.endswith("$$"):
|
||||
st.latex(text.strip("$$"))
|
||||
else:
|
||||
st.markdown(text)
|
||||
|
||||
st.write("")
|
||||
st.write("")
|
||||
|
||||
with st.expander(":star2: :gray[Tips for better results]"):
|
||||
st.markdown("""
|
||||
* :mag_right: Use a clear and high-resolution image.
|
||||
* :scissors: Crop images as accurately as possible.
|
||||
* :jigsaw: Split large multi line formulas into smaller ones.
|
||||
* :page_facing_up: Use images with **white background and black text** as much as possible.
|
||||
* :book: Use a font with good readability.
|
||||
""")
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
paste_result.image_data = None
|
||||
|
||||
# ====== </page> ======
|
||||
55
texteller/cli/commands/web/style.py
Normal file
55
texteller/cli/commands/web/style.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from texteller.utils import lines_dedent
|
||||
|
||||
|
||||
HEADER_HTML = lines_dedent("""
|
||||
<h1 style="color: black; text-align: center;">
|
||||
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
|
||||
𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
|
||||
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
|
||||
</h1>
|
||||
""")
|
||||
|
||||
SUCCESS_GIF_HTML = lines_dedent("""
|
||||
<h1 style="color: black; text-align: center;">
|
||||
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
|
||||
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
|
||||
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
|
||||
</h1>
|
||||
""")
|
||||
|
||||
FAIL_GIF_HTML = lines_dedent("""
|
||||
<h1 style="color: black; text-align: center;">
|
||||
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download">
|
||||
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download">
|
||||
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download">
|
||||
</h1>
|
||||
""")
|
||||
|
||||
IMAGE_EMBED_HTML = lines_dedent("""
|
||||
<style>
|
||||
.centered-container {{
|
||||
text-align: center;
|
||||
}}
|
||||
.centered-image {{
|
||||
display: block;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
max-height: 350px;
|
||||
max-width: 100%;
|
||||
}}
|
||||
</style>
|
||||
<div class="centered-container">
|
||||
<img src="data:image/png;base64,{img_base64}" class="centered-image" alt="Input image">
|
||||
</div>
|
||||
""")
|
||||
|
||||
IMAGE_INFO_HTML = lines_dedent("""
|
||||
<style>
|
||||
.centered-container {{
|
||||
text-align: center;
|
||||
}}
|
||||
</style>
|
||||
<div class="centered-container">
|
||||
<p style="color:gray;">Input image ({img_height}✖️{img_width})</p>
|
||||
</div>
|
||||
""")
|
||||
Reference in New Issue
Block a user