[refactor] Init

This commit is contained in:
OleehyO
2025-04-16 14:23:02 +00:00
parent 0e32f3f3bf
commit 06edd104e2
101 changed files with 1854 additions and 2758 deletions

241
texteller/api/inference.py Normal file
View File

@@ -0,0 +1,241 @@
import re
import time
from collections import Counter
from typing import Literal
import cv2
import numpy as np
import torch
from onnxruntime import InferenceSession
from optimum.onnxruntime import ORTModelForVision2Seq
from transformers import GenerationConfig, RobertaTokenizerFast
from texteller.constants import MAX_TOKEN_SIZE
from texteller.logger import get_logger
from texteller.paddleocr import predict_det, predict_rec
from texteller.types import Bbox, TexTellerModel
from texteller.utils import (
bbox_merge,
get_device,
mask_img,
readimgs,
remove_style,
slice_from_image,
split_conflict,
transform,
add_newlines,
)
from .detection import latex_detect
from .format import format_latex
from .katex import to_katex
_logger = get_logger()
def img2latex(
model: TexTellerModel,
tokenizer: RobertaTokenizerFast,
images: list[str] | list[np.ndarray],
device: torch.device | None = None,
out_format: Literal["latex", "katex"] = "latex",
keep_style: bool = False,
max_tokens: int = MAX_TOKEN_SIZE,
num_beams: int = 1,
no_repeat_ngram_size: int = 0,
) -> list[str]:
"""
Convert images to LaTeX or KaTeX formatted strings.
Args:
model: The TexTeller or ORTModelForVision2Seq model instance
tokenizer: The tokenizer for the model
images: List of image paths or numpy arrays (RGB format)
device: The torch device to use (defaults to available GPU or CPU)
out_format: Output format, either "latex" or "katex"
keep_style: Whether to keep the style of the LaTeX
max_tokens: Maximum number of tokens to generate
num_beams: Number of beams for beam search
no_repeat_ngram_size: Size of n-grams to prevent repetition
Returns:
List of LaTeX or KaTeX strings corresponding to each input image
Example usage:
>>> import torch
>>> from texteller import load_model, load_tokenizer, img2latex
>>> model = load_model(model_path=None, use_onnx=False)
>>> tokenizer = load_tokenizer(tokenizer_path=None)
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> res = img2latex(model, tokenizer, ["path/to/image.png"], device=device, out_format="katex")
"""
assert isinstance(images, list)
assert len(images) > 0
if device is None:
device = get_device()
if device.type != model.device.type:
if isinstance(model, ORTModelForVision2Seq):
_logger.warning(
f"Onnxruntime device mismatch: detected {str(device)} but model is on {str(model.device)}, using {str(model.device)} instead"
)
else:
model = model.to(device=device)
if isinstance(images[0], str):
images = readimgs(images)
else: # already numpy array(rgb format)
assert isinstance(images[0], np.ndarray)
images = images
images = transform(images)
pixel_values = torch.stack(images)
generate_config = GenerationConfig(
max_new_tokens=max_tokens,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
no_repeat_ngram_size=no_repeat_ngram_size,
)
pred = model.generate(
pixel_values.to(model.device),
generation_config=generate_config,
)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
if out_format == "katex":
res = [to_katex(r) for r in res]
if not keep_style:
res = [remove_style(r) for r in res]
res = [format_latex(r) for r in res]
res = [add_newlines(r) for r in res]
return res
def paragraph2md(
img_path: str,
latexdet_model: InferenceSession,
textdet_model: predict_det.TextDetector,
textrec_model: predict_rec.TextRecognizer,
latexrec_model: TexTellerModel,
tokenizer: RobertaTokenizerFast,
device: torch.device | None = None,
num_beams=1,
) -> str:
"""
Input a mixed image of formula text and output str (in markdown syntax)
"""
img = cv2.imread(img_path)
corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])]
bg_color = np.array(Counter(corners).most_common(1)[0][0])
start_time = time.time()
latex_bboxes = latex_detect(img_path, latexdet_model)
end_time = time.time()
_logger.info(f"latex_det_model time: {end_time - start_time:.2f}s")
latex_bboxes = sorted(latex_bboxes)
latex_bboxes = bbox_merge(latex_bboxes)
masked_img = mask_img(img, latex_bboxes, bg_color)
start_time = time.time()
det_prediction, _ = textdet_model(masked_img)
end_time = time.time()
_logger.info(f"ocr_det_model time: {end_time - start_time:.2f}s")
ocr_bboxes = [
Bbox(
p[0][0],
p[0][1],
p[3][1] - p[0][1],
p[1][0] - p[0][0],
label="text",
confidence=None,
content=None,
)
for p in det_prediction
]
ocr_bboxes = sorted(ocr_bboxes)
ocr_bboxes = bbox_merge(ocr_bboxes)
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
sliced_imgs: list[np.ndarray] = slice_from_image(img, ocr_bboxes)
start_time = time.time()
rec_predictions, _ = textrec_model(sliced_imgs)
end_time = time.time()
_logger.info(f"ocr_rec_model time: {end_time - start_time:.2f}s")
assert len(rec_predictions) == len(ocr_bboxes)
for content, bbox in zip(rec_predictions, ocr_bboxes):
bbox.content = content[0]
latex_imgs = []
for bbox in latex_bboxes:
latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w])
start_time = time.time()
latex_rec_res = img2latex(
model=latexrec_model,
tokenizer=tokenizer,
images=latex_imgs,
num_beams=num_beams,
out_format="katex",
device=device,
keep_style=False,
)
end_time = time.time()
_logger.info(f"latex_rec_model time: {end_time - start_time:.2f}s")
for bbox, content in zip(latex_bboxes, latex_rec_res):
if bbox.label == "embedding":
bbox.content = " $" + content + "$ "
elif bbox.label == "isolated":
bbox.content = "\n\n" + r"$$" + content + r"$$" + "\n\n"
bboxes = sorted(ocr_bboxes + latex_bboxes)
if bboxes == []:
return ""
md = ""
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
for curr in bboxes:
# Add the formula number back to the isolated formula
if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr):
curr.content = curr.content.strip()
if curr.content.startswith("(") and curr.content.endswith(")"):
curr.content = curr.content[1:-1]
if re.search(r"\\tag\{.*\}$", md[:-4]) is not None:
# in case of multiple tag
md = md[:-5] + f", {curr.content}" + "}" + md[-4:]
else:
md = md[:-4] + f"\\tag{{{curr.content}}}" + md[-4:]
continue
if not prev.same_row(curr):
md += " "
if curr.label == "embedding":
# remove the bold effect from inline formulas
curr.content = remove_style(curr.content)
# change split environment into aligned
curr.content = curr.content.replace(r"\begin{split}", r"\begin{aligned}")
curr.content = curr.content.replace(r"\end{split}", r"\end{aligned}")
# remove extra spaces (keeping only one)
curr.content = re.sub(r" +", " ", curr.content)
assert curr.content.startswith("$") and curr.content.endswith("$")
curr.content = " $" + curr.content.strip("$") + "$ "
md += curr.content
prev = curr
return md.strip()