[refactor] Init

This commit is contained in:
OleehyO
2025-04-16 14:23:02 +00:00
parent 0e32f3f3bf
commit 06edd104e2
101 changed files with 1854 additions and 2758 deletions

24
texteller/api/__init__.py Normal file
View File

@@ -0,0 +1,24 @@
from .detection import latex_detect
from .format import format_latex
from .inference import img2latex, paragraph2md
from .katex import to_katex
from .load import (
load_latexdet_model,
load_model,
load_textdet_model,
load_textrec_model,
load_tokenizer,
)
__all__ = [
"to_katex",
"format_latex",
"img2latex",
"paragraph2md",
"load_model",
"load_tokenizer",
"load_latexdet_model",
"load_textrec_model",
"load_textdet_model",
"latex_detect",
]

View File

@@ -0,0 +1,4 @@
from .ngram import DetectRepeatingNgramCriteria
__all__ = ["DetectRepeatingNgramCriteria"]

View File

@@ -0,0 +1,63 @@
import torch
from transformers import StoppingCriteria
class DetectRepeatingNgramCriteria(StoppingCriteria):
"""
Stops generation efficiently if any n-gram repeats.
This criteria maintains a set of encountered n-grams.
At each step, it checks if the *latest* n-gram is already in the set.
If yes, it stops generation. If no, it adds the n-gram to the set.
"""
def __init__(self, n: int):
"""
Args:
n (int): The size of the n-gram to check for repetition.
"""
if n <= 0:
raise ValueError("n-gram size 'n' must be positive.")
self.n = n
# Stores tuples of token IDs representing seen n-grams
self.seen_ngrams = set()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores.
Return:
`bool`: `True` if generation should stop, `False` otherwise.
"""
batch_size, seq_length = input_ids.shape
# Need at least n tokens to form the first n-gram
if seq_length < self.n:
return False
# --- Efficient Check ---
# Consider only the first sequence in the batch for simplicity
if batch_size > 1:
# If handling batch_size > 1, you'd need a list of sets, one per batch item.
# Or decide on a stopping policy (e.g., stop if *any* sequence repeats).
# For now, we'll focus on the first sequence.
pass # No warning needed every step, maybe once in __init__ if needed.
sequence = input_ids[0] # Get the first sequence
# Get the latest n-gram (the one ending at the last token)
last_ngram_tensor = sequence[-self.n :]
# Convert to a hashable tuple for set storage and lookup
last_ngram_tuple = tuple(last_ngram_tensor.tolist())
# Check if this n-gram has been seen before *at any prior step*
if last_ngram_tuple in self.seen_ngrams:
return True # Stop generation
else:
# It's a new n-gram, add it to the set and continue
self.seen_ngrams.add(last_ngram_tuple)
return False # Continue generation

View File

@@ -0,0 +1,3 @@
from .detect import latex_detect
__all__ = ["latex_detect"]

View File

@@ -0,0 +1,48 @@
from typing import List
from onnxruntime import InferenceSession
from texteller.types import Bbox
from .preprocess import Compose
_config = {
"mode": "paddle",
"draw_threshold": 0.5,
"metric": "COCO",
"use_dynamic_shape": False,
"arch": "DETR",
"min_subgraph_size": 3,
"preprocess": [
{"interp": 2, "keep_ratio": False, "target_size": [1600, 1600], "type": "Resize"},
{
"mean": [0.0, 0.0, 0.0],
"norm_type": "none",
"std": [1.0, 1.0, 1.0],
"type": "NormalizeImage",
},
{"type": "Permute"},
],
"label_list": ["isolated", "embedding"],
}
def latex_detect(img_path: str, predictor: InferenceSession) -> List[Bbox]:
transforms = Compose(_config["preprocess"])
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None,] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
res = []
for output in outputs:
cls_name = _config["label_list"][int(output[0])]
score = output[1]
xmin = int(max(output[2], 0))
ymin = int(max(output[3], 0))
xmax = int(output[4])
ymax = int(output[5])
if score > 0.5:
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
return res

View File

@@ -0,0 +1,161 @@
import copy
import cv2
import numpy as np
def decode_image(img_path):
if isinstance(img_path, str):
with open(img_path, "rb") as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype="uint8")
else:
assert isinstance(img_path, np.ndarray)
data = img_path
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
img_info = {
"im_shape": np.array(im.shape[:2], dtype=np.float32),
"scale_factor": np.array([1.0, 1.0], dtype=np.float32),
}
return im, img_info
class Resize(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp)
im_info["im_shape"] = np.array(im.shape[:2]).astype("float32")
im_info["scale_factor"] = np.array([im_scale_y, im_scale_x]).astype("float32")
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class NormalizeImage(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True, norm_type="mean_std"):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_scale:
scale = 1.0 / 255.0
im *= scale
if self.norm_type == "mean_std":
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
class Permute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(
self,
):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class Compose:
def __init__(self, transforms):
self.transforms = []
for op_info in transforms:
new_op_info = op_info.copy()
op_type = new_op_info.pop("type")
self.transforms.append(eval(op_type)(**new_op_info))
def __call__(self, img_path):
img, im_info = decode_image(img_path)
for t in self.transforms:
img, im_info = t(img, im_info)
inputs = copy.deepcopy(im_info)
inputs["image"] = img
return inputs

653
texteller/api/format.py Normal file
View File

@@ -0,0 +1,653 @@
#!/usr/bin/env python3
"""
Python implementation of tex-fmt, a LaTeX formatter.
Based on the Rust implementation at https://github.com/WGUNDERWOOD/tex-fmt
"""
import re
from dataclasses import dataclass
from typing import List, Optional, Tuple
# Constants
LINE_END = "\n"
ITEM = "\\item"
DOC_BEGIN = "\\begin{document}"
DOC_END = "\\end{document}"
ENV_BEGIN = "\\begin{"
ENV_END = "\\end{"
TEXT_LINE_START = ""
COMMENT_LINE_START = "% "
# Opening and closing delimiters
OPENS = ['{', '(', '[']
CLOSES = ['}', ')', ']']
# Names of LaTeX verbatim environments
VERBATIMS = ["verbatim", "Verbatim", "lstlisting", "minted", "comment"]
VERBATIMS_BEGIN = [f"\\begin{{{v}}}" for v in VERBATIMS]
VERBATIMS_END = [f"\\end{{{v}}}" for v in VERBATIMS]
# Regex patterns for sectioning commands
SPLITTING = [
r"\\begin\{",
r"\\end\{",
r"\\item(?:$|[^a-zA-Z])",
r"\\(?:sub){0,2}section\*?\{",
r"\\chapter\*?\{",
r"\\part\*?\{",
]
# Compiled regexes
SPLITTING_STRING = f"({'|'.join(SPLITTING)})"
RE_NEWLINES = re.compile(f"{LINE_END}{LINE_END}({LINE_END})+")
RE_TRAIL = re.compile(f" +{LINE_END}")
RE_SPLITTING = re.compile(SPLITTING_STRING)
RE_SPLITTING_SHARED_LINE = re.compile(f"(?:\\S.*?)(?:{SPLITTING_STRING}.*)")
RE_SPLITTING_SHARED_LINE_CAPTURE = re.compile(f"(?P<prev>\\S.*?)(?P<env>{SPLITTING_STRING}.*)")
@dataclass
class Args:
"""Formatter configuration."""
tabchar: str = " "
tabsize: int = 4
wrap: bool = False
wraplen: int = 80
wrapmin: int = 40
lists: List[str] = None
verbosity: int = 0
def __post_init__(self):
if self.lists is None:
self.lists = []
@dataclass
class Ignore:
"""Information on the ignored state of a line."""
actual: bool = False
visual: bool = False
@classmethod
def new(cls):
return cls(False, False)
@dataclass
class Verbatim:
"""Information on the verbatim state of a line."""
actual: int = 0
visual: bool = False
@classmethod
def new(cls):
return cls(0, False)
@dataclass
class Indent:
"""Information on the indentation state of a line."""
actual: int = 0
visual: int = 0
@classmethod
def new(cls):
return cls(0, 0)
@dataclass
class State:
"""Information on the current state during formatting."""
linum_old: int = 1
linum_new: int = 1
ignore: Ignore = None
indent: Indent = None
verbatim: Verbatim = None
linum_last_zero_indent: int = 1
def __post_init__(self):
if self.ignore is None:
self.ignore = Ignore.new()
if self.indent is None:
self.indent = Indent.new()
if self.verbatim is None:
self.verbatim = Verbatim.new()
@dataclass
class Pattern:
"""Record whether a line contains certain patterns."""
contains_env_begin: bool = False
contains_env_end: bool = False
contains_item: bool = False
contains_splitting: bool = False
contains_comment: bool = False
@classmethod
def new(cls, s: str):
"""Check if a string contains patterns."""
if RE_SPLITTING.search(s):
return cls(
contains_env_begin=ENV_BEGIN in s,
contains_env_end=ENV_END in s,
contains_item=ITEM in s,
contains_splitting=True,
contains_comment='%' in s,
)
else:
return cls(
contains_env_begin=False,
contains_env_end=False,
contains_item=False,
contains_splitting=False,
contains_comment='%' in s,
)
@dataclass
class Log:
"""Log message."""
level: str
file: str
message: str
linum_new: Optional[int] = None
linum_old: Optional[int] = None
line: Optional[str] = None
def find_comment_index(line: str, pattern: Pattern) -> Optional[int]:
"""Find the index of a comment in a line."""
if not pattern.contains_comment:
return None
in_command = False
for i, c in enumerate(line):
if c == '\\':
in_command = True
elif in_command and not c.isalpha():
in_command = False
elif c == '%' and not in_command:
return i
return None
def contains_ignore_skip(line: str) -> bool:
"""Check if a line contains a skip directive."""
return line.endswith("% tex-fmt: skip")
def contains_ignore_begin(line: str) -> bool:
"""Check if a line contains the start of an ignore block."""
return line.endswith("% tex-fmt: off")
def contains_ignore_end(line: str) -> bool:
"""Check if a line contains the end of an ignore block."""
return line.endswith("% tex-fmt: on")
def get_ignore(line: str, state: State, logs: List[Log], file: str, warn: bool) -> Ignore:
"""Determine whether a line should be ignored."""
skip = contains_ignore_skip(line)
begin = contains_ignore_begin(line)
end = contains_ignore_end(line)
if skip:
actual = state.ignore.actual
visual = True
elif begin:
actual = True
visual = True
if warn and state.ignore.actual:
logs.append(
Log(
level="WARN",
file=file,
message="Cannot begin ignore block:",
linum_new=state.linum_new,
linum_old=state.linum_old,
line=line,
)
)
elif end:
actual = False
visual = True
if warn and not state.ignore.actual:
logs.append(
Log(
level="WARN",
file=file,
message="No ignore block to end.",
linum_new=state.linum_new,
linum_old=state.linum_old,
line=line,
)
)
else:
actual = state.ignore.actual
visual = state.ignore.actual
return Ignore(actual=actual, visual=visual)
def get_verbatim_diff(line: str, pattern: Pattern) -> int:
"""Calculate total verbatim depth change."""
if pattern.contains_env_begin and any(r in line for r in VERBATIMS_BEGIN):
return 1
elif pattern.contains_env_end and any(r in line for r in VERBATIMS_END):
return -1
else:
return 0
def get_verbatim(
line: str, state: State, logs: List[Log], file: str, warn: bool, pattern: Pattern
) -> Verbatim:
"""Determine whether a line is in a verbatim environment."""
diff = get_verbatim_diff(line, pattern)
actual = state.verbatim.actual + diff
visual = actual > 0 or state.verbatim.actual > 0
if warn and actual < 0:
logs.append(
Log(
level="WARN",
file=file,
message="Verbatim count is negative.",
linum_new=state.linum_new,
linum_old=state.linum_old,
line=line,
)
)
return Verbatim(actual=actual, visual=visual)
def get_diff(line: str, pattern: Pattern, lists_begin: List[str], lists_end: List[str]) -> int:
"""Calculate total indentation change due to the current line."""
diff = 0
# Other environments get single indents
if pattern.contains_env_begin and ENV_BEGIN in line:
# Documents get no global indentation
if DOC_BEGIN in line:
return 0
diff += 1
diff += 1 if any(r in line for r in lists_begin) else 0
elif pattern.contains_env_end and ENV_END in line:
# Documents get no global indentation
if DOC_END in line:
return 0
diff -= 1
diff -= 1 if any(r in line for r in lists_end) else 0
# Indent for delimiters
for c in line:
if c in OPENS:
diff += 1
elif c in CLOSES:
diff -= 1
return diff
def get_back(line: str, pattern: Pattern, state: State, lists_end: List[str]) -> int:
"""Calculate dedentation for the current line."""
# Only need to dedent if indentation is present
if state.indent.actual == 0:
return 0
if pattern.contains_env_end and ENV_END in line:
# Documents get no global indentation
if DOC_END in line:
return 0
# List environments get double indents for indenting items
for r in lists_end:
if r in line:
return 2
return 1
# Items get dedented
if pattern.contains_item and ITEM in line:
return 1
return 0
def get_indent(
line: str,
prev_indent: Indent,
pattern: Pattern,
state: State,
lists_begin: List[str],
lists_end: List[str],
) -> Indent:
"""Calculate the indent for a line."""
diff = get_diff(line, pattern, lists_begin, lists_end)
back = get_back(line, pattern, state, lists_end)
actual = prev_indent.actual + diff
visual = max(0, prev_indent.actual - back)
return Indent(actual=actual, visual=visual)
def calculate_indent(
line: str,
state: State,
logs: List[Log],
file: str,
args: Args,
pattern: Pattern,
lists_begin: List[str],
lists_end: List[str],
) -> Indent:
"""Calculate the indent for a line and update the state."""
indent = get_indent(line, state.indent, pattern, state, lists_begin, lists_end)
# Update the state
state.indent = indent
# Record the last line with zero indent
if indent.visual == 0:
state.linum_last_zero_indent = state.linum_new
return indent
def apply_indent(line: str, indent: Indent, args: Args, indent_char: str) -> str:
"""Apply indentation to a line."""
if not line.strip():
return ""
indent_str = indent_char * (indent.visual * args.tabsize)
return indent_str + line.lstrip()
def needs_wrap(line: str, indent_length: int, args: Args) -> bool:
"""Check if a line needs wrapping."""
return args.wrap and (len(line) + indent_length > args.wraplen)
def find_wrap_point(line: str, indent_length: int, args: Args) -> Optional[int]:
"""Find the best place to break a long line."""
wrap_point = None
after_char = False
prev_char = None
line_width = 0
wrap_boundary = args.wrapmin - indent_length
for i, c in enumerate(line):
line_width += 1
if line_width > wrap_boundary and wrap_point is not None:
break
if c == ' ' and prev_char != '\\':
if after_char:
wrap_point = i
elif c != '%':
after_char = True
prev_char = c
return wrap_point
def apply_wrap(
line: str,
indent_length: int,
state: State,
file: str,
args: Args,
logs: List[Log],
pattern: Pattern,
) -> Optional[List[str]]:
"""Wrap a long line into a short prefix and a suffix."""
if args.verbosity >= 3: # Trace level
logs.append(
Log(
level="TRACE",
file=file,
message="Wrapping long line.",
linum_new=state.linum_new,
linum_old=state.linum_old,
line=line,
)
)
wrap_point = find_wrap_point(line, indent_length, args)
comment_index = find_comment_index(line, pattern)
if wrap_point is None or wrap_point > args.wraplen:
logs.append(
Log(
level="WARN",
file=file,
message="Line cannot be wrapped.",
linum_new=state.linum_new,
linum_old=state.linum_old,
line=line,
)
)
return None
this_line = line[:wrap_point]
if comment_index is not None and wrap_point > comment_index:
next_line_start = COMMENT_LINE_START
else:
next_line_start = TEXT_LINE_START
next_line = line[wrap_point + 1 :]
return [this_line, next_line_start, next_line]
def needs_split(line: str, pattern: Pattern) -> bool:
"""Check if line contains content which should be split onto a new line."""
# Check if we should format this line and if we've matched an environment
contains_splittable_env = (
pattern.contains_splitting and RE_SPLITTING_SHARED_LINE.search(line) is not None
)
# If we're not ignoring and we've matched an environment...
if contains_splittable_env:
# Return True if the comment index is None (which implies the split point must be in text),
# otherwise compare the index of the comment with the split point
comment_index = find_comment_index(line, pattern)
if comment_index is None:
return True
match = RE_SPLITTING_SHARED_LINE_CAPTURE.search(line)
if match and match.start(2) > comment_index:
# If split point is past the comment index, don't split
return False
else:
# Otherwise, split point is before comment and we do split
return True
else:
# If ignoring or didn't match an environment, don't need a new line
return False
def split_line(line: str, state: State, file: str, args: Args, logs: List[Log]) -> Tuple[str, str]:
"""Ensure lines are split correctly."""
match = RE_SPLITTING_SHARED_LINE_CAPTURE.search(line)
if not match:
return line, ""
prev = match.group('prev')
rest = match.group('env')
if args.verbosity >= 3: # Trace level
logs.append(
Log(
level="TRACE",
file=file,
message="Placing environment on new line.",
linum_new=state.linum_new,
linum_old=state.linum_old,
line=line,
)
)
return prev, rest
def set_ignore_and_report(
line: str, temp_state: State, logs: List[Log], file: str, pattern: Pattern
) -> bool:
"""Sets the ignore and verbatim flags in the given State based on line and returns whether line should be ignored."""
temp_state.ignore = get_ignore(line, temp_state, logs, file, True)
temp_state.verbatim = get_verbatim(line, temp_state, logs, file, True, pattern)
return temp_state.verbatim.visual or temp_state.ignore.visual
def clean_text(text: str, args: Args) -> str:
"""Cleans the given text by removing extra line breaks and trailing spaces."""
# Remove extra newlines
text = RE_NEWLINES.sub(f"{LINE_END}{LINE_END}", text)
# Remove tabs if they shouldn't be used
if args.tabchar != '\t':
text = text.replace('\t', ' ' * args.tabsize)
# Remove trailing spaces
text = RE_TRAIL.sub(LINE_END, text)
return text
def remove_trailing_spaces(text: str) -> str:
"""Remove trailing spaces from line endings."""
return RE_TRAIL.sub(LINE_END, text)
def remove_trailing_blank_lines(text: str) -> str:
"""Remove trailing blank lines from file."""
return text.rstrip() + LINE_END
def indents_return_to_zero(state: State) -> bool:
"""Check if indentation returns to zero at the end of the file."""
return state.indent.actual == 0
def format_latex(text: str) -> str:
"""Format LaTeX text with default formatting options.
This is the main API function for formatting LaTeX text.
It uses pre-defined default values for all formatting parameters.
Args:
text: LaTeX text to format
Returns:
Formatted LaTeX text
"""
# Use default configuration
args = Args()
file = "input.tex"
# Format and return only the text
formatted_text, _ = _format_latex(text, file, args)
return formatted_text.strip()
def _format_latex(old_text: str, file: str, args: Args) -> Tuple[str, List[Log]]:
"""Internal function to format a LaTeX string."""
logs = []
logs.append(Log(level="INFO", file=file, message="Formatting started."))
# Clean the source file
old_text = clean_text(old_text, args)
old_lines = list(enumerate(old_text.splitlines(), 1))
# Initialize
state = State()
queue = []
new_text = ""
# Select the character used for indentation
indent_char = '\t' if args.tabchar == '\t' else ' '
# Get any extra environments to be indented as lists
lists_begin = [f"\\begin{{{l}}}" for l in args.lists]
lists_end = [f"\\end{{{l}}}" for l in args.lists]
while True:
if queue:
linum_old, line = queue.pop(0)
# Read the patterns present on this line
pattern = Pattern.new(line)
# Temporary state for working on this line
temp_state = State(
linum_old=linum_old,
linum_new=state.linum_new,
ignore=Ignore(state.ignore.actual, state.ignore.visual),
indent=Indent(state.indent.actual, state.indent.visual),
verbatim=Verbatim(state.verbatim.actual, state.verbatim.visual),
linum_last_zero_indent=state.linum_last_zero_indent,
)
# If the line should not be ignored...
if not set_ignore_and_report(line, temp_state, logs, file, pattern):
# Check if the line should be split because of a pattern that should begin on a new line
if needs_split(line, pattern):
# Split the line into two...
this_line, next_line = split_line(line, temp_state, file, args, logs)
# ...and queue the second part for formatting
if next_line:
queue.insert(0, (linum_old, next_line))
line = this_line
# Calculate the indent based on the current state and the patterns in the line
indent = calculate_indent(
line, temp_state, logs, file, args, pattern, lists_begin, lists_end
)
indent_length = indent.visual * args.tabsize
# Wrap the line before applying the indent, and loop back if the line needed wrapping
if needs_wrap(line.lstrip(), indent_length, args):
wrapped_lines = apply_wrap(
line.lstrip(), indent_length, temp_state, file, args, logs, pattern
)
if wrapped_lines:
this_line, next_line_start, next_line = wrapped_lines
queue.insert(0, (linum_old, next_line_start + next_line))
queue.insert(0, (linum_old, this_line))
continue
# Lastly, apply the indent if the line didn't need wrapping
line = apply_indent(line, indent, args, indent_char)
# Add line to new text
state = temp_state
new_text += line + LINE_END
state.linum_new += 1
elif old_lines:
linum_old, line = old_lines.pop(0)
queue.append((linum_old, line))
else:
break
if not indents_return_to_zero(state):
msg = f"Indent does not return to zero. Last non-indented line is line {state.linum_last_zero_indent}"
logs.append(Log(level="WARN", file=file, message=msg))
new_text = remove_trailing_spaces(new_text)
new_text = remove_trailing_blank_lines(new_text)
logs.append(Log(level="INFO", file=file, message="Formatting complete."))
return new_text, logs

241
texteller/api/inference.py Normal file
View File

@@ -0,0 +1,241 @@
import re
import time
from collections import Counter
from typing import Literal
import cv2
import numpy as np
import torch
from onnxruntime import InferenceSession
from optimum.onnxruntime import ORTModelForVision2Seq
from transformers import GenerationConfig, RobertaTokenizerFast
from texteller.constants import MAX_TOKEN_SIZE
from texteller.logger import get_logger
from texteller.paddleocr import predict_det, predict_rec
from texteller.types import Bbox, TexTellerModel
from texteller.utils import (
bbox_merge,
get_device,
mask_img,
readimgs,
remove_style,
slice_from_image,
split_conflict,
transform,
add_newlines,
)
from .detection import latex_detect
from .format import format_latex
from .katex import to_katex
_logger = get_logger()
def img2latex(
model: TexTellerModel,
tokenizer: RobertaTokenizerFast,
images: list[str] | list[np.ndarray],
device: torch.device | None = None,
out_format: Literal["latex", "katex"] = "latex",
keep_style: bool = False,
max_tokens: int = MAX_TOKEN_SIZE,
num_beams: int = 1,
no_repeat_ngram_size: int = 0,
) -> list[str]:
"""
Convert images to LaTeX or KaTeX formatted strings.
Args:
model: The TexTeller or ORTModelForVision2Seq model instance
tokenizer: The tokenizer for the model
images: List of image paths or numpy arrays (RGB format)
device: The torch device to use (defaults to available GPU or CPU)
out_format: Output format, either "latex" or "katex"
keep_style: Whether to keep the style of the LaTeX
max_tokens: Maximum number of tokens to generate
num_beams: Number of beams for beam search
no_repeat_ngram_size: Size of n-grams to prevent repetition
Returns:
List of LaTeX or KaTeX strings corresponding to each input image
Example usage:
>>> import torch
>>> from texteller import load_model, load_tokenizer, img2latex
>>> model = load_model(model_path=None, use_onnx=False)
>>> tokenizer = load_tokenizer(tokenizer_path=None)
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> res = img2latex(model, tokenizer, ["path/to/image.png"], device=device, out_format="katex")
"""
assert isinstance(images, list)
assert len(images) > 0
if device is None:
device = get_device()
if device.type != model.device.type:
if isinstance(model, ORTModelForVision2Seq):
_logger.warning(
f"Onnxruntime device mismatch: detected {str(device)} but model is on {str(model.device)}, using {str(model.device)} instead"
)
else:
model = model.to(device=device)
if isinstance(images[0], str):
images = readimgs(images)
else: # already numpy array(rgb format)
assert isinstance(images[0], np.ndarray)
images = images
images = transform(images)
pixel_values = torch.stack(images)
generate_config = GenerationConfig(
max_new_tokens=max_tokens,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
no_repeat_ngram_size=no_repeat_ngram_size,
)
pred = model.generate(
pixel_values.to(model.device),
generation_config=generate_config,
)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
if out_format == "katex":
res = [to_katex(r) for r in res]
if not keep_style:
res = [remove_style(r) for r in res]
res = [format_latex(r) for r in res]
res = [add_newlines(r) for r in res]
return res
def paragraph2md(
img_path: str,
latexdet_model: InferenceSession,
textdet_model: predict_det.TextDetector,
textrec_model: predict_rec.TextRecognizer,
latexrec_model: TexTellerModel,
tokenizer: RobertaTokenizerFast,
device: torch.device | None = None,
num_beams=1,
) -> str:
"""
Input a mixed image of formula text and output str (in markdown syntax)
"""
img = cv2.imread(img_path)
corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])]
bg_color = np.array(Counter(corners).most_common(1)[0][0])
start_time = time.time()
latex_bboxes = latex_detect(img_path, latexdet_model)
end_time = time.time()
_logger.info(f"latex_det_model time: {end_time - start_time:.2f}s")
latex_bboxes = sorted(latex_bboxes)
latex_bboxes = bbox_merge(latex_bboxes)
masked_img = mask_img(img, latex_bboxes, bg_color)
start_time = time.time()
det_prediction, _ = textdet_model(masked_img)
end_time = time.time()
_logger.info(f"ocr_det_model time: {end_time - start_time:.2f}s")
ocr_bboxes = [
Bbox(
p[0][0],
p[0][1],
p[3][1] - p[0][1],
p[1][0] - p[0][0],
label="text",
confidence=None,
content=None,
)
for p in det_prediction
]
ocr_bboxes = sorted(ocr_bboxes)
ocr_bboxes = bbox_merge(ocr_bboxes)
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
sliced_imgs: list[np.ndarray] = slice_from_image(img, ocr_bboxes)
start_time = time.time()
rec_predictions, _ = textrec_model(sliced_imgs)
end_time = time.time()
_logger.info(f"ocr_rec_model time: {end_time - start_time:.2f}s")
assert len(rec_predictions) == len(ocr_bboxes)
for content, bbox in zip(rec_predictions, ocr_bboxes):
bbox.content = content[0]
latex_imgs = []
for bbox in latex_bboxes:
latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w])
start_time = time.time()
latex_rec_res = img2latex(
model=latexrec_model,
tokenizer=tokenizer,
images=latex_imgs,
num_beams=num_beams,
out_format="katex",
device=device,
keep_style=False,
)
end_time = time.time()
_logger.info(f"latex_rec_model time: {end_time - start_time:.2f}s")
for bbox, content in zip(latex_bboxes, latex_rec_res):
if bbox.label == "embedding":
bbox.content = " $" + content + "$ "
elif bbox.label == "isolated":
bbox.content = "\n\n" + r"$$" + content + r"$$" + "\n\n"
bboxes = sorted(ocr_bboxes + latex_bboxes)
if bboxes == []:
return ""
md = ""
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
for curr in bboxes:
# Add the formula number back to the isolated formula
if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr):
curr.content = curr.content.strip()
if curr.content.startswith("(") and curr.content.endswith(")"):
curr.content = curr.content[1:-1]
if re.search(r"\\tag\{.*\}$", md[:-4]) is not None:
# in case of multiple tag
md = md[:-5] + f", {curr.content}" + "}" + md[-4:]
else:
md = md[:-4] + f"\\tag{{{curr.content}}}" + md[-4:]
continue
if not prev.same_row(curr):
md += " "
if curr.label == "embedding":
# remove the bold effect from inline formulas
curr.content = remove_style(curr.content)
# change split environment into aligned
curr.content = curr.content.replace(r"\begin{split}", r"\begin{aligned}")
curr.content = curr.content.replace(r"\end{split}", r"\end{aligned}")
# remove extra spaces (keeping only one)
curr.content = re.sub(r" +", " ", curr.content)
assert curr.content.startswith("$") and curr.content.endswith("$")
curr.content = " $" + curr.content.strip("$") + "$ "
md += curr.content
prev = curr
return md.strip()

118
texteller/api/katex.py Normal file
View File

@@ -0,0 +1,118 @@
import re
from ..utils.latex import change_all
from .format import format_latex
def _rm_dollar_surr(content):
pattern = re.compile(r'\\[a-zA-Z]+\$.*?\$|\$.*?\$')
matches = pattern.findall(content)
for match in matches:
if not re.match(r'\\[a-zA-Z]+', match):
new_match = match.strip('$')
content = content.replace(match, ' ' + new_match + ' ')
return content
def to_katex(formula: str) -> str:
res = formula
# remove mbox surrounding
res = change_all(res, r'\mbox ', r' ', r'{', r'}', r'', r'')
res = change_all(res, r'\mbox', r' ', r'{', r'}', r'', r'')
# remove hbox surrounding
res = re.sub(r'\\hbox to ?-? ?\d+\.\d+(pt)?\{', r'\\hbox{', res)
res = change_all(res, r'\hbox', r' ', r'{', r'}', r'', r' ')
# remove raise surrounding
res = re.sub(r'\\raise ?-? ?\d+\.\d+(pt)?', r' ', res)
# remove makebox
res = re.sub(r'\\makebox ?\[\d+\.\d+(pt)?\]\{', r'\\makebox{', res)
res = change_all(res, r'\makebox', r' ', r'{', r'}', r'', r' ')
# remove vbox surrounding, scalebox surrounding
res = re.sub(r'\\raisebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\raisebox{', res)
res = re.sub(r'\\scalebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\scalebox{', res)
res = change_all(res, r'\scalebox', r' ', r'{', r'}', r'', r' ')
res = change_all(res, r'\raisebox', r' ', r'{', r'}', r'', r' ')
res = change_all(res, r'\vbox', r' ', r'{', r'}', r'', r' ')
origin_instructions = [
r'\Huge',
r'\huge',
r'\LARGE',
r'\Large',
r'\large',
r'\normalsize',
r'\small',
r'\footnotesize',
r'\tiny',
]
for old_ins, new_ins in zip(origin_instructions, origin_instructions):
res = change_all(res, old_ins, new_ins, r'$', r'$', '{', '}')
res = change_all(res, r'\mathbf', r'\bm', r'{', r'}', r'{', r'}')
res = change_all(res, r'\boldmath ', r'\bm', r'{', r'}', r'{', r'}')
res = change_all(res, r'\boldmath', r'\bm', r'{', r'}', r'{', r'}')
res = change_all(res, r'\boldmath ', r'\bm', r'$', r'$', r'{', r'}')
res = change_all(res, r'\boldmath', r'\bm', r'$', r'$', r'{', r'}')
res = change_all(res, r'\scriptsize', r'\scriptsize', r'$', r'$', r'{', r'}')
res = change_all(res, r'\emph', r'\textit', r'{', r'}', r'{', r'}')
res = change_all(res, r'\emph ', r'\textit', r'{', r'}', r'{', r'}')
# remove bold command
res = change_all(res, r'\bm', r' ', r'{', r'}', r'', r'')
origin_instructions = [
r'\left',
r'\middle',
r'\right',
r'\big',
r'\Big',
r'\bigg',
r'\Bigg',
r'\bigl',
r'\Bigl',
r'\biggl',
r'\Biggl',
r'\bigm',
r'\Bigm',
r'\biggm',
r'\Biggm',
r'\bigr',
r'\Bigr',
r'\biggr',
r'\Biggr',
]
for origin_ins in origin_instructions:
res = change_all(res, origin_ins, origin_ins, r'{', r'}', r'', r'')
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
if res.endswith(r'\newline'):
res = res[:-8]
# remove multiple spaces
res = re.sub(r'(\\,){1,}', ' ', res)
res = re.sub(r'(\\!){1,}', ' ', res)
res = re.sub(r'(\\;){1,}', ' ', res)
res = re.sub(r'(\\:){1,}', ' ', res)
res = re.sub(r'\\vspace\{.*?}', '', res)
# merge consecutive text
def merge_texts(match):
texts = match.group(0)
merged_content = ''.join(re.findall(r'\\text\{([^}]*)\}', texts))
return f'\\text{{{merged_content}}}'
res = re.sub(r'(\\text\{[^}]*\}\s*){2,}', merge_texts, res)
res = res.replace(r'\bf ', '')
res = _rm_dollar_surr(res)
# remove extra spaces (keeping only one)
res = re.sub(r' +', ' ', res)
# format latex
res = res.strip()
res = format_latex(res)
return res

66
texteller/api/load.py Normal file
View File

@@ -0,0 +1,66 @@
from pathlib import Path
import wget
from onnxruntime import InferenceSession
from transformers import RobertaTokenizerFast
from texteller.constants import LATEX_DET_MODEL_URL, TEXT_DET_MODEL_URL, TEXT_REC_MODEL_URL
from texteller.globals import Globals
from texteller.logger import get_logger
from texteller.models import TexTeller
from texteller.paddleocr import predict_det, predict_rec
from texteller.paddleocr.utility import parse_args
from texteller.utils import cuda_available, mkdir, resolve_path
from texteller.types import TexTellerModel
_logger = get_logger(__name__)
def load_model(model_dir: str | None = None, use_onnx: bool = False) -> TexTellerModel:
return TexTeller.from_pretrained(model_dir, use_onnx=use_onnx)
def load_tokenizer(tokenizer_dir: str | None = None) -> RobertaTokenizerFast:
return TexTeller.get_tokenizer(tokenizer_dir)
def load_latexdet_model() -> InferenceSession:
fpath = _maybe_download(LATEX_DET_MODEL_URL)
return InferenceSession(
resolve_path(fpath),
providers=["CUDAExecutionProvider" if cuda_available() else "CPUExecutionProvider"],
)
def load_textrec_model() -> predict_rec.TextRecognizer:
fpath = _maybe_download(TEXT_REC_MODEL_URL)
paddleocr_args = parse_args()
paddleocr_args.use_onnx = True
paddleocr_args.rec_model_dir = resolve_path(fpath)
paddleocr_args.use_gpu = cuda_available()
predictor = predict_rec.TextRecognizer(paddleocr_args)
return predictor
def load_textdet_model() -> predict_det.TextDetector:
fpath = _maybe_download(TEXT_DET_MODEL_URL)
paddleocr_args = parse_args()
paddleocr_args.use_onnx = True
paddleocr_args.det_model_dir = resolve_path(fpath)
paddleocr_args.use_gpu = cuda_available()
predictor = predict_det.TextDetector(paddleocr_args)
return predictor
def _maybe_download(url: str, dirpath: str | None = None, force: bool = False) -> Path:
if dirpath is None:
dirpath = Globals().cache_dir
mkdir(dirpath)
fname = Path(url).name
fpath = Path(dirpath) / fname
if not fpath.exists() or force:
_logger.info(f"Downloading {fname} from {url} to {fpath}")
wget.download(url, resolve_path(fpath))
return fpath