[refactor] Init
This commit is contained in:
241
texteller/api/inference.py
Normal file
241
texteller/api/inference.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import re
|
||||
import time
|
||||
from collections import Counter
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnxruntime import InferenceSession
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
from transformers import GenerationConfig, RobertaTokenizerFast
|
||||
|
||||
from texteller.constants import MAX_TOKEN_SIZE
|
||||
from texteller.logger import get_logger
|
||||
from texteller.paddleocr import predict_det, predict_rec
|
||||
from texteller.types import Bbox, TexTellerModel
|
||||
from texteller.utils import (
|
||||
bbox_merge,
|
||||
get_device,
|
||||
mask_img,
|
||||
readimgs,
|
||||
remove_style,
|
||||
slice_from_image,
|
||||
split_conflict,
|
||||
transform,
|
||||
add_newlines,
|
||||
)
|
||||
|
||||
from .detection import latex_detect
|
||||
from .format import format_latex
|
||||
from .katex import to_katex
|
||||
|
||||
_logger = get_logger()
|
||||
|
||||
|
||||
def img2latex(
|
||||
model: TexTellerModel,
|
||||
tokenizer: RobertaTokenizerFast,
|
||||
images: list[str] | list[np.ndarray],
|
||||
device: torch.device | None = None,
|
||||
out_format: Literal["latex", "katex"] = "latex",
|
||||
keep_style: bool = False,
|
||||
max_tokens: int = MAX_TOKEN_SIZE,
|
||||
num_beams: int = 1,
|
||||
no_repeat_ngram_size: int = 0,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Convert images to LaTeX or KaTeX formatted strings.
|
||||
|
||||
Args:
|
||||
model: The TexTeller or ORTModelForVision2Seq model instance
|
||||
tokenizer: The tokenizer for the model
|
||||
images: List of image paths or numpy arrays (RGB format)
|
||||
device: The torch device to use (defaults to available GPU or CPU)
|
||||
out_format: Output format, either "latex" or "katex"
|
||||
keep_style: Whether to keep the style of the LaTeX
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
num_beams: Number of beams for beam search
|
||||
no_repeat_ngram_size: Size of n-grams to prevent repetition
|
||||
|
||||
Returns:
|
||||
List of LaTeX or KaTeX strings corresponding to each input image
|
||||
|
||||
Example usage:
|
||||
>>> import torch
|
||||
>>> from texteller import load_model, load_tokenizer, img2latex
|
||||
|
||||
>>> model = load_model(model_path=None, use_onnx=False)
|
||||
>>> tokenizer = load_tokenizer(tokenizer_path=None)
|
||||
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
>>> res = img2latex(model, tokenizer, ["path/to/image.png"], device=device, out_format="katex")
|
||||
"""
|
||||
assert isinstance(images, list)
|
||||
assert len(images) > 0
|
||||
|
||||
if device is None:
|
||||
device = get_device()
|
||||
|
||||
if device.type != model.device.type:
|
||||
if isinstance(model, ORTModelForVision2Seq):
|
||||
_logger.warning(
|
||||
f"Onnxruntime device mismatch: detected {str(device)} but model is on {str(model.device)}, using {str(model.device)} instead"
|
||||
)
|
||||
else:
|
||||
model = model.to(device=device)
|
||||
|
||||
if isinstance(images[0], str):
|
||||
images = readimgs(images)
|
||||
else: # already numpy array(rgb format)
|
||||
assert isinstance(images[0], np.ndarray)
|
||||
images = images
|
||||
|
||||
images = transform(images)
|
||||
pixel_values = torch.stack(images)
|
||||
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=max_tokens,
|
||||
num_beams=num_beams,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
)
|
||||
pred = model.generate(
|
||||
pixel_values.to(model.device),
|
||||
generation_config=generate_config,
|
||||
)
|
||||
|
||||
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||
|
||||
if out_format == "katex":
|
||||
res = [to_katex(r) for r in res]
|
||||
|
||||
if not keep_style:
|
||||
res = [remove_style(r) for r in res]
|
||||
|
||||
res = [format_latex(r) for r in res]
|
||||
res = [add_newlines(r) for r in res]
|
||||
return res
|
||||
|
||||
|
||||
def paragraph2md(
|
||||
img_path: str,
|
||||
latexdet_model: InferenceSession,
|
||||
textdet_model: predict_det.TextDetector,
|
||||
textrec_model: predict_rec.TextRecognizer,
|
||||
latexrec_model: TexTellerModel,
|
||||
tokenizer: RobertaTokenizerFast,
|
||||
device: torch.device | None = None,
|
||||
num_beams=1,
|
||||
) -> str:
|
||||
"""
|
||||
Input a mixed image of formula text and output str (in markdown syntax)
|
||||
"""
|
||||
img = cv2.imread(img_path)
|
||||
corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])]
|
||||
bg_color = np.array(Counter(corners).most_common(1)[0][0])
|
||||
|
||||
start_time = time.time()
|
||||
latex_bboxes = latex_detect(img_path, latexdet_model)
|
||||
end_time = time.time()
|
||||
_logger.info(f"latex_det_model time: {end_time - start_time:.2f}s")
|
||||
latex_bboxes = sorted(latex_bboxes)
|
||||
latex_bboxes = bbox_merge(latex_bboxes)
|
||||
masked_img = mask_img(img, latex_bboxes, bg_color)
|
||||
|
||||
start_time = time.time()
|
||||
det_prediction, _ = textdet_model(masked_img)
|
||||
end_time = time.time()
|
||||
_logger.info(f"ocr_det_model time: {end_time - start_time:.2f}s")
|
||||
ocr_bboxes = [
|
||||
Bbox(
|
||||
p[0][0],
|
||||
p[0][1],
|
||||
p[3][1] - p[0][1],
|
||||
p[1][0] - p[0][0],
|
||||
label="text",
|
||||
confidence=None,
|
||||
content=None,
|
||||
)
|
||||
for p in det_prediction
|
||||
]
|
||||
|
||||
ocr_bboxes = sorted(ocr_bboxes)
|
||||
ocr_bboxes = bbox_merge(ocr_bboxes)
|
||||
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
|
||||
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
|
||||
|
||||
sliced_imgs: list[np.ndarray] = slice_from_image(img, ocr_bboxes)
|
||||
start_time = time.time()
|
||||
rec_predictions, _ = textrec_model(sliced_imgs)
|
||||
end_time = time.time()
|
||||
_logger.info(f"ocr_rec_model time: {end_time - start_time:.2f}s")
|
||||
|
||||
assert len(rec_predictions) == len(ocr_bboxes)
|
||||
for content, bbox in zip(rec_predictions, ocr_bboxes):
|
||||
bbox.content = content[0]
|
||||
|
||||
latex_imgs = []
|
||||
for bbox in latex_bboxes:
|
||||
latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w])
|
||||
start_time = time.time()
|
||||
latex_rec_res = img2latex(
|
||||
model=latexrec_model,
|
||||
tokenizer=tokenizer,
|
||||
images=latex_imgs,
|
||||
num_beams=num_beams,
|
||||
out_format="katex",
|
||||
device=device,
|
||||
keep_style=False,
|
||||
)
|
||||
end_time = time.time()
|
||||
_logger.info(f"latex_rec_model time: {end_time - start_time:.2f}s")
|
||||
|
||||
for bbox, content in zip(latex_bboxes, latex_rec_res):
|
||||
if bbox.label == "embedding":
|
||||
bbox.content = " $" + content + "$ "
|
||||
elif bbox.label == "isolated":
|
||||
bbox.content = "\n\n" + r"$$" + content + r"$$" + "\n\n"
|
||||
|
||||
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
||||
if bboxes == []:
|
||||
return ""
|
||||
|
||||
md = ""
|
||||
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
|
||||
for curr in bboxes:
|
||||
# Add the formula number back to the isolated formula
|
||||
if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr):
|
||||
curr.content = curr.content.strip()
|
||||
if curr.content.startswith("(") and curr.content.endswith(")"):
|
||||
curr.content = curr.content[1:-1]
|
||||
|
||||
if re.search(r"\\tag\{.*\}$", md[:-4]) is not None:
|
||||
# in case of multiple tag
|
||||
md = md[:-5] + f", {curr.content}" + "}" + md[-4:]
|
||||
else:
|
||||
md = md[:-4] + f"\\tag{{{curr.content}}}" + md[-4:]
|
||||
continue
|
||||
|
||||
if not prev.same_row(curr):
|
||||
md += " "
|
||||
|
||||
if curr.label == "embedding":
|
||||
# remove the bold effect from inline formulas
|
||||
curr.content = remove_style(curr.content)
|
||||
|
||||
# change split environment into aligned
|
||||
curr.content = curr.content.replace(r"\begin{split}", r"\begin{aligned}")
|
||||
curr.content = curr.content.replace(r"\end{split}", r"\end{aligned}")
|
||||
|
||||
# remove extra spaces (keeping only one)
|
||||
curr.content = re.sub(r" +", " ", curr.content)
|
||||
assert curr.content.startswith("$") and curr.content.endswith("$")
|
||||
curr.content = " $" + curr.content.strip("$") + "$ "
|
||||
md += curr.content
|
||||
prev = curr
|
||||
|
||||
return md.strip()
|
||||
Reference in New Issue
Block a user