import re import time from collections import Counter from typing import Literal import cv2 import numpy as np import torch from onnxruntime import InferenceSession from optimum.onnxruntime import ORTModelForVision2Seq from transformers import GenerationConfig, RobertaTokenizerFast from texteller.constants import MAX_TOKEN_SIZE from texteller.logger import get_logger from texteller.paddleocr import predict_det, predict_rec from texteller.types import Bbox, TexTellerModel from texteller.utils import ( bbox_merge, get_device, mask_img, readimgs, remove_style, slice_from_image, split_conflict, transform, add_newlines, ) from .detection import latex_detect from .format import format_latex from .katex import to_katex _logger = get_logger() def img2latex( model: TexTellerModel, tokenizer: RobertaTokenizerFast, images: list[str] | list[np.ndarray], device: torch.device | None = None, out_format: Literal["latex", "katex"] = "latex", keep_style: bool = False, max_tokens: int = MAX_TOKEN_SIZE, num_beams: int = 1, no_repeat_ngram_size: int = 0, ) -> list[str]: """ Convert images to LaTeX or KaTeX formatted strings. Args: model: The TexTeller or ORTModelForVision2Seq model instance tokenizer: The tokenizer for the model images: List of image paths or numpy arrays (RGB format) device: The torch device to use (defaults to available GPU or CPU) out_format: Output format, either "latex" or "katex" keep_style: Whether to keep the style of the LaTeX max_tokens: Maximum number of tokens to generate num_beams: Number of beams for beam search no_repeat_ngram_size: Size of n-grams to prevent repetition Returns: List of LaTeX or KaTeX strings corresponding to each input image Example usage: >>> import torch >>> from texteller import load_model, load_tokenizer, img2latex >>> model = load_model(model_path=None, use_onnx=False) >>> tokenizer = load_tokenizer(tokenizer_path=None) >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> res = img2latex(model, tokenizer, ["path/to/image.png"], device=device, out_format="katex") """ assert isinstance(images, list) assert len(images) > 0 if device is None: device = get_device() if device.type != model.device.type: if isinstance(model, ORTModelForVision2Seq): _logger.warning( f"Onnxruntime device mismatch: detected {str(device)} but model is on {str(model.device)}, using {str(model.device)} instead" ) else: model = model.to(device=device) if isinstance(images[0], str): images = readimgs(images) else: # already numpy array(rgb format) assert isinstance(images[0], np.ndarray) images = images images = transform(images) pixel_values = torch.stack(images) generate_config = GenerationConfig( max_new_tokens=max_tokens, num_beams=num_beams, do_sample=False, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, no_repeat_ngram_size=no_repeat_ngram_size, ) pred = model.generate( pixel_values.to(model.device), generation_config=generate_config, ) res = tokenizer.batch_decode(pred, skip_special_tokens=True) if out_format == "katex": res = [to_katex(r) for r in res] if not keep_style: res = [remove_style(r) for r in res] res = [format_latex(r) for r in res] res = [add_newlines(r) for r in res] return res def paragraph2md( img_path: str, latexdet_model: InferenceSession, textdet_model: predict_det.TextDetector, textrec_model: predict_rec.TextRecognizer, latexrec_model: TexTellerModel, tokenizer: RobertaTokenizerFast, device: torch.device | None = None, num_beams=1, ) -> str: """ Input a mixed image of formula text and output str (in markdown syntax) """ img = cv2.imread(img_path) corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])] bg_color = np.array(Counter(corners).most_common(1)[0][0]) start_time = time.time() latex_bboxes = latex_detect(img_path, latexdet_model) end_time = time.time() _logger.info(f"latex_det_model time: {end_time - start_time:.2f}s") latex_bboxes = sorted(latex_bboxes) latex_bboxes = bbox_merge(latex_bboxes) masked_img = mask_img(img, latex_bboxes, bg_color) start_time = time.time() det_prediction, _ = textdet_model(masked_img) end_time = time.time() _logger.info(f"ocr_det_model time: {end_time - start_time:.2f}s") ocr_bboxes = [ Bbox( p[0][0], p[0][1], p[3][1] - p[0][1], p[1][0] - p[0][0], label="text", confidence=None, content=None, ) for p in det_prediction ] ocr_bboxes = sorted(ocr_bboxes) ocr_bboxes = bbox_merge(ocr_bboxes) ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes) ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes)) sliced_imgs: list[np.ndarray] = slice_from_image(img, ocr_bboxes) start_time = time.time() rec_predictions, _ = textrec_model(sliced_imgs) end_time = time.time() _logger.info(f"ocr_rec_model time: {end_time - start_time:.2f}s") assert len(rec_predictions) == len(ocr_bboxes) for content, bbox in zip(rec_predictions, ocr_bboxes): bbox.content = content[0] latex_imgs = [] for bbox in latex_bboxes: latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w]) start_time = time.time() latex_rec_res = img2latex( model=latexrec_model, tokenizer=tokenizer, images=latex_imgs, num_beams=num_beams, out_format="katex", device=device, keep_style=False, ) end_time = time.time() _logger.info(f"latex_rec_model time: {end_time - start_time:.2f}s") for bbox, content in zip(latex_bboxes, latex_rec_res): if bbox.label == "embedding": bbox.content = " $" + content + "$ " elif bbox.label == "isolated": bbox.content = "\n\n" + r"$$" + content + r"$$" + "\n\n" bboxes = sorted(ocr_bboxes + latex_bboxes) if bboxes == []: return "" md = "" prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard") for curr in bboxes: # Add the formula number back to the isolated formula if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr): curr.content = curr.content.strip() if curr.content.startswith("(") and curr.content.endswith(")"): curr.content = curr.content[1:-1] if re.search(r"\\tag\{.*\}$", md[:-4]) is not None: # in case of multiple tag md = md[:-5] + f", {curr.content}" + "}" + md[-4:] else: md = md[:-4] + f"\\tag{{{curr.content}}}" + md[-4:] continue if not prev.same_row(curr): md += " " if curr.label == "embedding": # remove the bold effect from inline formulas curr.content = remove_style(curr.content) # change split environment into aligned curr.content = curr.content.replace(r"\begin{split}", r"\begin{aligned}") curr.content = curr.content.replace(r"\end{split}", r"\end{aligned}") # remove extra spaces (keeping only one) curr.content = re.sub(r" +", " ", curr.content) assert curr.content.startswith("$") and curr.content.endswith("$") curr.content = " $" + curr.content.strip("$") + "$ " md += curr.content prev = curr return md.strip()