[refactor] Init

This commit is contained in:
OleehyO
2025-04-16 14:23:02 +00:00
parent 0e32f3f3bf
commit 06edd104e2
101 changed files with 1854 additions and 2758 deletions

View File

@@ -0,0 +1,3 @@
"""
CLI commands for TexTeller
"""

View File

@@ -0,0 +1,51 @@
"""
CLI command for formula inference from images.
"""
import click
from texteller.api import img2latex, load_model, load_tokenizer
@click.command()
@click.argument("image_path", type=click.Path(exists=True, file_okay=True, dir_okay=False))
@click.option(
"--model-path",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
default=None,
help="Path to the model dir path, if not provided, will use model from huggingface repo",
)
@click.option(
"--tokenizer-path",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
default=None,
help="Path to the tokenizer dir path, if not provided, will use tokenizer from huggingface repo",
)
@click.option(
"--output-format",
type=click.Choice(["latex", "katex"]),
default="katex",
help="Output format, either latex or katex",
)
@click.option(
"--keep-style",
is_flag=True,
default=False,
help="Whether to keep the style of the LaTeX (e.g. bold, italic, etc.)",
)
def inference(image_path, model_path, tokenizer_path, output_format, keep_style):
"""
CLI command for formula inference from images.
"""
model = load_model(model_dir=model_path)
tknz = load_tokenizer(tokenizer_dir=tokenizer_path)
pred = img2latex(
model=model,
tokenizer=tknz,
images=[image_path],
out_format=output_format,
keep_style=keep_style,
)[0]
click.echo(f"Predicted LaTeX: ```\n{pred}\n```")

View File

@@ -0,0 +1,106 @@
"""
CLI commands for launching server.
"""
import sys
import time
import click
from ray import serve
from texteller.globals import Globals
from texteller.utils import get_device
@click.command()
@click.option(
"-ckpt",
"--checkpoint_dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
default=None,
help="Path to the checkpoint directory, if not provided, will use model from huggingface repo",
)
@click.option(
"-tknz",
"--tokenizer_dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
default=None,
help="Path to the tokenizer directory, if not provided, will use tokenizer from huggingface repo",
)
@click.option(
"-p",
"--port",
type=int,
default=8000,
help="Port to run the server on",
)
@click.option(
"--num-replicas",
type=int,
default=1,
help="Number of replicas to run the server on",
)
@click.option(
"--ncpu-per-replica",
type=float,
default=1.0,
help="Number of CPUs per replica",
)
@click.option(
"--ngpu-per-replica",
type=float,
default=1.0,
help="Number of GPUs per replica",
)
@click.option(
"--num-beams",
type=int,
default=1,
help="Number of beams to use",
)
@click.option(
"--use-onnx",
is_flag=True,
type=bool,
default=False,
help="Use ONNX runtime",
)
def launch(
checkpoint_dir,
tokenizer_dir,
port,
num_replicas,
ncpu_per_replica,
ngpu_per_replica,
num_beams,
use_onnx,
):
"""Launch the api server"""
device = get_device()
if ngpu_per_replica > 0 and not device.type == "cuda":
click.echo(
click.style(
f"Error: --ngpu-per-replica > 0 but detected device is {device.type}",
fg="red",
)
)
sys.exit(1)
Globals().num_replicas = num_replicas
Globals().ncpu_per_replica = ncpu_per_replica
Globals().ngpu_per_replica = ngpu_per_replica
from texteller.cli.commands.launch.server import Ingress, TexTellerServer
serve.start(http_options={"host": "0.0.0.0", "port": port})
rec_server = TexTellerServer.bind(
checkpoint_dir=checkpoint_dir,
tokenizer_dir=tokenizer_dir,
use_onnx=use_onnx,
num_beams=num_beams,
)
ingress = Ingress.bind(rec_server)
serve.run(ingress, route_prefix="/predict")
while True:
time.sleep(1)

View File

@@ -0,0 +1,69 @@
import numpy as np
import cv2
from starlette.requests import Request
from ray import serve
from ray.serve.handle import DeploymentHandle
from texteller.api import load_model, load_tokenizer, img2latex
from texteller.utils import get_device
from texteller.globals import Globals
from typing import Literal
@serve.deployment(
num_replicas=Globals().num_replicas,
ray_actor_options={
"num_cpus": Globals().ncpu_per_replica,
"num_gpus": Globals().ngpu_per_replica * 1.0 / 2,
},
)
class TexTellerServer:
def __init__(
self,
checkpoint_dir: str,
tokenizer_dir: str,
use_onnx: bool = False,
out_format: Literal["latex", "katex"] = "katex",
keep_style: bool = False,
num_beams: int = 1,
) -> None:
self.model = load_model(
model_dir=checkpoint_dir,
use_onnx=use_onnx,
)
self.tokenizer = load_tokenizer(tokenizer_dir=tokenizer_dir)
self.num_beams = num_beams
self.out_format = out_format
self.keep_style = keep_style
if not use_onnx:
self.model = self.model.to(get_device())
def predict(self, image_nparray: np.ndarray) -> str:
return img2latex(
model=self.model,
tokenizer=self.tokenizer,
images=[image_nparray],
device=get_device(),
out_format=self.out_format,
keep_style=self.keep_style,
num_beams=self.num_beams,
)[0]
@serve.deployment()
class Ingress:
def __init__(self, rec_server: DeploymentHandle) -> None:
self.texteller_server = rec_server
async def __call__(self, request: Request) -> str:
form = await request.form()
img_rb = await form["img"].read()
img_nparray = np.frombuffer(img_rb, np.uint8)
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
pred = await self.texteller_server.predict.remote(img_nparray)
return pred

View File

@@ -0,0 +1,9 @@
import os
import click
from pathlib import Path
@click.command()
def web():
"""Launch the web interface for TexTeller."""
os.system(f"streamlit run {Path(__file__).parent / 'streamlit_demo.py'}")

View File

@@ -0,0 +1,225 @@
import base64
import io
import os
import re
import shutil
import tempfile
import streamlit as st
from PIL import Image
from streamlit_paste_button import paste_image_button as pbutton
from texteller.api import (
img2latex,
load_latexdet_model,
load_model,
load_textdet_model,
load_textrec_model,
load_tokenizer,
paragraph2md,
)
from texteller.cli.commands.web.style import (
HEADER_HTML,
IMAGE_EMBED_HTML,
IMAGE_INFO_HTML,
SUCCESS_GIF_HTML,
)
from texteller.utils import str2device
st.set_page_config(page_title="TexTeller", page_icon="🧮")
@st.cache_resource
def get_texteller(use_onnx):
return load_model(use_onnx=use_onnx)
@st.cache_resource
def get_tokenizer():
return load_tokenizer()
@st.cache_resource
def get_latexdet_model():
return load_latexdet_model()
@st.cache_resource()
def get_textrec_model():
return load_textrec_model()
@st.cache_resource()
def get_textdet_model():
return load_textdet_model()
def get_image_base64(img_file):
buffered = io.BytesIO()
img_file.seek(0)
img = Image.open(img_file)
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def on_file_upload():
st.session_state["UPLOADED_FILE_CHANGED"] = True
def change_side_bar():
st.session_state["CHANGE_SIDEBAR_FLAG"] = True
if "start" not in st.session_state:
st.session_state["start"] = 1
st.toast("Hooray!", icon="🎉")
if "UPLOADED_FILE_CHANGED" not in st.session_state:
st.session_state["UPLOADED_FILE_CHANGED"] = False
if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
if "INF_MODE" not in st.session_state:
st.session_state["INF_MODE"] = "Formula recognition"
# ====== <sidebar> ======
with st.sidebar:
num_beams = 1
st.markdown("# 🔨️ Config")
st.markdown("")
inf_mode = st.selectbox(
"Inference mode",
("Formula recognition", "Paragraph recognition"),
on_change=change_side_bar,
)
num_beams = st.number_input(
"Number of beams", min_value=1, max_value=20, step=1, on_change=change_side_bar
)
device = st.radio("device", ("cpu", "cuda", "mps"), on_change=change_side_bar)
st.markdown("## Seedup")
use_onnx = st.toggle("ONNX Runtime ")
# ====== </sidebar> ======
# ====== <page> ======
latexrec_model = get_texteller(use_onnx)
tokenizer = get_tokenizer()
if inf_mode == "Paragraph recognition":
latexdet_model = get_latexdet_model()
textrec_model = get_textrec_model()
textdet_model = get_textdet_model()
st.markdown(HEADER_HTML, unsafe_allow_html=True)
uploaded_file = st.file_uploader(" ", type=["jpg", "png"], on_change=on_file_upload)
paste_result = pbutton(
label="📋 Paste an image",
background_color="#5BBCFF",
hover_background_color="#3498db",
)
st.write("")
if st.session_state["CHANGE_SIDEBAR_FLAG"] is True:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
elif uploaded_file or paste_result.image_data is not None:
if st.session_state["UPLOADED_FILE_CHANGED"] is False and paste_result.image_data is not None:
uploaded_file = io.BytesIO()
paste_result.image_data.save(uploaded_file, format="PNG")
uploaded_file.seek(0)
if st.session_state["UPLOADED_FILE_CHANGED"] is True:
st.session_state["UPLOADED_FILE_CHANGED"] = False
img = Image.open(uploaded_file)
temp_dir = tempfile.mkdtemp()
png_fpath = os.path.join(temp_dir, "image.png")
img.save(png_fpath, "PNG")
with st.container(height=300):
img_base64 = get_image_base64(uploaded_file)
st.markdown(
IMAGE_EMBED_HTML.format(img_base64=img_base64),
unsafe_allow_html=True,
)
st.markdown(
IMAGE_INFO_HTML.format(img_height=img.height, img_width=img.width),
unsafe_allow_html=True,
)
st.write("")
with st.spinner("Predicting..."):
if inf_mode == "Formula recognition":
pred = img2latex(
model=latexrec_model,
tokenizer=tokenizer,
images=[png_fpath],
device=str2device(device),
out_format="katex",
num_beams=num_beams,
keep_style=False,
)[0]
else:
pred = paragraph2md(
img_path=png_fpath,
latexdet_model=latexdet_model,
textdet_model=textdet_model,
textrec_model=textrec_model,
latexrec_model=latexrec_model,
tokenizer=tokenizer,
device=str2device(device),
num_beams=num_beams,
)
st.success("Completed!", icon="")
# st.markdown(SUCCESS_GIF_HTML, unsafe_allow_html=True)
# st.text_area("Predicted LaTeX", pred, height=150)
if inf_mode == "Formula recognition":
st.code(pred, language="latex")
elif inf_mode == "Paragraph recognition":
st.code(pred, language="markdown")
else:
raise ValueError(f"Invalid inference mode: {inf_mode}")
if inf_mode == "Formula recognition":
st.latex(pred)
elif inf_mode == "Paragraph recognition":
mixed_res = re.split(r"(\$\$.*?\$\$)", pred, flags=re.DOTALL)
for text in mixed_res:
if text.startswith("$$") and text.endswith("$$"):
st.latex(text.strip("$$"))
else:
st.markdown(text)
st.write("")
st.write("")
with st.expander(":star2: :gray[Tips for better results]"):
st.markdown("""
* :mag_right: Use a clear and high-resolution image.
* :scissors: Crop images as accurately as possible.
* :jigsaw: Split large multi line formulas into smaller ones.
* :page_facing_up: Use images with **white background and black text** as much as possible.
* :book: Use a font with good readability.
""")
shutil.rmtree(temp_dir)
paste_result.image_data = None
# ====== </page> ======

View File

@@ -0,0 +1,55 @@
from texteller.utils import lines_dedent
HEADER_HTML = lines_dedent("""
<h1 style="color: black; text-align: center;">
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
</h1>
""")
SUCCESS_GIF_HTML = lines_dedent("""
<h1 style="color: black; text-align: center;">
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
</h1>
""")
FAIL_GIF_HTML = lines_dedent("""
<h1 style="color: black; text-align: center;">
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download">
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download">
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download">
</h1>
""")
IMAGE_EMBED_HTML = lines_dedent("""
<style>
.centered-container {{
text-align: center;
}}
.centered-image {{
display: block;
margin-left: auto;
margin-right: auto;
max-height: 350px;
max-width: 100%;
}}
</style>
<div class="centered-container">
<img src="data:image/png;base64,{img_base64}" class="centered-image" alt="Input image">
</div>
""")
IMAGE_INFO_HTML = lines_dedent("""
<style>
.centered-container {{
text-align: center;
}}
</style>
<div class="centered-container">
<p style="color:gray;">Input image ({img_height}✖️{img_width})</p>
</div>
""")