[refactor] Init
This commit is contained in:
24
texteller/api/__init__.py
Normal file
24
texteller/api/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from .detection import latex_detect
|
||||
from .format import format_latex
|
||||
from .inference import img2latex, paragraph2md
|
||||
from .katex import to_katex
|
||||
from .load import (
|
||||
load_latexdet_model,
|
||||
load_model,
|
||||
load_textdet_model,
|
||||
load_textrec_model,
|
||||
load_tokenizer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"to_katex",
|
||||
"format_latex",
|
||||
"img2latex",
|
||||
"paragraph2md",
|
||||
"load_model",
|
||||
"load_tokenizer",
|
||||
"load_latexdet_model",
|
||||
"load_textrec_model",
|
||||
"load_textdet_model",
|
||||
"latex_detect",
|
||||
]
|
||||
4
texteller/api/criterias/__init__.py
Normal file
4
texteller/api/criterias/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .ngram import DetectRepeatingNgramCriteria
|
||||
|
||||
|
||||
__all__ = ["DetectRepeatingNgramCriteria"]
|
||||
63
texteller/api/criterias/ngram.py
Normal file
63
texteller/api/criterias/ngram.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
|
||||
class DetectRepeatingNgramCriteria(StoppingCriteria):
|
||||
"""
|
||||
Stops generation efficiently if any n-gram repeats.
|
||||
|
||||
This criteria maintains a set of encountered n-grams.
|
||||
At each step, it checks if the *latest* n-gram is already in the set.
|
||||
If yes, it stops generation. If no, it adds the n-gram to the set.
|
||||
"""
|
||||
|
||||
def __init__(self, n: int):
|
||||
"""
|
||||
Args:
|
||||
n (int): The size of the n-gram to check for repetition.
|
||||
"""
|
||||
if n <= 0:
|
||||
raise ValueError("n-gram size 'n' must be positive.")
|
||||
self.n = n
|
||||
# Stores tuples of token IDs representing seen n-grams
|
||||
self.seen_ngrams = set()
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
|
||||
Prediction scores.
|
||||
|
||||
Return:
|
||||
`bool`: `True` if generation should stop, `False` otherwise.
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
# Need at least n tokens to form the first n-gram
|
||||
if seq_length < self.n:
|
||||
return False
|
||||
|
||||
# --- Efficient Check ---
|
||||
# Consider only the first sequence in the batch for simplicity
|
||||
if batch_size > 1:
|
||||
# If handling batch_size > 1, you'd need a list of sets, one per batch item.
|
||||
# Or decide on a stopping policy (e.g., stop if *any* sequence repeats).
|
||||
# For now, we'll focus on the first sequence.
|
||||
pass # No warning needed every step, maybe once in __init__ if needed.
|
||||
|
||||
sequence = input_ids[0] # Get the first sequence
|
||||
|
||||
# Get the latest n-gram (the one ending at the last token)
|
||||
last_ngram_tensor = sequence[-self.n :]
|
||||
# Convert to a hashable tuple for set storage and lookup
|
||||
last_ngram_tuple = tuple(last_ngram_tensor.tolist())
|
||||
|
||||
# Check if this n-gram has been seen before *at any prior step*
|
||||
if last_ngram_tuple in self.seen_ngrams:
|
||||
return True # Stop generation
|
||||
else:
|
||||
# It's a new n-gram, add it to the set and continue
|
||||
self.seen_ngrams.add(last_ngram_tuple)
|
||||
return False # Continue generation
|
||||
3
texteller/api/detection/__init__.py
Normal file
3
texteller/api/detection/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .detect import latex_detect
|
||||
|
||||
__all__ = ["latex_detect"]
|
||||
48
texteller/api/detection/detect.py
Normal file
48
texteller/api/detection/detect.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import List
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from texteller.types import Bbox
|
||||
|
||||
from .preprocess import Compose
|
||||
|
||||
_config = {
|
||||
"mode": "paddle",
|
||||
"draw_threshold": 0.5,
|
||||
"metric": "COCO",
|
||||
"use_dynamic_shape": False,
|
||||
"arch": "DETR",
|
||||
"min_subgraph_size": 3,
|
||||
"preprocess": [
|
||||
{"interp": 2, "keep_ratio": False, "target_size": [1600, 1600], "type": "Resize"},
|
||||
{
|
||||
"mean": [0.0, 0.0, 0.0],
|
||||
"norm_type": "none",
|
||||
"std": [1.0, 1.0, 1.0],
|
||||
"type": "NormalizeImage",
|
||||
},
|
||||
{"type": "Permute"},
|
||||
],
|
||||
"label_list": ["isolated", "embedding"],
|
||||
}
|
||||
|
||||
|
||||
def latex_detect(img_path: str, predictor: InferenceSession) -> List[Bbox]:
|
||||
transforms = Compose(_config["preprocess"])
|
||||
inputs = transforms(img_path)
|
||||
inputs_name = [var.name for var in predictor.get_inputs()]
|
||||
inputs = {k: inputs[k][None,] for k in inputs_name}
|
||||
|
||||
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
|
||||
res = []
|
||||
for output in outputs:
|
||||
cls_name = _config["label_list"][int(output[0])]
|
||||
score = output[1]
|
||||
xmin = int(max(output[2], 0))
|
||||
ymin = int(max(output[3], 0))
|
||||
xmax = int(output[4])
|
||||
ymax = int(output[5])
|
||||
if score > 0.5:
|
||||
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
|
||||
|
||||
return res
|
||||
161
texteller/api/detection/preprocess.py
Normal file
161
texteller/api/detection/preprocess.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import copy
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def decode_image(img_path):
|
||||
if isinstance(img_path, str):
|
||||
with open(img_path, "rb") as f:
|
||||
im_read = f.read()
|
||||
data = np.frombuffer(im_read, dtype="uint8")
|
||||
else:
|
||||
assert isinstance(img_path, np.ndarray)
|
||||
data = img_path
|
||||
|
||||
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
|
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||
img_info = {
|
||||
"im_shape": np.array(im.shape[:2], dtype=np.float32),
|
||||
"scale_factor": np.array([1.0, 1.0], dtype=np.float32),
|
||||
}
|
||||
return im, img_info
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""resize image by target_size and max_size
|
||||
Args:
|
||||
target_size (int): the target size of image
|
||||
keep_ratio (bool): whether keep_ratio or not, default true
|
||||
interp (int): method of resize
|
||||
"""
|
||||
|
||||
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
|
||||
if isinstance(target_size, int):
|
||||
target_size = [target_size, target_size]
|
||||
self.target_size = target_size
|
||||
self.keep_ratio = keep_ratio
|
||||
self.interp = interp
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
assert len(self.target_size) == 2
|
||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||
im_channel = im.shape[2]
|
||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||
im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp)
|
||||
im_info["im_shape"] = np.array(im.shape[:2]).astype("float32")
|
||||
im_info["scale_factor"] = np.array([im_scale_y, im_scale_x]).astype("float32")
|
||||
return im, im_info
|
||||
|
||||
def generate_scale(self, im):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
Returns:
|
||||
im_scale_x: the resize ratio of X
|
||||
im_scale_y: the resize ratio of Y
|
||||
"""
|
||||
origin_shape = im.shape[:2]
|
||||
im_c = im.shape[2]
|
||||
if self.keep_ratio:
|
||||
im_size_min = np.min(origin_shape)
|
||||
im_size_max = np.max(origin_shape)
|
||||
target_size_min = np.min(self.target_size)
|
||||
target_size_max = np.max(self.target_size)
|
||||
im_scale = float(target_size_min) / float(im_size_min)
|
||||
if np.round(im_scale * im_size_max) > target_size_max:
|
||||
im_scale = float(target_size_max) / float(im_size_max)
|
||||
im_scale_x = im_scale
|
||||
im_scale_y = im_scale
|
||||
else:
|
||||
resize_h, resize_w = self.target_size
|
||||
im_scale_y = resize_h / float(origin_shape[0])
|
||||
im_scale_x = resize_w / float(origin_shape[1])
|
||||
return im_scale_y, im_scale_x
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
"""normalize image
|
||||
Args:
|
||||
mean (list): im - mean
|
||||
std (list): im / std
|
||||
is_scale (bool): whether need im / 255
|
||||
norm_type (str): type in ['mean_std', 'none']
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, is_scale=True, norm_type="mean_std"):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.is_scale = is_scale
|
||||
self.norm_type = norm_type
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.astype(np.float32, copy=False)
|
||||
if self.is_scale:
|
||||
scale = 1.0 / 255.0
|
||||
im *= scale
|
||||
|
||||
if self.norm_type == "mean_std":
|
||||
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
|
||||
std = np.array(self.std)[np.newaxis, np.newaxis, :]
|
||||
im -= mean
|
||||
im /= std
|
||||
return im, im_info
|
||||
|
||||
|
||||
class Permute(object):
|
||||
"""permute image
|
||||
Args:
|
||||
to_bgr (bool): whether convert RGB to BGR
|
||||
channel_first (bool): whether convert HWC to CHW
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super(Permute, self).__init__()
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.transpose((2, 0, 1)).copy()
|
||||
return im, im_info
|
||||
|
||||
|
||||
class Compose:
|
||||
def __init__(self, transforms):
|
||||
self.transforms = []
|
||||
for op_info in transforms:
|
||||
new_op_info = op_info.copy()
|
||||
op_type = new_op_info.pop("type")
|
||||
self.transforms.append(eval(op_type)(**new_op_info))
|
||||
|
||||
def __call__(self, img_path):
|
||||
img, im_info = decode_image(img_path)
|
||||
for t in self.transforms:
|
||||
img, im_info = t(img, im_info)
|
||||
inputs = copy.deepcopy(im_info)
|
||||
inputs["image"] = img
|
||||
return inputs
|
||||
653
texteller/api/format.py
Normal file
653
texteller/api/format.py
Normal file
@@ -0,0 +1,653 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Python implementation of tex-fmt, a LaTeX formatter.
|
||||
Based on the Rust implementation at https://github.com/WGUNDERWOOD/tex-fmt
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
# Constants
|
||||
LINE_END = "\n"
|
||||
ITEM = "\\item"
|
||||
DOC_BEGIN = "\\begin{document}"
|
||||
DOC_END = "\\end{document}"
|
||||
ENV_BEGIN = "\\begin{"
|
||||
ENV_END = "\\end{"
|
||||
TEXT_LINE_START = ""
|
||||
COMMENT_LINE_START = "% "
|
||||
|
||||
# Opening and closing delimiters
|
||||
OPENS = ['{', '(', '[']
|
||||
CLOSES = ['}', ')', ']']
|
||||
|
||||
# Names of LaTeX verbatim environments
|
||||
VERBATIMS = ["verbatim", "Verbatim", "lstlisting", "minted", "comment"]
|
||||
VERBATIMS_BEGIN = [f"\\begin{{{v}}}" for v in VERBATIMS]
|
||||
VERBATIMS_END = [f"\\end{{{v}}}" for v in VERBATIMS]
|
||||
|
||||
# Regex patterns for sectioning commands
|
||||
SPLITTING = [
|
||||
r"\\begin\{",
|
||||
r"\\end\{",
|
||||
r"\\item(?:$|[^a-zA-Z])",
|
||||
r"\\(?:sub){0,2}section\*?\{",
|
||||
r"\\chapter\*?\{",
|
||||
r"\\part\*?\{",
|
||||
]
|
||||
|
||||
# Compiled regexes
|
||||
SPLITTING_STRING = f"({'|'.join(SPLITTING)})"
|
||||
RE_NEWLINES = re.compile(f"{LINE_END}{LINE_END}({LINE_END})+")
|
||||
RE_TRAIL = re.compile(f" +{LINE_END}")
|
||||
RE_SPLITTING = re.compile(SPLITTING_STRING)
|
||||
RE_SPLITTING_SHARED_LINE = re.compile(f"(?:\\S.*?)(?:{SPLITTING_STRING}.*)")
|
||||
RE_SPLITTING_SHARED_LINE_CAPTURE = re.compile(f"(?P<prev>\\S.*?)(?P<env>{SPLITTING_STRING}.*)")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
"""Formatter configuration."""
|
||||
|
||||
tabchar: str = " "
|
||||
tabsize: int = 4
|
||||
wrap: bool = False
|
||||
wraplen: int = 80
|
||||
wrapmin: int = 40
|
||||
lists: List[str] = None
|
||||
verbosity: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.lists is None:
|
||||
self.lists = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ignore:
|
||||
"""Information on the ignored state of a line."""
|
||||
|
||||
actual: bool = False
|
||||
visual: bool = False
|
||||
|
||||
@classmethod
|
||||
def new(cls):
|
||||
return cls(False, False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Verbatim:
|
||||
"""Information on the verbatim state of a line."""
|
||||
|
||||
actual: int = 0
|
||||
visual: bool = False
|
||||
|
||||
@classmethod
|
||||
def new(cls):
|
||||
return cls(0, False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Indent:
|
||||
"""Information on the indentation state of a line."""
|
||||
|
||||
actual: int = 0
|
||||
visual: int = 0
|
||||
|
||||
@classmethod
|
||||
def new(cls):
|
||||
return cls(0, 0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Information on the current state during formatting."""
|
||||
|
||||
linum_old: int = 1
|
||||
linum_new: int = 1
|
||||
ignore: Ignore = None
|
||||
indent: Indent = None
|
||||
verbatim: Verbatim = None
|
||||
linum_last_zero_indent: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.ignore is None:
|
||||
self.ignore = Ignore.new()
|
||||
if self.indent is None:
|
||||
self.indent = Indent.new()
|
||||
if self.verbatim is None:
|
||||
self.verbatim = Verbatim.new()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pattern:
|
||||
"""Record whether a line contains certain patterns."""
|
||||
|
||||
contains_env_begin: bool = False
|
||||
contains_env_end: bool = False
|
||||
contains_item: bool = False
|
||||
contains_splitting: bool = False
|
||||
contains_comment: bool = False
|
||||
|
||||
@classmethod
|
||||
def new(cls, s: str):
|
||||
"""Check if a string contains patterns."""
|
||||
if RE_SPLITTING.search(s):
|
||||
return cls(
|
||||
contains_env_begin=ENV_BEGIN in s,
|
||||
contains_env_end=ENV_END in s,
|
||||
contains_item=ITEM in s,
|
||||
contains_splitting=True,
|
||||
contains_comment='%' in s,
|
||||
)
|
||||
else:
|
||||
return cls(
|
||||
contains_env_begin=False,
|
||||
contains_env_end=False,
|
||||
contains_item=False,
|
||||
contains_splitting=False,
|
||||
contains_comment='%' in s,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Log:
|
||||
"""Log message."""
|
||||
|
||||
level: str
|
||||
file: str
|
||||
message: str
|
||||
linum_new: Optional[int] = None
|
||||
linum_old: Optional[int] = None
|
||||
line: Optional[str] = None
|
||||
|
||||
|
||||
def find_comment_index(line: str, pattern: Pattern) -> Optional[int]:
|
||||
"""Find the index of a comment in a line."""
|
||||
if not pattern.contains_comment:
|
||||
return None
|
||||
|
||||
in_command = False
|
||||
for i, c in enumerate(line):
|
||||
if c == '\\':
|
||||
in_command = True
|
||||
elif in_command and not c.isalpha():
|
||||
in_command = False
|
||||
elif c == '%' and not in_command:
|
||||
return i
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def contains_ignore_skip(line: str) -> bool:
|
||||
"""Check if a line contains a skip directive."""
|
||||
return line.endswith("% tex-fmt: skip")
|
||||
|
||||
|
||||
def contains_ignore_begin(line: str) -> bool:
|
||||
"""Check if a line contains the start of an ignore block."""
|
||||
return line.endswith("% tex-fmt: off")
|
||||
|
||||
|
||||
def contains_ignore_end(line: str) -> bool:
|
||||
"""Check if a line contains the end of an ignore block."""
|
||||
return line.endswith("% tex-fmt: on")
|
||||
|
||||
|
||||
def get_ignore(line: str, state: State, logs: List[Log], file: str, warn: bool) -> Ignore:
|
||||
"""Determine whether a line should be ignored."""
|
||||
skip = contains_ignore_skip(line)
|
||||
begin = contains_ignore_begin(line)
|
||||
end = contains_ignore_end(line)
|
||||
|
||||
if skip:
|
||||
actual = state.ignore.actual
|
||||
visual = True
|
||||
elif begin:
|
||||
actual = True
|
||||
visual = True
|
||||
if warn and state.ignore.actual:
|
||||
logs.append(
|
||||
Log(
|
||||
level="WARN",
|
||||
file=file,
|
||||
message="Cannot begin ignore block:",
|
||||
linum_new=state.linum_new,
|
||||
linum_old=state.linum_old,
|
||||
line=line,
|
||||
)
|
||||
)
|
||||
elif end:
|
||||
actual = False
|
||||
visual = True
|
||||
if warn and not state.ignore.actual:
|
||||
logs.append(
|
||||
Log(
|
||||
level="WARN",
|
||||
file=file,
|
||||
message="No ignore block to end.",
|
||||
linum_new=state.linum_new,
|
||||
linum_old=state.linum_old,
|
||||
line=line,
|
||||
)
|
||||
)
|
||||
else:
|
||||
actual = state.ignore.actual
|
||||
visual = state.ignore.actual
|
||||
|
||||
return Ignore(actual=actual, visual=visual)
|
||||
|
||||
|
||||
def get_verbatim_diff(line: str, pattern: Pattern) -> int:
|
||||
"""Calculate total verbatim depth change."""
|
||||
if pattern.contains_env_begin and any(r in line for r in VERBATIMS_BEGIN):
|
||||
return 1
|
||||
elif pattern.contains_env_end and any(r in line for r in VERBATIMS_END):
|
||||
return -1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_verbatim(
|
||||
line: str, state: State, logs: List[Log], file: str, warn: bool, pattern: Pattern
|
||||
) -> Verbatim:
|
||||
"""Determine whether a line is in a verbatim environment."""
|
||||
diff = get_verbatim_diff(line, pattern)
|
||||
actual = state.verbatim.actual + diff
|
||||
visual = actual > 0 or state.verbatim.actual > 0
|
||||
|
||||
if warn and actual < 0:
|
||||
logs.append(
|
||||
Log(
|
||||
level="WARN",
|
||||
file=file,
|
||||
message="Verbatim count is negative.",
|
||||
linum_new=state.linum_new,
|
||||
linum_old=state.linum_old,
|
||||
line=line,
|
||||
)
|
||||
)
|
||||
|
||||
return Verbatim(actual=actual, visual=visual)
|
||||
|
||||
|
||||
def get_diff(line: str, pattern: Pattern, lists_begin: List[str], lists_end: List[str]) -> int:
|
||||
"""Calculate total indentation change due to the current line."""
|
||||
diff = 0
|
||||
|
||||
# Other environments get single indents
|
||||
if pattern.contains_env_begin and ENV_BEGIN in line:
|
||||
# Documents get no global indentation
|
||||
if DOC_BEGIN in line:
|
||||
return 0
|
||||
diff += 1
|
||||
diff += 1 if any(r in line for r in lists_begin) else 0
|
||||
elif pattern.contains_env_end and ENV_END in line:
|
||||
# Documents get no global indentation
|
||||
if DOC_END in line:
|
||||
return 0
|
||||
diff -= 1
|
||||
diff -= 1 if any(r in line for r in lists_end) else 0
|
||||
|
||||
# Indent for delimiters
|
||||
for c in line:
|
||||
if c in OPENS:
|
||||
diff += 1
|
||||
elif c in CLOSES:
|
||||
diff -= 1
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def get_back(line: str, pattern: Pattern, state: State, lists_end: List[str]) -> int:
|
||||
"""Calculate dedentation for the current line."""
|
||||
# Only need to dedent if indentation is present
|
||||
if state.indent.actual == 0:
|
||||
return 0
|
||||
|
||||
if pattern.contains_env_end and ENV_END in line:
|
||||
# Documents get no global indentation
|
||||
if DOC_END in line:
|
||||
return 0
|
||||
# List environments get double indents for indenting items
|
||||
for r in lists_end:
|
||||
if r in line:
|
||||
return 2
|
||||
return 1
|
||||
|
||||
# Items get dedented
|
||||
if pattern.contains_item and ITEM in line:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def get_indent(
|
||||
line: str,
|
||||
prev_indent: Indent,
|
||||
pattern: Pattern,
|
||||
state: State,
|
||||
lists_begin: List[str],
|
||||
lists_end: List[str],
|
||||
) -> Indent:
|
||||
"""Calculate the indent for a line."""
|
||||
diff = get_diff(line, pattern, lists_begin, lists_end)
|
||||
back = get_back(line, pattern, state, lists_end)
|
||||
|
||||
actual = prev_indent.actual + diff
|
||||
visual = max(0, prev_indent.actual - back)
|
||||
|
||||
return Indent(actual=actual, visual=visual)
|
||||
|
||||
|
||||
def calculate_indent(
|
||||
line: str,
|
||||
state: State,
|
||||
logs: List[Log],
|
||||
file: str,
|
||||
args: Args,
|
||||
pattern: Pattern,
|
||||
lists_begin: List[str],
|
||||
lists_end: List[str],
|
||||
) -> Indent:
|
||||
"""Calculate the indent for a line and update the state."""
|
||||
indent = get_indent(line, state.indent, pattern, state, lists_begin, lists_end)
|
||||
|
||||
# Update the state
|
||||
state.indent = indent
|
||||
|
||||
# Record the last line with zero indent
|
||||
if indent.visual == 0:
|
||||
state.linum_last_zero_indent = state.linum_new
|
||||
|
||||
return indent
|
||||
|
||||
|
||||
def apply_indent(line: str, indent: Indent, args: Args, indent_char: str) -> str:
|
||||
"""Apply indentation to a line."""
|
||||
if not line.strip():
|
||||
return ""
|
||||
|
||||
indent_str = indent_char * (indent.visual * args.tabsize)
|
||||
return indent_str + line.lstrip()
|
||||
|
||||
|
||||
def needs_wrap(line: str, indent_length: int, args: Args) -> bool:
|
||||
"""Check if a line needs wrapping."""
|
||||
return args.wrap and (len(line) + indent_length > args.wraplen)
|
||||
|
||||
|
||||
def find_wrap_point(line: str, indent_length: int, args: Args) -> Optional[int]:
|
||||
"""Find the best place to break a long line."""
|
||||
wrap_point = None
|
||||
after_char = False
|
||||
prev_char = None
|
||||
|
||||
line_width = 0
|
||||
wrap_boundary = args.wrapmin - indent_length
|
||||
|
||||
for i, c in enumerate(line):
|
||||
line_width += 1
|
||||
if line_width > wrap_boundary and wrap_point is not None:
|
||||
break
|
||||
if c == ' ' and prev_char != '\\':
|
||||
if after_char:
|
||||
wrap_point = i
|
||||
elif c != '%':
|
||||
after_char = True
|
||||
prev_char = c
|
||||
|
||||
return wrap_point
|
||||
|
||||
|
||||
def apply_wrap(
|
||||
line: str,
|
||||
indent_length: int,
|
||||
state: State,
|
||||
file: str,
|
||||
args: Args,
|
||||
logs: List[Log],
|
||||
pattern: Pattern,
|
||||
) -> Optional[List[str]]:
|
||||
"""Wrap a long line into a short prefix and a suffix."""
|
||||
if args.verbosity >= 3: # Trace level
|
||||
logs.append(
|
||||
Log(
|
||||
level="TRACE",
|
||||
file=file,
|
||||
message="Wrapping long line.",
|
||||
linum_new=state.linum_new,
|
||||
linum_old=state.linum_old,
|
||||
line=line,
|
||||
)
|
||||
)
|
||||
|
||||
wrap_point = find_wrap_point(line, indent_length, args)
|
||||
comment_index = find_comment_index(line, pattern)
|
||||
|
||||
if wrap_point is None or wrap_point > args.wraplen:
|
||||
logs.append(
|
||||
Log(
|
||||
level="WARN",
|
||||
file=file,
|
||||
message="Line cannot be wrapped.",
|
||||
linum_new=state.linum_new,
|
||||
linum_old=state.linum_old,
|
||||
line=line,
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
this_line = line[:wrap_point]
|
||||
|
||||
if comment_index is not None and wrap_point > comment_index:
|
||||
next_line_start = COMMENT_LINE_START
|
||||
else:
|
||||
next_line_start = TEXT_LINE_START
|
||||
|
||||
next_line = line[wrap_point + 1 :]
|
||||
|
||||
return [this_line, next_line_start, next_line]
|
||||
|
||||
|
||||
def needs_split(line: str, pattern: Pattern) -> bool:
|
||||
"""Check if line contains content which should be split onto a new line."""
|
||||
# Check if we should format this line and if we've matched an environment
|
||||
contains_splittable_env = (
|
||||
pattern.contains_splitting and RE_SPLITTING_SHARED_LINE.search(line) is not None
|
||||
)
|
||||
|
||||
# If we're not ignoring and we've matched an environment...
|
||||
if contains_splittable_env:
|
||||
# Return True if the comment index is None (which implies the split point must be in text),
|
||||
# otherwise compare the index of the comment with the split point
|
||||
comment_index = find_comment_index(line, pattern)
|
||||
if comment_index is None:
|
||||
return True
|
||||
|
||||
match = RE_SPLITTING_SHARED_LINE_CAPTURE.search(line)
|
||||
if match and match.start(2) > comment_index:
|
||||
# If split point is past the comment index, don't split
|
||||
return False
|
||||
else:
|
||||
# Otherwise, split point is before comment and we do split
|
||||
return True
|
||||
else:
|
||||
# If ignoring or didn't match an environment, don't need a new line
|
||||
return False
|
||||
|
||||
|
||||
def split_line(line: str, state: State, file: str, args: Args, logs: List[Log]) -> Tuple[str, str]:
|
||||
"""Ensure lines are split correctly."""
|
||||
match = RE_SPLITTING_SHARED_LINE_CAPTURE.search(line)
|
||||
if not match:
|
||||
return line, ""
|
||||
|
||||
prev = match.group('prev')
|
||||
rest = match.group('env')
|
||||
|
||||
if args.verbosity >= 3: # Trace level
|
||||
logs.append(
|
||||
Log(
|
||||
level="TRACE",
|
||||
file=file,
|
||||
message="Placing environment on new line.",
|
||||
linum_new=state.linum_new,
|
||||
linum_old=state.linum_old,
|
||||
line=line,
|
||||
)
|
||||
)
|
||||
|
||||
return prev, rest
|
||||
|
||||
|
||||
def set_ignore_and_report(
|
||||
line: str, temp_state: State, logs: List[Log], file: str, pattern: Pattern
|
||||
) -> bool:
|
||||
"""Sets the ignore and verbatim flags in the given State based on line and returns whether line should be ignored."""
|
||||
temp_state.ignore = get_ignore(line, temp_state, logs, file, True)
|
||||
temp_state.verbatim = get_verbatim(line, temp_state, logs, file, True, pattern)
|
||||
|
||||
return temp_state.verbatim.visual or temp_state.ignore.visual
|
||||
|
||||
|
||||
def clean_text(text: str, args: Args) -> str:
|
||||
"""Cleans the given text by removing extra line breaks and trailing spaces."""
|
||||
# Remove extra newlines
|
||||
text = RE_NEWLINES.sub(f"{LINE_END}{LINE_END}", text)
|
||||
|
||||
# Remove tabs if they shouldn't be used
|
||||
if args.tabchar != '\t':
|
||||
text = text.replace('\t', ' ' * args.tabsize)
|
||||
|
||||
# Remove trailing spaces
|
||||
text = RE_TRAIL.sub(LINE_END, text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def remove_trailing_spaces(text: str) -> str:
|
||||
"""Remove trailing spaces from line endings."""
|
||||
return RE_TRAIL.sub(LINE_END, text)
|
||||
|
||||
|
||||
def remove_trailing_blank_lines(text: str) -> str:
|
||||
"""Remove trailing blank lines from file."""
|
||||
return text.rstrip() + LINE_END
|
||||
|
||||
|
||||
def indents_return_to_zero(state: State) -> bool:
|
||||
"""Check if indentation returns to zero at the end of the file."""
|
||||
return state.indent.actual == 0
|
||||
|
||||
|
||||
def format_latex(text: str) -> str:
|
||||
"""Format LaTeX text with default formatting options.
|
||||
|
||||
This is the main API function for formatting LaTeX text.
|
||||
It uses pre-defined default values for all formatting parameters.
|
||||
|
||||
Args:
|
||||
text: LaTeX text to format
|
||||
|
||||
Returns:
|
||||
Formatted LaTeX text
|
||||
"""
|
||||
# Use default configuration
|
||||
args = Args()
|
||||
file = "input.tex"
|
||||
|
||||
# Format and return only the text
|
||||
formatted_text, _ = _format_latex(text, file, args)
|
||||
return formatted_text.strip()
|
||||
|
||||
|
||||
def _format_latex(old_text: str, file: str, args: Args) -> Tuple[str, List[Log]]:
|
||||
"""Internal function to format a LaTeX string."""
|
||||
logs = []
|
||||
logs.append(Log(level="INFO", file=file, message="Formatting started."))
|
||||
|
||||
# Clean the source file
|
||||
old_text = clean_text(old_text, args)
|
||||
old_lines = list(enumerate(old_text.splitlines(), 1))
|
||||
|
||||
# Initialize
|
||||
state = State()
|
||||
queue = []
|
||||
new_text = ""
|
||||
|
||||
# Select the character used for indentation
|
||||
indent_char = '\t' if args.tabchar == '\t' else ' '
|
||||
|
||||
# Get any extra environments to be indented as lists
|
||||
lists_begin = [f"\\begin{{{l}}}" for l in args.lists]
|
||||
lists_end = [f"\\end{{{l}}}" for l in args.lists]
|
||||
|
||||
while True:
|
||||
if queue:
|
||||
linum_old, line = queue.pop(0)
|
||||
|
||||
# Read the patterns present on this line
|
||||
pattern = Pattern.new(line)
|
||||
|
||||
# Temporary state for working on this line
|
||||
temp_state = State(
|
||||
linum_old=linum_old,
|
||||
linum_new=state.linum_new,
|
||||
ignore=Ignore(state.ignore.actual, state.ignore.visual),
|
||||
indent=Indent(state.indent.actual, state.indent.visual),
|
||||
verbatim=Verbatim(state.verbatim.actual, state.verbatim.visual),
|
||||
linum_last_zero_indent=state.linum_last_zero_indent,
|
||||
)
|
||||
|
||||
# If the line should not be ignored...
|
||||
if not set_ignore_and_report(line, temp_state, logs, file, pattern):
|
||||
# Check if the line should be split because of a pattern that should begin on a new line
|
||||
if needs_split(line, pattern):
|
||||
# Split the line into two...
|
||||
this_line, next_line = split_line(line, temp_state, file, args, logs)
|
||||
# ...and queue the second part for formatting
|
||||
if next_line:
|
||||
queue.insert(0, (linum_old, next_line))
|
||||
line = this_line
|
||||
|
||||
# Calculate the indent based on the current state and the patterns in the line
|
||||
indent = calculate_indent(
|
||||
line, temp_state, logs, file, args, pattern, lists_begin, lists_end
|
||||
)
|
||||
|
||||
indent_length = indent.visual * args.tabsize
|
||||
|
||||
# Wrap the line before applying the indent, and loop back if the line needed wrapping
|
||||
if needs_wrap(line.lstrip(), indent_length, args):
|
||||
wrapped_lines = apply_wrap(
|
||||
line.lstrip(), indent_length, temp_state, file, args, logs, pattern
|
||||
)
|
||||
if wrapped_lines:
|
||||
this_line, next_line_start, next_line = wrapped_lines
|
||||
queue.insert(0, (linum_old, next_line_start + next_line))
|
||||
queue.insert(0, (linum_old, this_line))
|
||||
continue
|
||||
|
||||
# Lastly, apply the indent if the line didn't need wrapping
|
||||
line = apply_indent(line, indent, args, indent_char)
|
||||
|
||||
# Add line to new text
|
||||
state = temp_state
|
||||
new_text += line + LINE_END
|
||||
state.linum_new += 1
|
||||
elif old_lines:
|
||||
linum_old, line = old_lines.pop(0)
|
||||
queue.append((linum_old, line))
|
||||
else:
|
||||
break
|
||||
|
||||
if not indents_return_to_zero(state):
|
||||
msg = f"Indent does not return to zero. Last non-indented line is line {state.linum_last_zero_indent}"
|
||||
logs.append(Log(level="WARN", file=file, message=msg))
|
||||
|
||||
new_text = remove_trailing_spaces(new_text)
|
||||
new_text = remove_trailing_blank_lines(new_text)
|
||||
logs.append(Log(level="INFO", file=file, message="Formatting complete."))
|
||||
|
||||
return new_text, logs
|
||||
241
texteller/api/inference.py
Normal file
241
texteller/api/inference.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import re
|
||||
import time
|
||||
from collections import Counter
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnxruntime import InferenceSession
|
||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||
from transformers import GenerationConfig, RobertaTokenizerFast
|
||||
|
||||
from texteller.constants import MAX_TOKEN_SIZE
|
||||
from texteller.logger import get_logger
|
||||
from texteller.paddleocr import predict_det, predict_rec
|
||||
from texteller.types import Bbox, TexTellerModel
|
||||
from texteller.utils import (
|
||||
bbox_merge,
|
||||
get_device,
|
||||
mask_img,
|
||||
readimgs,
|
||||
remove_style,
|
||||
slice_from_image,
|
||||
split_conflict,
|
||||
transform,
|
||||
add_newlines,
|
||||
)
|
||||
|
||||
from .detection import latex_detect
|
||||
from .format import format_latex
|
||||
from .katex import to_katex
|
||||
|
||||
_logger = get_logger()
|
||||
|
||||
|
||||
def img2latex(
|
||||
model: TexTellerModel,
|
||||
tokenizer: RobertaTokenizerFast,
|
||||
images: list[str] | list[np.ndarray],
|
||||
device: torch.device | None = None,
|
||||
out_format: Literal["latex", "katex"] = "latex",
|
||||
keep_style: bool = False,
|
||||
max_tokens: int = MAX_TOKEN_SIZE,
|
||||
num_beams: int = 1,
|
||||
no_repeat_ngram_size: int = 0,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Convert images to LaTeX or KaTeX formatted strings.
|
||||
|
||||
Args:
|
||||
model: The TexTeller or ORTModelForVision2Seq model instance
|
||||
tokenizer: The tokenizer for the model
|
||||
images: List of image paths or numpy arrays (RGB format)
|
||||
device: The torch device to use (defaults to available GPU or CPU)
|
||||
out_format: Output format, either "latex" or "katex"
|
||||
keep_style: Whether to keep the style of the LaTeX
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
num_beams: Number of beams for beam search
|
||||
no_repeat_ngram_size: Size of n-grams to prevent repetition
|
||||
|
||||
Returns:
|
||||
List of LaTeX or KaTeX strings corresponding to each input image
|
||||
|
||||
Example usage:
|
||||
>>> import torch
|
||||
>>> from texteller import load_model, load_tokenizer, img2latex
|
||||
|
||||
>>> model = load_model(model_path=None, use_onnx=False)
|
||||
>>> tokenizer = load_tokenizer(tokenizer_path=None)
|
||||
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
>>> res = img2latex(model, tokenizer, ["path/to/image.png"], device=device, out_format="katex")
|
||||
"""
|
||||
assert isinstance(images, list)
|
||||
assert len(images) > 0
|
||||
|
||||
if device is None:
|
||||
device = get_device()
|
||||
|
||||
if device.type != model.device.type:
|
||||
if isinstance(model, ORTModelForVision2Seq):
|
||||
_logger.warning(
|
||||
f"Onnxruntime device mismatch: detected {str(device)} but model is on {str(model.device)}, using {str(model.device)} instead"
|
||||
)
|
||||
else:
|
||||
model = model.to(device=device)
|
||||
|
||||
if isinstance(images[0], str):
|
||||
images = readimgs(images)
|
||||
else: # already numpy array(rgb format)
|
||||
assert isinstance(images[0], np.ndarray)
|
||||
images = images
|
||||
|
||||
images = transform(images)
|
||||
pixel_values = torch.stack(images)
|
||||
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=max_tokens,
|
||||
num_beams=num_beams,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
)
|
||||
pred = model.generate(
|
||||
pixel_values.to(model.device),
|
||||
generation_config=generate_config,
|
||||
)
|
||||
|
||||
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||
|
||||
if out_format == "katex":
|
||||
res = [to_katex(r) for r in res]
|
||||
|
||||
if not keep_style:
|
||||
res = [remove_style(r) for r in res]
|
||||
|
||||
res = [format_latex(r) for r in res]
|
||||
res = [add_newlines(r) for r in res]
|
||||
return res
|
||||
|
||||
|
||||
def paragraph2md(
|
||||
img_path: str,
|
||||
latexdet_model: InferenceSession,
|
||||
textdet_model: predict_det.TextDetector,
|
||||
textrec_model: predict_rec.TextRecognizer,
|
||||
latexrec_model: TexTellerModel,
|
||||
tokenizer: RobertaTokenizerFast,
|
||||
device: torch.device | None = None,
|
||||
num_beams=1,
|
||||
) -> str:
|
||||
"""
|
||||
Input a mixed image of formula text and output str (in markdown syntax)
|
||||
"""
|
||||
img = cv2.imread(img_path)
|
||||
corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])]
|
||||
bg_color = np.array(Counter(corners).most_common(1)[0][0])
|
||||
|
||||
start_time = time.time()
|
||||
latex_bboxes = latex_detect(img_path, latexdet_model)
|
||||
end_time = time.time()
|
||||
_logger.info(f"latex_det_model time: {end_time - start_time:.2f}s")
|
||||
latex_bboxes = sorted(latex_bboxes)
|
||||
latex_bboxes = bbox_merge(latex_bboxes)
|
||||
masked_img = mask_img(img, latex_bboxes, bg_color)
|
||||
|
||||
start_time = time.time()
|
||||
det_prediction, _ = textdet_model(masked_img)
|
||||
end_time = time.time()
|
||||
_logger.info(f"ocr_det_model time: {end_time - start_time:.2f}s")
|
||||
ocr_bboxes = [
|
||||
Bbox(
|
||||
p[0][0],
|
||||
p[0][1],
|
||||
p[3][1] - p[0][1],
|
||||
p[1][0] - p[0][0],
|
||||
label="text",
|
||||
confidence=None,
|
||||
content=None,
|
||||
)
|
||||
for p in det_prediction
|
||||
]
|
||||
|
||||
ocr_bboxes = sorted(ocr_bboxes)
|
||||
ocr_bboxes = bbox_merge(ocr_bboxes)
|
||||
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
|
||||
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
|
||||
|
||||
sliced_imgs: list[np.ndarray] = slice_from_image(img, ocr_bboxes)
|
||||
start_time = time.time()
|
||||
rec_predictions, _ = textrec_model(sliced_imgs)
|
||||
end_time = time.time()
|
||||
_logger.info(f"ocr_rec_model time: {end_time - start_time:.2f}s")
|
||||
|
||||
assert len(rec_predictions) == len(ocr_bboxes)
|
||||
for content, bbox in zip(rec_predictions, ocr_bboxes):
|
||||
bbox.content = content[0]
|
||||
|
||||
latex_imgs = []
|
||||
for bbox in latex_bboxes:
|
||||
latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w])
|
||||
start_time = time.time()
|
||||
latex_rec_res = img2latex(
|
||||
model=latexrec_model,
|
||||
tokenizer=tokenizer,
|
||||
images=latex_imgs,
|
||||
num_beams=num_beams,
|
||||
out_format="katex",
|
||||
device=device,
|
||||
keep_style=False,
|
||||
)
|
||||
end_time = time.time()
|
||||
_logger.info(f"latex_rec_model time: {end_time - start_time:.2f}s")
|
||||
|
||||
for bbox, content in zip(latex_bboxes, latex_rec_res):
|
||||
if bbox.label == "embedding":
|
||||
bbox.content = " $" + content + "$ "
|
||||
elif bbox.label == "isolated":
|
||||
bbox.content = "\n\n" + r"$$" + content + r"$$" + "\n\n"
|
||||
|
||||
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
||||
if bboxes == []:
|
||||
return ""
|
||||
|
||||
md = ""
|
||||
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
|
||||
for curr in bboxes:
|
||||
# Add the formula number back to the isolated formula
|
||||
if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr):
|
||||
curr.content = curr.content.strip()
|
||||
if curr.content.startswith("(") and curr.content.endswith(")"):
|
||||
curr.content = curr.content[1:-1]
|
||||
|
||||
if re.search(r"\\tag\{.*\}$", md[:-4]) is not None:
|
||||
# in case of multiple tag
|
||||
md = md[:-5] + f", {curr.content}" + "}" + md[-4:]
|
||||
else:
|
||||
md = md[:-4] + f"\\tag{{{curr.content}}}" + md[-4:]
|
||||
continue
|
||||
|
||||
if not prev.same_row(curr):
|
||||
md += " "
|
||||
|
||||
if curr.label == "embedding":
|
||||
# remove the bold effect from inline formulas
|
||||
curr.content = remove_style(curr.content)
|
||||
|
||||
# change split environment into aligned
|
||||
curr.content = curr.content.replace(r"\begin{split}", r"\begin{aligned}")
|
||||
curr.content = curr.content.replace(r"\end{split}", r"\end{aligned}")
|
||||
|
||||
# remove extra spaces (keeping only one)
|
||||
curr.content = re.sub(r" +", " ", curr.content)
|
||||
assert curr.content.startswith("$") and curr.content.endswith("$")
|
||||
curr.content = " $" + curr.content.strip("$") + "$ "
|
||||
md += curr.content
|
||||
prev = curr
|
||||
|
||||
return md.strip()
|
||||
118
texteller/api/katex.py
Normal file
118
texteller/api/katex.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import re
|
||||
|
||||
from ..utils.latex import change_all
|
||||
from .format import format_latex
|
||||
|
||||
|
||||
def _rm_dollar_surr(content):
|
||||
pattern = re.compile(r'\\[a-zA-Z]+\$.*?\$|\$.*?\$')
|
||||
matches = pattern.findall(content)
|
||||
|
||||
for match in matches:
|
||||
if not re.match(r'\\[a-zA-Z]+', match):
|
||||
new_match = match.strip('$')
|
||||
content = content.replace(match, ' ' + new_match + ' ')
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def to_katex(formula: str) -> str:
|
||||
res = formula
|
||||
# remove mbox surrounding
|
||||
res = change_all(res, r'\mbox ', r' ', r'{', r'}', r'', r'')
|
||||
res = change_all(res, r'\mbox', r' ', r'{', r'}', r'', r'')
|
||||
# remove hbox surrounding
|
||||
res = re.sub(r'\\hbox to ?-? ?\d+\.\d+(pt)?\{', r'\\hbox{', res)
|
||||
res = change_all(res, r'\hbox', r' ', r'{', r'}', r'', r' ')
|
||||
# remove raise surrounding
|
||||
res = re.sub(r'\\raise ?-? ?\d+\.\d+(pt)?', r' ', res)
|
||||
# remove makebox
|
||||
res = re.sub(r'\\makebox ?\[\d+\.\d+(pt)?\]\{', r'\\makebox{', res)
|
||||
res = change_all(res, r'\makebox', r' ', r'{', r'}', r'', r' ')
|
||||
# remove vbox surrounding, scalebox surrounding
|
||||
res = re.sub(r'\\raisebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\raisebox{', res)
|
||||
res = re.sub(r'\\scalebox\{-? ?\d+\.\d+(pt)?\}\{', r'\\scalebox{', res)
|
||||
res = change_all(res, r'\scalebox', r' ', r'{', r'}', r'', r' ')
|
||||
res = change_all(res, r'\raisebox', r' ', r'{', r'}', r'', r' ')
|
||||
res = change_all(res, r'\vbox', r' ', r'{', r'}', r'', r' ')
|
||||
|
||||
origin_instructions = [
|
||||
r'\Huge',
|
||||
r'\huge',
|
||||
r'\LARGE',
|
||||
r'\Large',
|
||||
r'\large',
|
||||
r'\normalsize',
|
||||
r'\small',
|
||||
r'\footnotesize',
|
||||
r'\tiny',
|
||||
]
|
||||
for old_ins, new_ins in zip(origin_instructions, origin_instructions):
|
||||
res = change_all(res, old_ins, new_ins, r'$', r'$', '{', '}')
|
||||
res = change_all(res, r'\mathbf', r'\bm', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath ', r'\bm', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath', r'\bm', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath ', r'\bm', r'$', r'$', r'{', r'}')
|
||||
res = change_all(res, r'\boldmath', r'\bm', r'$', r'$', r'{', r'}')
|
||||
res = change_all(res, r'\scriptsize', r'\scriptsize', r'$', r'$', r'{', r'}')
|
||||
res = change_all(res, r'\emph', r'\textit', r'{', r'}', r'{', r'}')
|
||||
res = change_all(res, r'\emph ', r'\textit', r'{', r'}', r'{', r'}')
|
||||
|
||||
# remove bold command
|
||||
res = change_all(res, r'\bm', r' ', r'{', r'}', r'', r'')
|
||||
|
||||
origin_instructions = [
|
||||
r'\left',
|
||||
r'\middle',
|
||||
r'\right',
|
||||
r'\big',
|
||||
r'\Big',
|
||||
r'\bigg',
|
||||
r'\Bigg',
|
||||
r'\bigl',
|
||||
r'\Bigl',
|
||||
r'\biggl',
|
||||
r'\Biggl',
|
||||
r'\bigm',
|
||||
r'\Bigm',
|
||||
r'\biggm',
|
||||
r'\Biggm',
|
||||
r'\bigr',
|
||||
r'\Bigr',
|
||||
r'\biggr',
|
||||
r'\Biggr',
|
||||
]
|
||||
for origin_ins in origin_instructions:
|
||||
res = change_all(res, origin_ins, origin_ins, r'{', r'}', r'', r'')
|
||||
|
||||
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
|
||||
|
||||
if res.endswith(r'\newline'):
|
||||
res = res[:-8]
|
||||
|
||||
# remove multiple spaces
|
||||
res = re.sub(r'(\\,){1,}', ' ', res)
|
||||
res = re.sub(r'(\\!){1,}', ' ', res)
|
||||
res = re.sub(r'(\\;){1,}', ' ', res)
|
||||
res = re.sub(r'(\\:){1,}', ' ', res)
|
||||
res = re.sub(r'\\vspace\{.*?}', '', res)
|
||||
|
||||
# merge consecutive text
|
||||
def merge_texts(match):
|
||||
texts = match.group(0)
|
||||
merged_content = ''.join(re.findall(r'\\text\{([^}]*)\}', texts))
|
||||
return f'\\text{{{merged_content}}}'
|
||||
|
||||
res = re.sub(r'(\\text\{[^}]*\}\s*){2,}', merge_texts, res)
|
||||
|
||||
res = res.replace(r'\bf ', '')
|
||||
res = _rm_dollar_surr(res)
|
||||
|
||||
# remove extra spaces (keeping only one)
|
||||
res = re.sub(r' +', ' ', res)
|
||||
|
||||
# format latex
|
||||
res = res.strip()
|
||||
res = format_latex(res)
|
||||
|
||||
return res
|
||||
66
texteller/api/load.py
Normal file
66
texteller/api/load.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from pathlib import Path
|
||||
|
||||
import wget
|
||||
from onnxruntime import InferenceSession
|
||||
from transformers import RobertaTokenizerFast
|
||||
|
||||
from texteller.constants import LATEX_DET_MODEL_URL, TEXT_DET_MODEL_URL, TEXT_REC_MODEL_URL
|
||||
from texteller.globals import Globals
|
||||
from texteller.logger import get_logger
|
||||
from texteller.models import TexTeller
|
||||
from texteller.paddleocr import predict_det, predict_rec
|
||||
from texteller.paddleocr.utility import parse_args
|
||||
from texteller.utils import cuda_available, mkdir, resolve_path
|
||||
from texteller.types import TexTellerModel
|
||||
|
||||
_logger = get_logger(__name__)
|
||||
|
||||
|
||||
def load_model(model_dir: str | None = None, use_onnx: bool = False) -> TexTellerModel:
|
||||
return TexTeller.from_pretrained(model_dir, use_onnx=use_onnx)
|
||||
|
||||
|
||||
def load_tokenizer(tokenizer_dir: str | None = None) -> RobertaTokenizerFast:
|
||||
return TexTeller.get_tokenizer(tokenizer_dir)
|
||||
|
||||
|
||||
def load_latexdet_model() -> InferenceSession:
|
||||
fpath = _maybe_download(LATEX_DET_MODEL_URL)
|
||||
return InferenceSession(
|
||||
resolve_path(fpath),
|
||||
providers=["CUDAExecutionProvider" if cuda_available() else "CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
|
||||
def load_textrec_model() -> predict_rec.TextRecognizer:
|
||||
fpath = _maybe_download(TEXT_REC_MODEL_URL)
|
||||
paddleocr_args = parse_args()
|
||||
paddleocr_args.use_onnx = True
|
||||
paddleocr_args.rec_model_dir = resolve_path(fpath)
|
||||
paddleocr_args.use_gpu = cuda_available()
|
||||
predictor = predict_rec.TextRecognizer(paddleocr_args)
|
||||
return predictor
|
||||
|
||||
|
||||
def load_textdet_model() -> predict_det.TextDetector:
|
||||
fpath = _maybe_download(TEXT_DET_MODEL_URL)
|
||||
paddleocr_args = parse_args()
|
||||
paddleocr_args.use_onnx = True
|
||||
paddleocr_args.det_model_dir = resolve_path(fpath)
|
||||
paddleocr_args.use_gpu = cuda_available()
|
||||
predictor = predict_det.TextDetector(paddleocr_args)
|
||||
return predictor
|
||||
|
||||
|
||||
def _maybe_download(url: str, dirpath: str | None = None, force: bool = False) -> Path:
|
||||
if dirpath is None:
|
||||
dirpath = Globals().cache_dir
|
||||
mkdir(dirpath)
|
||||
|
||||
fname = Path(url).name
|
||||
fpath = Path(dirpath) / fname
|
||||
if not fpath.exists() or force:
|
||||
_logger.info(f"Downloading {fname} from {url} to {fpath}")
|
||||
wget.download(url, resolve_path(fpath))
|
||||
|
||||
return fpath
|
||||
Reference in New Issue
Block a user