[refactor] Init

This commit is contained in:
OleehyO
2025-04-16 14:23:02 +00:00
parent 0e32f3f3bf
commit 06edd104e2
101 changed files with 1854 additions and 2758 deletions

1
texteller/__init__.py Normal file
View File

@@ -0,0 +1 @@
from texteller.api import *

24
texteller/api/__init__.py Normal file
View File

@@ -0,0 +1,24 @@
from .detection import latex_detect
from .format import format_latex
from .inference import img2latex, paragraph2md
from .katex import to_katex
from .load import (
load_latexdet_model,
load_model,
load_textdet_model,
load_textrec_model,
load_tokenizer,
)
__all__ = [
"to_katex",
"format_latex",
"img2latex",
"paragraph2md",
"load_model",
"load_tokenizer",
"load_latexdet_model",
"load_textrec_model",
"load_textdet_model",
"latex_detect",
]

View File

@@ -0,0 +1,4 @@
from .ngram import DetectRepeatingNgramCriteria
__all__ = ["DetectRepeatingNgramCriteria"]

View File

@@ -1,16 +1,8 @@
import torch import torch
import numpy as np from transformers import StoppingCriteria
from transformers import RobertaTokenizerFast, GenerationConfig, StoppingCriteria
from typing import List, Union
from .transforms import inference_transform
from .helpers import convert2rgb
from ..model.TexTeller import TexTeller
from ...globals import MAX_TOKEN_SIZE
class EfficientDetectRepeatingNgramCriteria(StoppingCriteria): class DetectRepeatingNgramCriteria(StoppingCriteria):
""" """
Stops generation efficiently if any n-gram repeats. Stops generation efficiently if any n-gram repeats.
@@ -69,48 +61,3 @@ class EfficientDetectRepeatingNgramCriteria(StoppingCriteria):
# It's a new n-gram, add it to the set and continue # It's a new n-gram, add it to the set and continue
self.seen_ngrams.add(last_ngram_tuple) self.seen_ngrams.add(last_ngram_tuple)
return False # Continue generation return False # Continue generation
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs: Union[List[str], List[np.ndarray]],
accelerator: str = 'cpu',
num_beams: int = 1,
max_tokens=None,
) -> List[str]:
if imgs == []:
return []
if hasattr(model, 'eval'):
# not onnx session, turn model.eval()
model.eval()
if isinstance(imgs[0], str):
imgs = convert2rgb(imgs)
else: # already numpy array(rgb format)
assert isinstance(imgs[0], np.ndarray)
imgs = imgs
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
if hasattr(model, 'eval'):
# not onnx session, move weights to device
model = model.to(accelerator)
pixel_values = pixel_values.to(accelerator)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
# no_repeat_ngram_size=10,
)
pred = model.generate(
pixel_values.to(model.device),
generation_config=generate_config,
# stopping_criteria=[EfficientDetectRepeatingNgramCriteria(20)],
)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res

View File

@@ -0,0 +1,3 @@
from .detect import latex_detect
__all__ = ["latex_detect"]

View File

@@ -0,0 +1,48 @@
from typing import List
from onnxruntime import InferenceSession
from texteller.types import Bbox
from .preprocess import Compose
_config = {
"mode": "paddle",
"draw_threshold": 0.5,
"metric": "COCO",
"use_dynamic_shape": False,
"arch": "DETR",
"min_subgraph_size": 3,
"preprocess": [
{"interp": 2, "keep_ratio": False, "target_size": [1600, 1600], "type": "Resize"},
{
"mean": [0.0, 0.0, 0.0],
"norm_type": "none",
"std": [1.0, 1.0, 1.0],
"type": "NormalizeImage",
},
{"type": "Permute"},
],
"label_list": ["isolated", "embedding"],
}
def latex_detect(img_path: str, predictor: InferenceSession) -> List[Bbox]:
transforms = Compose(_config["preprocess"])
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None,] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
res = []
for output in outputs:
cls_name = _config["label_list"][int(output[0])]
score = output[1]
xmin = int(max(output[2], 0))
ymin = int(max(output[3], 0))
xmax = int(output[4])
ymax = int(output[5])
if score > 0.5:
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
return res

View File

@@ -0,0 +1,161 @@
import copy
import cv2
import numpy as np
def decode_image(img_path):
if isinstance(img_path, str):
with open(img_path, "rb") as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype="uint8")
else:
assert isinstance(img_path, np.ndarray)
data = img_path
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
img_info = {
"im_shape": np.array(im.shape[:2], dtype=np.float32),
"scale_factor": np.array([1.0, 1.0], dtype=np.float32),
}
return im, img_info
class Resize(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp)
im_info["im_shape"] = np.array(im.shape[:2]).astype("float32")
im_info["scale_factor"] = np.array([im_scale_y, im_scale_x]).astype("float32")
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class NormalizeImage(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True, norm_type="mean_std"):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_scale:
scale = 1.0 / 255.0
im *= scale
if self.norm_type == "mean_std":
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
class Permute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(
self,
):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class Compose:
def __init__(self, transforms):
self.transforms = []
for op_info in transforms:
new_op_info = op_info.copy()
op_type = new_op_info.pop("type")
self.transforms.append(eval(op_type)(**new_op_info))
def __call__(self, img_path):
img, im_info = decode_image(img_path)
for t in self.transforms:
img, im_info = t(img, im_info)
inputs = copy.deepcopy(im_info)
inputs["image"] = img
return inputs

View File

@@ -5,9 +5,8 @@ Based on the Rust implementation at https://github.com/WGUNDERWOOD/tex-fmt
""" """
import re import re
import argparse
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Dict, Set from typing import List, Optional, Tuple
# Constants # Constants
LINE_END = "\n" LINE_END = "\n"
@@ -49,7 +48,7 @@ RE_SPLITTING_SHARED_LINE_CAPTURE = re.compile(f"(?P<prev>\\S.*?)(?P<env>{SPLITTI
@dataclass @dataclass
class Args: class Args:
"""Command line arguments and configuration.""" """Formatter configuration."""
tabchar: str = " " tabchar: str = " "
tabsize: int = 4 tabsize: int = 4
@@ -542,13 +541,29 @@ def indents_return_to_zero(state: State) -> bool:
return state.indent.actual == 0 return state.indent.actual == 0
def format_latex( def format_latex(text: str) -> str:
old_text: str, file: str = "input.tex", args: Optional[Args] = None """Format LaTeX text with default formatting options.
) -> Tuple[str, List[Log]]:
"""Central function to format a LaTeX string."""
if args is None:
args = Args()
This is the main API function for formatting LaTeX text.
It uses pre-defined default values for all formatting parameters.
Args:
text: LaTeX text to format
Returns:
Formatted LaTeX text
"""
# Use default configuration
args = Args()
file = "input.tex"
# Format and return only the text
formatted_text, _ = _format_latex(text, file, args)
return formatted_text.strip()
def _format_latex(old_text: str, file: str, args: Args) -> Tuple[str, List[Log]]:
"""Internal function to format a LaTeX string."""
logs = [] logs = []
logs.append(Log(level="INFO", file=file, message="Formatting started.")) logs.append(Log(level="INFO", file=file, message="Formatting started."))
@@ -636,63 +651,3 @@ def format_latex(
logs.append(Log(level="INFO", file=file, message="Formatting complete.")) logs.append(Log(level="INFO", file=file, message="Formatting complete."))
return new_text, logs return new_text, logs
def main():
"""Command-line entry point."""
parser = argparse.ArgumentParser(description="Format LaTeX files")
parser.add_argument("file", help="LaTeX file to format")
parser.add_argument(
"--tabchar",
choices=["space", "tab"],
default="space",
help="Character to use for indentation",
)
parser.add_argument("--tabsize", type=int, default=4, help="Number of spaces per indent level")
parser.add_argument("--wrap", action="store_true", help="Enable line wrapping")
parser.add_argument("--wraplen", type=int, default=80, help="Maximum line length")
parser.add_argument(
"--wrapmin", type=int, default=40, help="Minimum line length before wrapping"
)
parser.add_argument(
"--lists", nargs="+", default=[], help="Additional environments to indent as lists"
)
parser.add_argument("--verbose", "-v", action="count", default=0, help="Increase verbosity")
parser.add_argument("--output", "-o", help="Output file (default: overwrite input)")
args_parsed = parser.parse_args()
# Convert command line args to our Args class
args = Args(
tabchar="\t" if args_parsed.tabchar == "tab" else " ",
tabsize=args_parsed.tabsize,
wrap=args_parsed.wrap,
wraplen=args_parsed.wraplen,
wrapmin=args_parsed.wrapmin,
lists=args_parsed.lists,
verbosity=args_parsed.verbose,
)
# Read input file
with open(args_parsed.file, "r", encoding="utf-8") as f:
text = f.read()
# Format the text
formatted_text, logs = format_latex(text, args_parsed.file, args)
# Print logs if verbose
if args.verbosity > 0:
for log in logs:
if log.linum_new is not None:
print(f"{log.level} {log.file}:{log.linum_new}:{log.linum_old}: {log.message}")
else:
print(f"{log.level} {log.file}: {log.message}")
# Write output
output_file = args_parsed.output or args_parsed.file
with open(output_file, "w", encoding="utf-8") as f:
f.write(formatted_text)
if __name__ == "__main__":
main()

241
texteller/api/inference.py Normal file
View File

@@ -0,0 +1,241 @@
import re
import time
from collections import Counter
from typing import Literal
import cv2
import numpy as np
import torch
from onnxruntime import InferenceSession
from optimum.onnxruntime import ORTModelForVision2Seq
from transformers import GenerationConfig, RobertaTokenizerFast
from texteller.constants import MAX_TOKEN_SIZE
from texteller.logger import get_logger
from texteller.paddleocr import predict_det, predict_rec
from texteller.types import Bbox, TexTellerModel
from texteller.utils import (
bbox_merge,
get_device,
mask_img,
readimgs,
remove_style,
slice_from_image,
split_conflict,
transform,
add_newlines,
)
from .detection import latex_detect
from .format import format_latex
from .katex import to_katex
_logger = get_logger()
def img2latex(
model: TexTellerModel,
tokenizer: RobertaTokenizerFast,
images: list[str] | list[np.ndarray],
device: torch.device | None = None,
out_format: Literal["latex", "katex"] = "latex",
keep_style: bool = False,
max_tokens: int = MAX_TOKEN_SIZE,
num_beams: int = 1,
no_repeat_ngram_size: int = 0,
) -> list[str]:
"""
Convert images to LaTeX or KaTeX formatted strings.
Args:
model: The TexTeller or ORTModelForVision2Seq model instance
tokenizer: The tokenizer for the model
images: List of image paths or numpy arrays (RGB format)
device: The torch device to use (defaults to available GPU or CPU)
out_format: Output format, either "latex" or "katex"
keep_style: Whether to keep the style of the LaTeX
max_tokens: Maximum number of tokens to generate
num_beams: Number of beams for beam search
no_repeat_ngram_size: Size of n-grams to prevent repetition
Returns:
List of LaTeX or KaTeX strings corresponding to each input image
Example usage:
>>> import torch
>>> from texteller import load_model, load_tokenizer, img2latex
>>> model = load_model(model_path=None, use_onnx=False)
>>> tokenizer = load_tokenizer(tokenizer_path=None)
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> res = img2latex(model, tokenizer, ["path/to/image.png"], device=device, out_format="katex")
"""
assert isinstance(images, list)
assert len(images) > 0
if device is None:
device = get_device()
if device.type != model.device.type:
if isinstance(model, ORTModelForVision2Seq):
_logger.warning(
f"Onnxruntime device mismatch: detected {str(device)} but model is on {str(model.device)}, using {str(model.device)} instead"
)
else:
model = model.to(device=device)
if isinstance(images[0], str):
images = readimgs(images)
else: # already numpy array(rgb format)
assert isinstance(images[0], np.ndarray)
images = images
images = transform(images)
pixel_values = torch.stack(images)
generate_config = GenerationConfig(
max_new_tokens=max_tokens,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
no_repeat_ngram_size=no_repeat_ngram_size,
)
pred = model.generate(
pixel_values.to(model.device),
generation_config=generate_config,
)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
if out_format == "katex":
res = [to_katex(r) for r in res]
if not keep_style:
res = [remove_style(r) for r in res]
res = [format_latex(r) for r in res]
res = [add_newlines(r) for r in res]
return res
def paragraph2md(
img_path: str,
latexdet_model: InferenceSession,
textdet_model: predict_det.TextDetector,
textrec_model: predict_rec.TextRecognizer,
latexrec_model: TexTellerModel,
tokenizer: RobertaTokenizerFast,
device: torch.device | None = None,
num_beams=1,
) -> str:
"""
Input a mixed image of formula text and output str (in markdown syntax)
"""
img = cv2.imread(img_path)
corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])]
bg_color = np.array(Counter(corners).most_common(1)[0][0])
start_time = time.time()
latex_bboxes = latex_detect(img_path, latexdet_model)
end_time = time.time()
_logger.info(f"latex_det_model time: {end_time - start_time:.2f}s")
latex_bboxes = sorted(latex_bboxes)
latex_bboxes = bbox_merge(latex_bboxes)
masked_img = mask_img(img, latex_bboxes, bg_color)
start_time = time.time()
det_prediction, _ = textdet_model(masked_img)
end_time = time.time()
_logger.info(f"ocr_det_model time: {end_time - start_time:.2f}s")
ocr_bboxes = [
Bbox(
p[0][0],
p[0][1],
p[3][1] - p[0][1],
p[1][0] - p[0][0],
label="text",
confidence=None,
content=None,
)
for p in det_prediction
]
ocr_bboxes = sorted(ocr_bboxes)
ocr_bboxes = bbox_merge(ocr_bboxes)
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
sliced_imgs: list[np.ndarray] = slice_from_image(img, ocr_bboxes)
start_time = time.time()
rec_predictions, _ = textrec_model(sliced_imgs)
end_time = time.time()
_logger.info(f"ocr_rec_model time: {end_time - start_time:.2f}s")
assert len(rec_predictions) == len(ocr_bboxes)
for content, bbox in zip(rec_predictions, ocr_bboxes):
bbox.content = content[0]
latex_imgs = []
for bbox in latex_bboxes:
latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w])
start_time = time.time()
latex_rec_res = img2latex(
model=latexrec_model,
tokenizer=tokenizer,
images=latex_imgs,
num_beams=num_beams,
out_format="katex",
device=device,
keep_style=False,
)
end_time = time.time()
_logger.info(f"latex_rec_model time: {end_time - start_time:.2f}s")
for bbox, content in zip(latex_bboxes, latex_rec_res):
if bbox.label == "embedding":
bbox.content = " $" + content + "$ "
elif bbox.label == "isolated":
bbox.content = "\n\n" + r"$$" + content + r"$$" + "\n\n"
bboxes = sorted(ocr_bboxes + latex_bboxes)
if bboxes == []:
return ""
md = ""
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
for curr in bboxes:
# Add the formula number back to the isolated formula
if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr):
curr.content = curr.content.strip()
if curr.content.startswith("(") and curr.content.endswith(")"):
curr.content = curr.content[1:-1]
if re.search(r"\\tag\{.*\}$", md[:-4]) is not None:
# in case of multiple tag
md = md[:-5] + f", {curr.content}" + "}" + md[-4:]
else:
md = md[:-4] + f"\\tag{{{curr.content}}}" + md[-4:]
continue
if not prev.same_row(curr):
md += " "
if curr.label == "embedding":
# remove the bold effect from inline formulas
curr.content = remove_style(curr.content)
# change split environment into aligned
curr.content = curr.content.replace(r"\begin{split}", r"\begin{aligned}")
curr.content = curr.content.replace(r"\end{split}", r"\end{aligned}")
# remove extra spaces (keeping only one)
curr.content = re.sub(r" +", " ", curr.content)
assert curr.content.startswith("$") and curr.content.endswith("$")
curr.content = " $" + curr.content.strip("$") + "$ "
md += curr.content
prev = curr
return md.strip()

View File

@@ -1,73 +1,10 @@
import re import re
from .latex_formatter import format_latex from ..utils.latex import change_all
from .format import format_latex
def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r): def _rm_dollar_surr(content):
result = ""
i = 0
n = len(input_str)
while i < n:
if input_str[i : i + len(old_inst)] == old_inst:
# check if the old_inst is followed by old_surr_l
start = i + len(old_inst)
else:
result += input_str[i]
i += 1
continue
if start < n and input_str[start] == old_surr_l:
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
count = 1
j = start + 1
escaped = False
while j < n and count > 0:
if input_str[j] == '\\' and not escaped:
escaped = True
j += 1
continue
if input_str[j] == old_surr_r and not escaped:
count -= 1
if count == 0:
break
elif input_str[j] == old_surr_l and not escaped:
count += 1
escaped = False
j += 1
if count == 0:
assert j < n
assert input_str[start] == old_surr_l
assert input_str[j] == old_surr_r
inner_content = input_str[start + 1 : j]
# Replace the content with new pattern
result += new_inst + new_surr_l + inner_content + new_surr_r
i = j + 1
continue
else:
assert count >= 1
assert j == n
print("Warning: unbalanced surrogate pair in input string")
result += new_inst + new_surr_l
i = start + 1
continue
else:
result += input_str[i:start]
i = start
if old_inst != new_inst and (old_inst + old_surr_l) in result:
return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
else:
return result
def find_substring_positions(string, substring):
positions = [match.start() for match in re.finditer(re.escape(substring), string)]
return positions
def rm_dollar_surr(content):
pattern = re.compile(r'\\[a-zA-Z]+\$.*?\$|\$.*?\$') pattern = re.compile(r'\\[a-zA-Z]+\$.*?\$|\$.*?\$')
matches = pattern.findall(content) matches = pattern.findall(content)
@@ -79,19 +16,6 @@ def rm_dollar_surr(content):
return content return content
def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
pos = find_substring_positions(input_str, old_inst + old_surr_l)
res = list(input_str)
for p in pos[::-1]:
res[p:] = list(
change(
''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r
)
)
res = ''.join(res)
return res
def to_katex(formula: str) -> str: def to_katex(formula: str) -> str:
res = formula res = formula
# remove mbox surrounding # remove mbox surrounding
@@ -182,13 +106,13 @@ def to_katex(formula: str) -> str:
res = re.sub(r'(\\text\{[^}]*\}\s*){2,}', merge_texts, res) res = re.sub(r'(\\text\{[^}]*\}\s*){2,}', merge_texts, res)
res = res.replace(r'\bf ', '') res = res.replace(r'\bf ', '')
res = rm_dollar_surr(res) res = _rm_dollar_surr(res)
# remove extra spaces (keeping only one) # remove extra spaces (keeping only one)
res = re.sub(r' +', ' ', res) res = re.sub(r' +', ' ', res)
# format latex # format latex
res = res.strip() res = res.strip()
res, logs = format_latex(res) res = format_latex(res)
return res return res

66
texteller/api/load.py Normal file
View File

@@ -0,0 +1,66 @@
from pathlib import Path
import wget
from onnxruntime import InferenceSession
from transformers import RobertaTokenizerFast
from texteller.constants import LATEX_DET_MODEL_URL, TEXT_DET_MODEL_URL, TEXT_REC_MODEL_URL
from texteller.globals import Globals
from texteller.logger import get_logger
from texteller.models import TexTeller
from texteller.paddleocr import predict_det, predict_rec
from texteller.paddleocr.utility import parse_args
from texteller.utils import cuda_available, mkdir, resolve_path
from texteller.types import TexTellerModel
_logger = get_logger(__name__)
def load_model(model_dir: str | None = None, use_onnx: bool = False) -> TexTellerModel:
return TexTeller.from_pretrained(model_dir, use_onnx=use_onnx)
def load_tokenizer(tokenizer_dir: str | None = None) -> RobertaTokenizerFast:
return TexTeller.get_tokenizer(tokenizer_dir)
def load_latexdet_model() -> InferenceSession:
fpath = _maybe_download(LATEX_DET_MODEL_URL)
return InferenceSession(
resolve_path(fpath),
providers=["CUDAExecutionProvider" if cuda_available() else "CPUExecutionProvider"],
)
def load_textrec_model() -> predict_rec.TextRecognizer:
fpath = _maybe_download(TEXT_REC_MODEL_URL)
paddleocr_args = parse_args()
paddleocr_args.use_onnx = True
paddleocr_args.rec_model_dir = resolve_path(fpath)
paddleocr_args.use_gpu = cuda_available()
predictor = predict_rec.TextRecognizer(paddleocr_args)
return predictor
def load_textdet_model() -> predict_det.TextDetector:
fpath = _maybe_download(TEXT_DET_MODEL_URL)
paddleocr_args = parse_args()
paddleocr_args.use_onnx = True
paddleocr_args.det_model_dir = resolve_path(fpath)
paddleocr_args.use_gpu = cuda_available()
predictor = predict_det.TextDetector(paddleocr_args)
return predictor
def _maybe_download(url: str, dirpath: str | None = None, force: bool = False) -> Path:
if dirpath is None:
dirpath = Globals().cache_dir
mkdir(dirpath)
fname = Path(url).name
fpath = Path(dirpath) / fname
if not fpath.exists() or force:
_logger.info(f"Downloading {fname} from {url} to {fpath}")
wget.download(url, resolve_path(fpath))
return fpath

25
texteller/cli/__init__.py Normal file
View File

@@ -0,0 +1,25 @@
"""
CLI entry point for TexTeller.
"""
import time
import click
from texteller.cli.commands.inference import inference
from texteller.cli.commands.launch import launch
from texteller.cli.commands.web import web
@click.group()
def cli():
pass
cli.add_command(inference)
cli.add_command(web)
cli.add_command(launch)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,3 @@
"""
CLI commands for TexTeller
"""

View File

@@ -0,0 +1,51 @@
"""
CLI command for formula inference from images.
"""
import click
from texteller.api import img2latex, load_model, load_tokenizer
@click.command()
@click.argument("image_path", type=click.Path(exists=True, file_okay=True, dir_okay=False))
@click.option(
"--model-path",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
default=None,
help="Path to the model dir path, if not provided, will use model from huggingface repo",
)
@click.option(
"--tokenizer-path",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
default=None,
help="Path to the tokenizer dir path, if not provided, will use tokenizer from huggingface repo",
)
@click.option(
"--output-format",
type=click.Choice(["latex", "katex"]),
default="katex",
help="Output format, either latex or katex",
)
@click.option(
"--keep-style",
is_flag=True,
default=False,
help="Whether to keep the style of the LaTeX (e.g. bold, italic, etc.)",
)
def inference(image_path, model_path, tokenizer_path, output_format, keep_style):
"""
CLI command for formula inference from images.
"""
model = load_model(model_dir=model_path)
tknz = load_tokenizer(tokenizer_dir=tokenizer_path)
pred = img2latex(
model=model,
tokenizer=tknz,
images=[image_path],
out_format=output_format,
keep_style=keep_style,
)[0]
click.echo(f"Predicted LaTeX: ```\n{pred}\n```")

View File

@@ -0,0 +1,106 @@
"""
CLI commands for launching server.
"""
import sys
import time
import click
from ray import serve
from texteller.globals import Globals
from texteller.utils import get_device
@click.command()
@click.option(
"-ckpt",
"--checkpoint_dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
default=None,
help="Path to the checkpoint directory, if not provided, will use model from huggingface repo",
)
@click.option(
"-tknz",
"--tokenizer_dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
default=None,
help="Path to the tokenizer directory, if not provided, will use tokenizer from huggingface repo",
)
@click.option(
"-p",
"--port",
type=int,
default=8000,
help="Port to run the server on",
)
@click.option(
"--num-replicas",
type=int,
default=1,
help="Number of replicas to run the server on",
)
@click.option(
"--ncpu-per-replica",
type=float,
default=1.0,
help="Number of CPUs per replica",
)
@click.option(
"--ngpu-per-replica",
type=float,
default=1.0,
help="Number of GPUs per replica",
)
@click.option(
"--num-beams",
type=int,
default=1,
help="Number of beams to use",
)
@click.option(
"--use-onnx",
is_flag=True,
type=bool,
default=False,
help="Use ONNX runtime",
)
def launch(
checkpoint_dir,
tokenizer_dir,
port,
num_replicas,
ncpu_per_replica,
ngpu_per_replica,
num_beams,
use_onnx,
):
"""Launch the api server"""
device = get_device()
if ngpu_per_replica > 0 and not device.type == "cuda":
click.echo(
click.style(
f"Error: --ngpu-per-replica > 0 but detected device is {device.type}",
fg="red",
)
)
sys.exit(1)
Globals().num_replicas = num_replicas
Globals().ncpu_per_replica = ncpu_per_replica
Globals().ngpu_per_replica = ngpu_per_replica
from texteller.cli.commands.launch.server import Ingress, TexTellerServer
serve.start(http_options={"host": "0.0.0.0", "port": port})
rec_server = TexTellerServer.bind(
checkpoint_dir=checkpoint_dir,
tokenizer_dir=tokenizer_dir,
use_onnx=use_onnx,
num_beams=num_beams,
)
ingress = Ingress.bind(rec_server)
serve.run(ingress, route_prefix="/predict")
while True:
time.sleep(1)

View File

@@ -0,0 +1,69 @@
import numpy as np
import cv2
from starlette.requests import Request
from ray import serve
from ray.serve.handle import DeploymentHandle
from texteller.api import load_model, load_tokenizer, img2latex
from texteller.utils import get_device
from texteller.globals import Globals
from typing import Literal
@serve.deployment(
num_replicas=Globals().num_replicas,
ray_actor_options={
"num_cpus": Globals().ncpu_per_replica,
"num_gpus": Globals().ngpu_per_replica * 1.0 / 2,
},
)
class TexTellerServer:
def __init__(
self,
checkpoint_dir: str,
tokenizer_dir: str,
use_onnx: bool = False,
out_format: Literal["latex", "katex"] = "katex",
keep_style: bool = False,
num_beams: int = 1,
) -> None:
self.model = load_model(
model_dir=checkpoint_dir,
use_onnx=use_onnx,
)
self.tokenizer = load_tokenizer(tokenizer_dir=tokenizer_dir)
self.num_beams = num_beams
self.out_format = out_format
self.keep_style = keep_style
if not use_onnx:
self.model = self.model.to(get_device())
def predict(self, image_nparray: np.ndarray) -> str:
return img2latex(
model=self.model,
tokenizer=self.tokenizer,
images=[image_nparray],
device=get_device(),
out_format=self.out_format,
keep_style=self.keep_style,
num_beams=self.num_beams,
)[0]
@serve.deployment()
class Ingress:
def __init__(self, rec_server: DeploymentHandle) -> None:
self.texteller_server = rec_server
async def __call__(self, request: Request) -> str:
form = await request.form()
img_rb = await form["img"].read()
img_nparray = np.frombuffer(img_rb, np.uint8)
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
pred = await self.texteller_server.predict.remote(img_nparray)
return pred

View File

@@ -0,0 +1,9 @@
import os
import click
from pathlib import Path
@click.command()
def web():
"""Launch the web interface for TexTeller."""
os.system(f"streamlit run {Path(__file__).parent / 'streamlit_demo.py'}")

View File

@@ -0,0 +1,225 @@
import base64
import io
import os
import re
import shutil
import tempfile
import streamlit as st
from PIL import Image
from streamlit_paste_button import paste_image_button as pbutton
from texteller.api import (
img2latex,
load_latexdet_model,
load_model,
load_textdet_model,
load_textrec_model,
load_tokenizer,
paragraph2md,
)
from texteller.cli.commands.web.style import (
HEADER_HTML,
IMAGE_EMBED_HTML,
IMAGE_INFO_HTML,
SUCCESS_GIF_HTML,
)
from texteller.utils import str2device
st.set_page_config(page_title="TexTeller", page_icon="🧮")
@st.cache_resource
def get_texteller(use_onnx):
return load_model(use_onnx=use_onnx)
@st.cache_resource
def get_tokenizer():
return load_tokenizer()
@st.cache_resource
def get_latexdet_model():
return load_latexdet_model()
@st.cache_resource()
def get_textrec_model():
return load_textrec_model()
@st.cache_resource()
def get_textdet_model():
return load_textdet_model()
def get_image_base64(img_file):
buffered = io.BytesIO()
img_file.seek(0)
img = Image.open(img_file)
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def on_file_upload():
st.session_state["UPLOADED_FILE_CHANGED"] = True
def change_side_bar():
st.session_state["CHANGE_SIDEBAR_FLAG"] = True
if "start" not in st.session_state:
st.session_state["start"] = 1
st.toast("Hooray!", icon="🎉")
if "UPLOADED_FILE_CHANGED" not in st.session_state:
st.session_state["UPLOADED_FILE_CHANGED"] = False
if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
if "INF_MODE" not in st.session_state:
st.session_state["INF_MODE"] = "Formula recognition"
# ====== <sidebar> ======
with st.sidebar:
num_beams = 1
st.markdown("# 🔨️ Config")
st.markdown("")
inf_mode = st.selectbox(
"Inference mode",
("Formula recognition", "Paragraph recognition"),
on_change=change_side_bar,
)
num_beams = st.number_input(
"Number of beams", min_value=1, max_value=20, step=1, on_change=change_side_bar
)
device = st.radio("device", ("cpu", "cuda", "mps"), on_change=change_side_bar)
st.markdown("## Seedup")
use_onnx = st.toggle("ONNX Runtime ")
# ====== </sidebar> ======
# ====== <page> ======
latexrec_model = get_texteller(use_onnx)
tokenizer = get_tokenizer()
if inf_mode == "Paragraph recognition":
latexdet_model = get_latexdet_model()
textrec_model = get_textrec_model()
textdet_model = get_textdet_model()
st.markdown(HEADER_HTML, unsafe_allow_html=True)
uploaded_file = st.file_uploader(" ", type=["jpg", "png"], on_change=on_file_upload)
paste_result = pbutton(
label="📋 Paste an image",
background_color="#5BBCFF",
hover_background_color="#3498db",
)
st.write("")
if st.session_state["CHANGE_SIDEBAR_FLAG"] is True:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
elif uploaded_file or paste_result.image_data is not None:
if st.session_state["UPLOADED_FILE_CHANGED"] is False and paste_result.image_data is not None:
uploaded_file = io.BytesIO()
paste_result.image_data.save(uploaded_file, format="PNG")
uploaded_file.seek(0)
if st.session_state["UPLOADED_FILE_CHANGED"] is True:
st.session_state["UPLOADED_FILE_CHANGED"] = False
img = Image.open(uploaded_file)
temp_dir = tempfile.mkdtemp()
png_fpath = os.path.join(temp_dir, "image.png")
img.save(png_fpath, "PNG")
with st.container(height=300):
img_base64 = get_image_base64(uploaded_file)
st.markdown(
IMAGE_EMBED_HTML.format(img_base64=img_base64),
unsafe_allow_html=True,
)
st.markdown(
IMAGE_INFO_HTML.format(img_height=img.height, img_width=img.width),
unsafe_allow_html=True,
)
st.write("")
with st.spinner("Predicting..."):
if inf_mode == "Formula recognition":
pred = img2latex(
model=latexrec_model,
tokenizer=tokenizer,
images=[png_fpath],
device=str2device(device),
out_format="katex",
num_beams=num_beams,
keep_style=False,
)[0]
else:
pred = paragraph2md(
img_path=png_fpath,
latexdet_model=latexdet_model,
textdet_model=textdet_model,
textrec_model=textrec_model,
latexrec_model=latexrec_model,
tokenizer=tokenizer,
device=str2device(device),
num_beams=num_beams,
)
st.success("Completed!", icon="")
# st.markdown(SUCCESS_GIF_HTML, unsafe_allow_html=True)
# st.text_area("Predicted LaTeX", pred, height=150)
if inf_mode == "Formula recognition":
st.code(pred, language="latex")
elif inf_mode == "Paragraph recognition":
st.code(pred, language="markdown")
else:
raise ValueError(f"Invalid inference mode: {inf_mode}")
if inf_mode == "Formula recognition":
st.latex(pred)
elif inf_mode == "Paragraph recognition":
mixed_res = re.split(r"(\$\$.*?\$\$)", pred, flags=re.DOTALL)
for text in mixed_res:
if text.startswith("$$") and text.endswith("$$"):
st.latex(text.strip("$$"))
else:
st.markdown(text)
st.write("")
st.write("")
with st.expander(":star2: :gray[Tips for better results]"):
st.markdown("""
* :mag_right: Use a clear and high-resolution image.
* :scissors: Crop images as accurately as possible.
* :jigsaw: Split large multi line formulas into smaller ones.
* :page_facing_up: Use images with **white background and black text** as much as possible.
* :book: Use a font with good readability.
""")
shutil.rmtree(temp_dir)
paste_result.image_data = None
# ====== </page> ======

View File

@@ -0,0 +1,55 @@
from texteller.utils import lines_dedent
HEADER_HTML = lines_dedent("""
<h1 style="color: black; text-align: center;">
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
</h1>
""")
SUCCESS_GIF_HTML = lines_dedent("""
<h1 style="color: black; text-align: center;">
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
</h1>
""")
FAIL_GIF_HTML = lines_dedent("""
<h1 style="color: black; text-align: center;">
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download">
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download">
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download">
</h1>
""")
IMAGE_EMBED_HTML = lines_dedent("""
<style>
.centered-container {{
text-align: center;
}}
.centered-image {{
display: block;
margin-left: auto;
margin-right: auto;
max-height: 350px;
max-width: 100%;
}}
</style>
<div class="centered-container">
<img src="data:image/png;base64,{img_base64}" class="centered-image" alt="Input image">
</div>
""")
IMAGE_INFO_HTML = lines_dedent("""
<style>
.centered-container {{
text-align: center;
}}
</style>
<div class="centered-container">
<p style="color:gray;">Input image ({img_height}✖️{img_width})</p>
</div>
""")

View File

@@ -1,12 +0,0 @@
import requests
rec_server_url = "http://127.0.0.1:8000/frec"
det_server_url = "http://127.0.0.1:8000/fdet"
img_path = "/your/image/path/"
with open(img_path, 'rb') as img:
files = {'img': img}
response = requests.post(rec_server_url, files=files)
# response = requests.post(det_server_url, files=files)
print(response.text)

View File

@@ -21,3 +21,13 @@ MIN_RESIZE_RATIO = 0.75
# Minimum height and width for input image for TexTeller # Minimum height and width for input image for TexTeller
MIN_HEIGHT = 12 MIN_HEIGHT = 12
MIN_WIDTH = 30 MIN_WIDTH = 30
LATEX_DET_MODEL_URL = (
"https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx"
)
TEXT_REC_MODEL_URL = (
"https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_server_rec.onnx"
)
TEXT_DET_MODEL_URL = (
"https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_det.onnx"
)

41
texteller/globals.py Normal file
View File

@@ -0,0 +1,41 @@
import logging
from pathlib import Path
class Globals:
"""
Singleton class for managing global variables with predefined and dynamic attributes.
Usage Example:
>>> # 1. Access predefined variable (with default value)
>>> print(Globals().repo_name) # Output: OleehyO/TexTeller
>>> # 2. Modify predefined variable
>>> Globals().repo_name = "NewRepo/NewProject"
>>> print(Globals().repo_name) # Output: NewRepo/NewProject
>>> # 3. Dynamically add new variable
>>> Globals().new_var = "hello"
>>> print(Globals().new_var) # Output: hello
>>> # 4. View all variables
>>> print(Globals()) # Output: <Globals: {'repo_name': ..., 'new_var': ...}>
"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.repo_name = "OleehyO/TexTeller"
self.logging_level = logging.INFO
self.cache_dir = Path("~/.cache/texteller").expanduser().resolve()
self.__class__._initialized = True
def __repr__(self):
return f"<Globals: {self.__dict__}>"

View File

@@ -1,96 +0,0 @@
import os
import argparse
import glob
import subprocess
import onnxruntime
from pathlib import Path
from models.det_model.inference import PredictConfig, predict_image
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--infer_cfg", type=str, help="infer_cfg.yml", default="./models/det_model/model/infer_cfg.yml"
)
parser.add_argument(
'--onnx_file',
type=str,
help="onnx model file path",
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
)
parser.add_argument("--image_dir", type=str, default='./testImgs')
parser.add_argument("--image_file", type=str)
parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
parser.add_argument(
'--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True
)
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert (
infer_img is not None or infer_dir is not None
), "--image_file or --image_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), "{} is not a file".format(infer_img)
assert infer_dir is None or os.path.isdir(infer_dir), "{} is not a directory".format(infer_dir)
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
return [infer_img]
images = set()
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), "infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for ext in exts:
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
images = list(images)
assert len(images) > 0, "no image found in {}".format(infer_dir)
print("Found {} inference images in total.".format(len(images)))
return images
def download_file(url, filename):
print(f"Downloading {filename}...")
subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True)
print("Download complete.")
if __name__ == '__main__':
cur_path = os.getcwd()
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
FLAGS = parser.parse_args()
if not os.path.exists(FLAGS.infer_cfg):
infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true"
download_file(infer_cfg_url, FLAGS.infer_cfg)
if not os.path.exists(FLAGS.onnx_file):
onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true"
download_file(onnx_file_url, FLAGS.onnx_file)
# load image list
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
if FLAGS.use_gpu:
predictor = onnxruntime.InferenceSession(
FLAGS.onnx_file, providers=['CUDAExecutionProvider']
)
else:
predictor = onnxruntime.InferenceSession(
FLAGS.onnx_file, providers=['CPUExecutionProvider']
)
# load infer config
infer_config = PredictConfig(FLAGS.infer_cfg)
predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list)
os.chdir(cur_path)

View File

@@ -1,81 +0,0 @@
import os
import argparse
import cv2 as cv
from pathlib import Path
from onnxruntime import InferenceSession
from models.thrid_party.paddleocr.infer import predict_det, predict_rec
from models.thrid_party.paddleocr.infer import utility
from models.utils import mix_inference
from models.ocr_model.utils.to_katex import to_katex
from models.ocr_model.utils.inference import inference as latex_inference
from models.ocr_model.model.TexTeller import TexTeller
from models.det_model.inference import PredictConfig
if __name__ == '__main__':
os.chdir(Path(__file__).resolve().parent)
parser = argparse.ArgumentParser()
parser.add_argument('-img', type=str, required=True, help='path to the input image')
parser.add_argument(
'--inference-mode',
type=str,
default='cpu',
help='Inference mode, select one of cpu, cuda, or mps',
)
parser.add_argument(
'--num-beam', type=int, default=1, help='number of beam search for decoding'
)
parser.add_argument('-mix', action='store_true', help='use mix mode')
args = parser.parse_args()
# You can use your own checkpoint and tokenizer path.
print('Loading model and tokenizer...')
latex_rec_model = TexTeller.from_pretrained()
tokenizer = TexTeller.get_tokenizer()
print('Model and tokenizer loaded.')
img_path = args.img
img = cv.imread(img_path)
print('Inference...')
if not args.mix:
res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam)
res = to_katex(res[0])
print(res)
else:
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
use_gpu = args.inference_mode == 'cuda'
SIZE_LIMIT = 20 * 1024 * 1024
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
det_use_gpu = False
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
paddleocr_args = utility.parse_args()
paddleocr_args.use_onnx = True
paddleocr_args.det_model_dir = det_model_dir
paddleocr_args.rec_model_dir = rec_model_dir
paddleocr_args.use_gpu = det_use_gpu
detector = predict_det.TextDetector(paddleocr_args)
paddleocr_args.use_gpu = rec_use_gpu
recognizer = predict_rec.TextRecognizer(paddleocr_args)
lang_ocr_models = [detector, recognizer]
latex_rec_models = [latex_rec_model, tokenizer]
res = mix_inference(
img_path,
infer_config,
latex_det_model,
lang_ocr_models,
latex_rec_models,
args.inference_mode,
args.num_beam,
)
print(res)

96
texteller/logger.py Normal file
View File

@@ -0,0 +1,96 @@
import inspect
import logging
import os
from datetime import datetime
from logging import Logger
import colorama
from colorama import Fore, Style
from texteller.globals import Globals
# Initialize colorama for colored console output
colorama.init(autoreset=True)
TEMPLATE = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
class ColoredFormatter(logging.Formatter):
"""Custom formatter to add colors based on log level."""
FORMATS = { # noqa: E501
logging.DEBUG: Fore.LIGHTBLACK_EX + TEMPLATE + Style.RESET_ALL,
logging.INFO: Fore.WHITE + TEMPLATE + Style.RESET_ALL,
logging.WARNING: Fore.YELLOW + TEMPLATE + Style.RESET_ALL,
logging.ERROR: Fore.RED + TEMPLATE + Style.RESET_ALL,
logging.CRITICAL: Fore.RED + Style.BRIGHT + TEMPLATE + Style.RESET_ALL,
} # noqa: E501
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno, self.FORMATS[logging.INFO])
formatter = logging.Formatter(log_fmt, datefmt="%Y-%m-%d %H:%M:%S")
return formatter.format(record)
def get_logger(name: str | None = None, use_file_handler: bool = False) -> Logger:
"""
Creates and configures a logger with the caller's module name (if provided) or the first two modules.
If the module name is too long, it takes the first two modules.
Args:
name (str, optional): Custom logger name. If None, derives from caller's module.
use_file_handler (bool, optional): Whether to use a file handler. Defaults to False.
Returns:
Logger: Configured logger with colored console output and file handler.
"""
# If name is not provided, derive it from the caller's module
if name is None:
# Get the caller's stack frame
frame = inspect.stack()[1]
module = inspect.getmodule(frame[0])
if module and module.__name__:
module_name = module.__name__
# Split module name and take first two components if too long
parts = module_name.split(".")
if len(parts) > 2:
name = ".".join(parts[:2])
else:
name = module_name
else:
name = "root"
# Create or get logger
logger = logging.getLogger(name)
# Prevent duplicate handlers
if logger.handlers:
return logger
# Set logger level
logger.setLevel(Globals().logging_level)
# Create console handler with colored formatter
console_handler = logging.StreamHandler()
console_handler.setLevel(Globals().logging_level)
console_formatter = ColoredFormatter()
console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler)
# Create file handler
if use_file_handler:
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"{datetime.now().strftime('%Y%m%d')}.log")
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(Globals().logging_level)
# File formatter (no colors)
file_formatter = logging.Formatter(TEMPLATE, datefmt="%Y-%m-%d %H:%M:%S")
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
# Prevent logger from propagating to root logger
logger.propagate = False
return logger

View File

@@ -0,0 +1,3 @@
from .texteller import TexTeller
__all__ = ['TexTeller']

View File

@@ -1,226 +0,0 @@
import os
import time
import yaml
import numpy as np
import cv2
from tqdm import tqdm
from typing import List
from .preprocess import Compose
from .Bbox import Bbox
# Global dictionary
SUPPORT_MODELS = {
'YOLO',
'PPYOLOE',
'RCNN',
'SSD',
'Face',
'FCOS',
'SOLOv2',
'TTFNet',
'S2ANet',
'JDE',
'FairMOT',
'DeepSORT',
'GFL',
'PicoDet',
'CenterNet',
'TOOD',
'RetinaNet',
'StrongBaseline',
'STGCN',
'YOLOX',
'HRNet',
'DETR',
}
class PredictConfig(object):
"""set config of preprocess, postprocess and visualize
Args:
infer_config (str): path of infer_cfg.yml
"""
def __init__(self, infer_config):
# parsing Yaml config for Preprocess
with open(infer_config) as f:
yml_conf = yaml.safe_load(f)
self.check_model(yml_conf)
self.arch = yml_conf['arch']
self.preprocess_infos = yml_conf['Preprocess']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.label_list = yml_conf['label_list']
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
self.mask = yml_conf.get("mask", False)
self.tracker = yml_conf.get("tracker", None)
self.nms = yml_conf.get("NMS", None)
self.fpn_stride = yml_conf.get("fpn_stride", None)
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
self.colors = {
label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)
}
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
print('The RCNN export model is used for ONNX and it only supports batch_size = 1')
self.print_config()
def check_model(self, yml_conf):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for support_model in SUPPORT_MODELS:
if support_model in yml_conf['arch']:
return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf['arch'], SUPPORT_MODELS))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def draw_bbox(image, outputs, infer_config):
for output in outputs:
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
color = infer_config.colors[label]
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
cv2.putText(
image,
"{}: {:.2f}".format(label, score),
(int(xmin), int(ymin - 5)),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
color,
2,
)
return image
def predict_image(imgsave_dir, infer_config, predictor, img_list):
# load preprocess transforms
transforms = Compose(infer_config.preprocess_infos)
errImgList = []
# Check and create subimg_save_dir if not exist
subimg_save_dir = os.path.join(imgsave_dir, 'subimages')
os.makedirs(subimg_save_dir, exist_ok=True)
first_image_skipped = False
total_time = 0
num_images = 0
# predict image
for img_path in tqdm(img_list):
img = cv2.imread(img_path)
if img is None:
print(f"Warning: Could not read image {img_path}. Skipping...")
errImgList.append(img_path)
continue
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None,] for k in inputs_name}
# Start timing
start_time = time.time()
outputs = predictor.run(output_names=None, input_feed=inputs)
# Stop timing
end_time = time.time()
inference_time = end_time - start_time
if not first_image_skipped:
first_image_skipped = True
else:
total_time += inference_time
num_images += 1
print(
f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds"
)
print("ONNXRuntime predict: ")
if infer_config.arch in ["HRNet"]:
print(np.array(outputs[0]))
else:
bboxes = np.array(outputs[0])
for bbox in bboxes:
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
print(f"{int(bbox[0])} {bbox[1]} " f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
# Save the subimages (crop from the original image)
subimg_counter = 1
for output in np.array(outputs[0]):
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
subimg = img[int(max(ymin, 0)) : int(ymax), int(max(xmin, 0)) : int(xmax)]
if len(subimg) == 0:
continue
subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg"
subimg_path = os.path.join(subimg_save_dir, subimg_filename)
cv2.imwrite(subimg_path, subimg)
subimg_counter += 1
# Draw bounding boxes and save the image with bounding boxes
img_with_mask = img.copy()
for output in np.array(outputs[0]):
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
cv2.rectangle(
img_with_mask,
(int(xmin), int(ymin)),
(int(xmax), int(ymax)),
(255, 255, 255),
-1,
) # 盖白
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
output_dir = imgsave_dir
os.makedirs(output_dir, exist_ok=True)
draw_box_dir = os.path.join(output_dir, 'draw_box')
mask_white_dir = os.path.join(output_dir, 'mask_white')
os.makedirs(draw_box_dir, exist_ok=True)
os.makedirs(mask_white_dir, exist_ok=True)
output_file_mask = os.path.join(mask_white_dir, os.path.basename(img_path))
output_file_bbox = os.path.join(draw_box_dir, os.path.basename(img_path))
cv2.imwrite(output_file_mask, img_with_mask)
cv2.imwrite(output_file_bbox, img_with_bbox)
avg_time_per_image = total_time / num_images if num_images > 0 else 0
print(f"Total inference time for {num_images} images: {total_time:.4f} seconds")
print(f"Average time per image: {avg_time_per_image:.4f} seconds")
print("ErrorImgs:")
print(errImgList)
def predict(img_path: str, predictor, infer_config) -> List[Bbox]:
transforms = Compose(infer_config.preprocess_infos)
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None,] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
res = []
for output in outputs:
cls_name = infer_config.label_list[int(output[0])]
score = output[1]
xmin = int(max(output[2], 0))
ymin = int(max(output[3], 0))
xmax = int(output[4])
ymax = int(output[5])
if score > infer_config.draw_threshold:
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
return res

View File

@@ -1,27 +0,0 @@
mode: paddle
draw_threshold: 0.5
metric: COCO
use_dynamic_shape: false
arch: DETR
min_subgraph_size: 3
Preprocess:
- interp: 2
keep_ratio: false
target_size:
- 1600
- 1600
type: Resize
- mean:
- 0.0
- 0.0
- 0.0
norm_type: none
std:
- 1.0
- 1.0
- 1.0
type: NormalizeImage
- type: Permute
label_list:
- isolated
- embedding

View File

@@ -1,485 +0,0 @@
import numpy as np
import cv2
import copy
def decode_image(img_path):
if isinstance(img_path, str):
with open(img_path, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
else:
assert isinstance(img_path, np.ndarray)
data = img_path
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
img_info = {
"im_shape": np.array(im.shape[:2], dtype=np.float32),
"scale_factor": np.array([1.0, 1.0], dtype=np.float32),
}
return im, img_info
class Resize(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array([im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class NormalizeImage(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_scale:
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
class Permute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(
self,
):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class PadStride(object):
"""padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride <= 0:
return im, im_info
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
class LetterBoxResize(object):
def __init__(self, target_size):
"""
Resize image to target size, convert normalized xywh to pixel xyxy
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
Args:
target_size (int|list): image target size.
"""
super(LetterBoxResize, self).__init__()
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
# letterbox: resize a rectangular image to a padded rectangular
shape = img.shape[:2] # [height, width]
ratio_h = float(height) / shape[0]
ratio_w = float(width) / shape[1]
ratio = min(ratio_h, ratio_w)
new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) # [width, height]
padw = (width - new_shape[0]) / 2
padh = (height - new_shape[1]) / 2
top, bottom = round(padh - 0.1), round(padh + 0.1)
left, right = round(padw - 0.1), round(padw + 0.1)
img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
) # padded rectangular
return img, ratio, padw, padh
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
height, width = self.target_size
h, w = im.shape[:2]
im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
new_shape = [round(h * ratio), round(w * ratio)]
im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
return im, im_info
class Pad(object):
def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
"""
Pad image to a specified size.
Args:
size (list[int]): image target size
fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
"""
super(Pad, self).__init__()
if isinstance(size, int):
size = [size, size]
self.size = size
self.fill_value = fill_value
def __call__(self, im, im_info):
im_h, im_w = im.shape[:2]
h, w = self.size
if h == im_h and w == im_w:
im = im.astype(np.float32)
return im, im_info
canvas = np.ones((h, w, 3), dtype=np.float32)
canvas *= np.array(self.fill_value, dtype=np.float32)
canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
im = canvas
return im, im_info
def rotate_point(pt, angle_rad):
"""Rotate a point by an angle.
Args:
pt (list[float]): 2 dimensional point to be rotated
angle_rad (float): rotation angle by radian
Returns:
list[float]: Rotated point.
"""
assert len(pt) == 2
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
new_x = pt[0] * cs - pt[1] * sn
new_y = pt[0] * sn + pt[1] * cs
rotated_pt = [new_x, new_y]
return rotated_pt
def _get_3rd_point(a, b):
"""To calculate the affine matrix, three pairs of points are required. This
function is used to get the 3rd point, given 2D points a & b.
The 3rd point is defined by rotating vector `a - b` by 90 degrees
anticlockwise, using b as the rotation center.
Args:
a (np.ndarray): point(x,y)
b (np.ndarray): point(x,y)
Returns:
np.ndarray: The 3rd point.
"""
assert len(a) == 2
assert len(b) == 2
direction = a - b
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
return third_pt
def get_affine_transform(center, input_size, rot, output_size, shift=(0.0, 0.0), inv=False):
"""Get the affine transform matrix, given the center/scale/rot/output_size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
scale (np.ndarray[2, ]): Scale of the bounding box
wrt [width, height].
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: The transform matrix.
"""
assert len(center) == 2
assert len(output_size) == 2
assert len(shift) == 2
if not isinstance(input_size, (np.ndarray, list)):
input_size = np.array([input_size, input_size], dtype=np.float32)
scale_tmp = input_size
shift = np.array(shift)
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = rotate_point([0.0, src_w * -0.5], rot_rad)
dst_dir = np.array([0.0, dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
dst = np.zeros((3, 2), dtype=np.float32)
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
class WarpAffine(object):
"""Warp affine the image"""
def __init__(self, keep_res=False, pad=31, input_h=512, input_w=512, scale=0.4, shift=0.1):
self.keep_res = keep_res
self.pad = pad
self.input_h = input_h
self.input_w = input_w
self.scale = scale
self.shift = shift
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
h, w = img.shape[:2]
if self.keep_res:
input_h = (h | self.pad) + 1
input_w = (w | self.pad) + 1
s = np.array([input_w, input_h], dtype=np.float32)
c = np.array([w // 2, h // 2], dtype=np.float32)
else:
s = max(h, w) * 1.0
input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2.0, h / 2.0], dtype=np.float32)
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
img = cv2.resize(img, (w, h))
inp = cv2.warpAffine(img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
return inp, im_info
# keypoint preprocess
def get_warp_matrix(theta, size_input, size_dst, size_target):
"""This code is based on
https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
Calculate the transformation matrix under the constraint of unbiased.
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
Data Processing for Human Pose Estimation (CVPR 2020).
Args:
theta (float): Rotation angle in degrees.
size_input (np.ndarray): Size of input image [w, h].
size_dst (np.ndarray): Size of output image [w, h].
size_target (np.ndarray): Size of ROI in input plane [w, h].
Returns:
matrix (np.ndarray): A matrix for transformation.
"""
theta = np.deg2rad(theta)
matrix = np.zeros((2, 3), dtype=np.float32)
scale_x = size_dst[0] / size_target[0]
scale_y = size_dst[1] / size_target[1]
matrix[0, 0] = np.cos(theta) * scale_x
matrix[0, 1] = -np.sin(theta) * scale_x
matrix[0, 2] = scale_x * (
-0.5 * size_input[0] * np.cos(theta)
+ 0.5 * size_input[1] * np.sin(theta)
+ 0.5 * size_target[0]
)
matrix[1, 0] = np.sin(theta) * scale_y
matrix[1, 1] = np.cos(theta) * scale_y
matrix[1, 2] = scale_y * (
-0.5 * size_input[0] * np.sin(theta)
- 0.5 * size_input[1] * np.cos(theta)
+ 0.5 * size_target[1]
)
return matrix
class TopDownEvalAffine(object):
"""apply affine transform to image and coords
Args:
trainsize (list): [w, h], the standard size used to train
use_udp (bool): whether to use Unbiased Data Processing.
records(dict): the dict contained the image and coords
Returns:
records (dict): contain the image and coords after tranformed
"""
def __init__(self, trainsize, use_udp=False):
self.trainsize = trainsize
self.use_udp = use_udp
def __call__(self, image, im_info):
rot = 0
imshape = im_info['im_shape'][::-1]
center = im_info['center'] if 'center' in im_info else imshape / 2.0
scale = im_info['scale'] if 'scale' in im_info else imshape
if self.use_udp:
trans = get_warp_matrix(
rot, center * 2.0, [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale
)
image = cv2.warpAffine(
image,
trans,
(int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR,
)
else:
trans = get_affine_transform(center, scale, rot, self.trainsize)
image = cv2.warpAffine(
image,
trans,
(int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR,
)
return image, im_info
class Compose:
def __init__(self, transforms):
self.transforms = []
for op_info in transforms:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
self.transforms.append(eval(op_type)(**new_op_info))
def __call__(self, img_path):
img, im_info = decode_image(img_path)
for t in self.transforms:
img, im_info = t(img, im_info)
inputs = copy.deepcopy(im_info)
inputs['image'] = img
return inputs

View File

@@ -1,43 +0,0 @@
from pathlib import Path
from ...globals import VOCAB_SIZE, FIXED_IMG_SIZE, IMG_CHANNELS, MAX_TOKEN_SIZE
from transformers import RobertaTokenizerFast, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
class TexTeller(VisionEncoderDecoderModel):
REPO_NAME = 'OleehyO/TexTeller'
def __init__(self):
config = VisionEncoderDecoderConfig.from_pretrained(
Path(__file__).resolve().parent / "config.json"
)
config.encoder.image_size = FIXED_IMG_SIZE
config.encoder.num_channels = IMG_CHANNELS
config.decoder.vocab_size = VOCAB_SIZE
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
super().__init__(config=config)
@classmethod
def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
if model_path is None or model_path == 'default':
if not use_onnx:
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
else:
from optimum.onnxruntime import ORTModelForVision2Seq
use_gpu = True if onnx_provider == 'cuda' else False
return ORTModelForVision2Seq.from_pretrained(
cls.REPO_NAME,
provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider",
)
model_path = Path(model_path).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
@classmethod
def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast:
if tokenizer_path is None or tokenizer_path == 'default':
return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME)
tokenizer_path = Path(tokenizer_path).resolve()
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))

View File

@@ -1,168 +0,0 @@
{
"_name_or_path": "OleehyO/TexTeller",
"architectures": [
"VisionEncoderDecoderModel"
],
"decoder": {
"_name_or_path": "",
"activation_dropout": 0.0,
"activation_function": "gelu",
"add_cross_attention": true,
"architectures": null,
"attention_dropout": 0.0,
"bad_words_ids": null,
"begin_suppress_tokens": null,
"bos_token_id": 0,
"chunk_size_feed_forward": 0,
"classifier_dropout": 0.0,
"cross_attention_hidden_size": 768,
"d_model": 1024,
"decoder_attention_heads": 16,
"decoder_ffn_dim": 4096,
"decoder_layerdrop": 0.0,
"decoder_layers": 12,
"decoder_start_token_id": 2,
"diversity_penalty": 0.0,
"do_sample": false,
"dropout": 0.1,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"init_std": 0.02,
"is_decoder": true,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layernorm_embedding": true,
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 1024,
"min_length": 0,
"model_type": "trocr",
"no_repeat_ngram_size": 0,
"num_beam_groups": 1,
"num_beams": 1,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": 1,
"prefix": null,
"problem_type": null,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"scale_embedding": false,
"sep_token_id": null,
"suppress_tokens": null,
"task_specific_params": null,
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"typical_p": 1.0,
"use_bfloat16": false,
"use_cache": false,
"use_learned_position_embeddings": true,
"vocab_size": 15000
},
"encoder": {
"_name_or_path": "",
"add_cross_attention": false,
"architectures": null,
"attention_probs_dropout_prob": 0.0,
"bad_words_ids": null,
"begin_suppress_tokens": null,
"bos_token_id": null,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": null,
"decoder_start_token_id": null,
"diversity_penalty": 0.0,
"do_sample": false,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"encoder_stride": 16,
"eos_token_id": null,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 768,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"image_size": 448,
"initializer_range": 0.02,
"intermediate_size": 3072,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-12,
"length_penalty": 1.0,
"max_length": 20,
"min_length": 0,
"model_type": "vit",
"no_repeat_ngram_size": 0,
"num_attention_heads": 12,
"num_beam_groups": 1,
"num_beams": 1,
"num_channels": 1,
"num_hidden_layers": 12,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": null,
"patch_size": 16,
"prefix": null,
"problem_type": null,
"pruned_heads": {},
"qkv_bias": false,
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"suppress_tokens": null,
"task_specific_params": null,
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"typical_p": 1.0,
"use_bfloat16": false
},
"is_encoder_decoder": true,
"model_type": "vision-encoder-decoder",
"tie_word_embeddings": false,
"transformers_version": "4.41.2",
"use_cache": true
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.2 KiB

View File

@@ -1,35 +0,0 @@
{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}

View File

@@ -1,114 +0,0 @@
import os
from functools import partial
from pathlib import Path
from datasets import load_dataset
from transformers import (
Trainer,
TrainingArguments,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
GenerationConfig,
)
from .training_args import CONFIG
from ..model.TexTeller import TexTeller
from ..utils.functional import (
tokenize_fn,
collate_fn,
img_train_transform,
img_inf_transform,
filter_fn,
)
from ..utils.metrics import bleu_metric
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
training_args = TrainingArguments(**CONFIG)
trainer = Trainer(
model,
training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=collate_fn_with_tokenizer,
)
trainer.train(resume_from_checkpoint=None)
def evaluate(model, tokenizer, eval_dataset, collate_fn):
eval_config = CONFIG.copy()
eval_config['predict_with_generate'] = True
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
eval_config['generation_config'] = generate_config
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
trainer = Seq2SeqTrainer(
model,
seq2seq_config,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=collate_fn,
compute_metrics=partial(bleu_metric, tokenizer=tokenizer),
)
eval_res = trainer.evaluate()
print(eval_res)
if __name__ == '__main__':
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
# dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
dataset = load_dataset("imagefolder", data_dir=str(script_dirpath / 'dataset'))['train']
dataset = dataset.filter(
lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH
)
dataset = dataset.shuffle(seed=42)
dataset = dataset.flatten_indices()
tokenizer = TexTeller.get_tokenizer()
# If you want use your own tokenizer, please modify the path to your tokenizer
# +tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8)
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(
map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8
)
# Split dataset into train and eval, ratio 9:1
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
train_dataset = train_dataset.with_transform(img_train_transform)
eval_dataset = eval_dataset.with_transform(img_inf_transform)
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# Train from scratch
model = TexTeller()
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
# +e.g.
# +model = TexTeller.from_pretrained(
# + '/path/to/your/model_checkpoint'
# +)
enable_train = True
enable_evaluate = False
if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
if enable_evaluate and len(eval_dataset) > 0:
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)

View File

@@ -1,31 +0,0 @@
CONFIG = {
"seed": 42, # Random seed for reproducibility
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
"learning_rate": 5e-5, # Learning rate
"num_train_epochs": 10, # Total number of training epochs
"per_device_train_batch_size": 4, # Batch size per GPU for training
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
"output_dir": "train_result", # Output directory
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
"report_to": ["tensorboard"], # Report logs to TensorBoard
"save_strategy": "steps", # Strategy to save checkpoints
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
"logging_strategy": "steps", # Log every certain number of steps
"logging_steps": 500, # Number of steps between each log
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
"optim": "adamw_torch", # Optimizer
"lr_scheduler_type": "cosine", # Learning rate scheduler
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
"eval_steps": 500, # If evaluation_strategy="step"
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
}

View File

@@ -1,60 +0,0 @@
import torch
from transformers import DataCollatorForLanguageModeling
from typing import List, Dict, Any
from .transforms import train_transform, inference_transform
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
def left_move(x: torch.Tensor, pad_val):
assert len(x.shape) == 2, 'x should be 2-dimensional'
lefted_x = torch.ones_like(x)
lefted_x[:, :-1] = x[:, 1:]
lefted_x[:, -1] = pad_val
return lefted_x
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
assert tokenizer is not None, 'tokenizer should not be None'
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
tokenized_formula['pixel_values'] = samples['image']
return tokenized_formula
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
assert tokenizer is not None, 'tokenizer should not be None'
pixel_values = [dic.pop('pixel_values') for dic in samples]
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
batch = clm_collator(samples)
batch['pixel_values'] = pixel_values
batch['decoder_input_ids'] = batch.pop('input_ids')
batch['decoder_attention_mask'] = batch.pop('attention_mask')
# 左移labels和decoder_attention_mask
batch['labels'] = left_move(batch['labels'], -100)
# 把list of Image转成一个tensor with (B, C, H, W)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
return batch
def img_train_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = train_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img
return samples
def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = inference_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img
return samples
def filter_fn(sample, tokenizer=None) -> bool:
return (
sample['image'].height > MIN_HEIGHT
and sample['image'].width > MIN_WIDTH
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
)

View File

@@ -1,26 +0,0 @@
import cv2
import numpy as np
from typing import List
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
processed_images = []
for path in image_paths:
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
print(f"Image at {path} could not be read.")
continue
if image.dtype == np.uint16:
print(f'Converting {path} to 8-bit, image may be lossy.')
image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
channels = 1 if len(image.shape) == 2 else image.shape[2]
if channels == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
elif channels == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(image)
return processed_images

View File

@@ -1,25 +0,0 @@
import evaluate
import numpy as np
import os
from pathlib import Path
from typing import Dict
from transformers import EvalPrediction, RobertaTokenizer
def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict:
cur_dir = Path(os.getcwd())
os.chdir(Path(__file__).resolve().parent)
metric = evaluate.load(
'google_bleu'
) # Will download the metric from huggingface if not already downloaded
os.chdir(cur_dir)
logits, labels = eval_preds.predictions, eval_preds.label_ids
preds = logits
labels = np.where(labels == -100, 1, labels)
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
return metric.compute(predictions=preds, references=labels)

View File

@@ -1,152 +0,0 @@
from augraphy import *
import random
def ocr_augmentation_pipeline():
pre_phase = []
ink_phase = [
InkColorSwap(
ink_swap_color="random",
ink_swap_sequence_number_range=(5, 10),
ink_swap_min_width_range=(2, 3),
ink_swap_max_width_range=(100, 120),
ink_swap_min_height_range=(2, 3),
ink_swap_max_height_range=(100, 120),
ink_swap_min_area_range=(10, 20),
ink_swap_max_area_range=(400, 500),
# p=0.2
p=0.4,
),
LinesDegradation(
line_roi=(0.0, 0.0, 1.0, 1.0),
line_gradient_range=(32, 255),
line_gradient_direction=(0, 2),
line_split_probability=(0.2, 0.4),
line_replacement_value=(250, 255),
line_min_length=(30, 40),
line_long_to_short_ratio=(5, 7),
line_replacement_probability=(0.4, 0.5),
line_replacement_thickness=(1, 3),
# p=0.2
p=0.4,
),
# ============================
OneOf(
[
Dithering(
dither="floyd-steinberg",
order=(3, 5),
),
InkBleed(
intensity_range=(0.1, 0.2),
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]),
severity=(0.4, 0.6),
),
],
# p=0.2
p=0.4,
),
# ============================
# ============================
InkShifter(
text_shift_scale_range=(18, 27),
text_shift_factor_range=(1, 4),
text_fade_range=(0, 2),
blur_kernel_size=(5, 5),
blur_sigma=0,
noise_type="perlin",
# p=0.2
p=0.4,
),
# ============================
]
paper_phase = [
NoiseTexturize( # tested
sigma_range=(3, 10),
turbulence_range=(2, 5),
texture_width_range=(300, 500),
texture_height_range=(300, 500),
# p=0.2
p=0.4,
),
BrightnessTexturize( # tested
texturize_range=(0.9, 0.99),
deviation=0.03,
# p=0.2
p=0.4,
),
]
post_phase = [
ColorShift( # tested
color_shift_offset_x_range=(3, 5),
color_shift_offset_y_range=(3, 5),
color_shift_iterations=(2, 3),
color_shift_brightness_range=(0.9, 1.1),
color_shift_gaussian_kernel_range=(3, 3),
# p=0.2
p=0.4,
),
DirtyDrum( # tested
line_width_range=(1, 6),
line_concentration=random.uniform(0.05, 0.15),
direction=random.randint(0, 2),
noise_intensity=random.uniform(0.6, 0.95),
noise_value=(64, 224),
ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
sigmaX=0,
# p=0.2
p=0.4,
),
# =====================================
OneOf(
[
LightingGradient(
light_position=None,
direction=None,
max_brightness=255,
min_brightness=0,
mode="gaussian",
linear_decay_rate=None,
transparency=None,
),
Brightness(
brightness_range=(0.9, 1.1),
min_brightness=0,
min_brightness_value=(120, 150),
),
Gamma(
gamma_range=(0.9, 1.1),
),
],
# p=0.2
p=0.4,
),
# =====================================
# =====================================
OneOf(
[
SubtleNoise(
subtle_range=random.randint(5, 10),
),
Jpeg(
quality_range=(70, 95),
),
],
# p=0.2
p=0.4,
),
# =====================================
]
pipeline = AugraphyPipeline(
ink_phase=ink_phase,
paper_phase=paper_phase,
post_phase=post_phase,
pre_phase=pre_phase,
log=False,
)
return pipeline

View File

@@ -1,177 +0,0 @@
import torch
import random
import numpy as np
import cv2
from torchvision.transforms import v2
from typing import List, Union
from PIL import Image
from collections import Counter
from ...globals import (
IMG_CHANNELS,
FIXED_IMG_SIZE,
IMAGE_MEAN,
IMAGE_STD,
MAX_RESIZE_RATIO,
MIN_RESIZE_RATIO,
)
from .ocr_aug import ocr_augmentation_pipeline
# train_pipeline = default_augraphy_pipeline(scan_only=True)
train_pipeline = ocr_augmentation_pipeline()
general_transform_pipeline = v2.Compose(
[
v2.ToImage(),
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
v2.Grayscale(),
v2.Resize(
size=FIXED_IMG_SIZE - 1,
interpolation=v2.InterpolationMode.BICUBIC,
max_size=FIXED_IMG_SIZE,
antialias=True,
),
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
# v2.ToPILImage()
]
)
def trim_white_border(image: np.ndarray):
if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8")
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
bg_color = Counter(corners).most_common(1)[0][0]
bg_color_np = np.array(bg_color, dtype=np.uint8)
h, w = image.shape[:2]
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
diff = cv2.absdiff(image, bg)
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
threshold = 15
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
x, y, w, h = cv2.boundingRect(diff)
trimmed_image = image[y : y + h, x : x + w]
return trimmed_image
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
randi = [random.randint(0, max_size) for _ in range(4)]
pad_height_size = randi[1] + randi[3]
pad_width_size = randi[0] + randi[2]
if pad_height_size + image.shape[0] < 30:
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
randi[1] += compensate_height
randi[3] += compensate_height
if pad_width_size + image.shape[1] < 30:
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
randi[0] += compensate_width
randi[2] += compensate_width
return v2.functional.pad(
torch.from_numpy(image).permute(2, 0, 1),
padding=randi,
padding_mode='constant',
fill=(255, 255, 255),
)
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
images = [
v2.functional.pad(
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
)
for img in images
]
return images
def random_resize(images: List[np.ndarray], minr: float, maxr: float) -> List[np.ndarray]:
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
return [
cv2.resize(
img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4
) # 抗锯齿
for img, r in zip(images, ratios)
]
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
# Get the center of the image to define the point of rotation
image_center = tuple(np.array(image.shape[1::-1]) / 2)
# Generate a random angle within the specified range
angle = random.randint(min_angle, max_angle)
# Get the rotation matrix for rotating the image around its center
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
# Determine the size of the rotated image
cos = np.abs(rotation_mat[0, 0])
sin = np.abs(rotation_mat[0, 1])
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
# Adjust the rotation matrix to take into account translation
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
# Rotate the image with the specified border color (white in this case)
rotated_image = cv2.warpAffine(
image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255)
)
return rotated_image
def ocr_aug(image: np.ndarray) -> np.ndarray:
if random.random() < 0.2:
image = rotate(image, -5, 5)
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
image = train_pipeline(image)
return image
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
images = [np.array(img.convert('RGB')) for img in images]
# random resize first
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
images = [trim_white_border(image) for image in images]
# OCR augmentation
images = [ocr_aug(image) for image in images]
# general transform pipeline
images = [general_transform_pipeline(image) for image in images]
# padding to fixed size
images = padding(images, FIXED_IMG_SIZE)
return images
def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
images = [
np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images
]
images = [trim_white_border(image) for image in images]
# general transform pipeline
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
# padding to fixed size
images = padding(images, FIXED_IMG_SIZE)
return images

View File

@@ -0,0 +1,48 @@
from pathlib import Path
from transformers import RobertaTokenizerFast, VisionEncoderDecoderConfig, VisionEncoderDecoderModel
from texteller.constants import (
FIXED_IMG_SIZE,
IMG_CHANNELS,
MAX_TOKEN_SIZE,
VOCAB_SIZE,
)
from texteller.globals import Globals
from texteller.types import TexTellerModel
from texteller.utils import cuda_available
class TexTeller(VisionEncoderDecoderModel):
def __init__(self):
config = VisionEncoderDecoderConfig.from_pretrained(Globals().repo_name)
config.encoder.image_size = FIXED_IMG_SIZE
config.encoder.num_channels = IMG_CHANNELS
config.decoder.vocab_size = VOCAB_SIZE
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
super().__init__(config=config)
@classmethod
def from_pretrained(cls, model_dir: str | None = None, use_onnx=False) -> TexTellerModel:
if model_dir is None or model_dir == Globals().repo_name:
if not use_onnx:
return VisionEncoderDecoderModel.from_pretrained(Globals().repo_name)
else:
from optimum.onnxruntime import ORTModelForVision2Seq
return ORTModelForVision2Seq.from_pretrained(
Globals().repo_name,
provider="CUDAExecutionProvider"
if cuda_available()
else "CPUExecutionProvider",
)
model_dir = Path(model_dir).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_dir))
@classmethod
def get_tokenizer(cls, tokenizer_dir: str = None) -> RobertaTokenizerFast:
if tokenizer_dir is None or tokenizer_dir == Globals().repo_name:
return RobertaTokenizerFast.from_pretrained(Globals().repo_name)
tokenizer_dir = Path(tokenizer_dir).resolve()
return RobertaTokenizerFast.from_pretrained(str(tokenizer_dir))

View File

@@ -1,24 +0,0 @@
import os
from pathlib import Path
from datasets import load_dataset
from ..ocr_model.model.TexTeller import TexTeller
from ..globals import VOCAB_SIZE
if __name__ == '__main__':
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
tokenizer = TexTeller.get_tokenizer()
# Don't forget to config your dataset path in loader.py
dataset = load_dataset('../ocr_model/train/dataset/loader.py')['train']
new_tokenizer = tokenizer.train_new_from_iterator(
text_iterator=dataset['latex_formula'],
# If you want to use a different vocab size, **change VOCAB_SIZE from globals.py**
vocab_size=VOCAB_SIZE,
)
# Save the new tokenizer for later training and inference
new_tokenizer.save_pretrained('./your_dir_name')

View File

@@ -1 +0,0 @@
from .mix_inference import mix_inference

View File

@@ -1,261 +0,0 @@
import re
import heapq
import cv2
import time
import numpy as np
from collections import Counter
from typing import List
from PIL import Image
from ..det_model.inference import predict as latex_det_predict
from ..det_model.Bbox import Bbox, draw_bboxes
from ..ocr_model.utils.inference import inference as latex_rec_predict
from ..ocr_model.utils.to_katex import to_katex, change_all
MAXV = 999999999
def mask_img(img, bboxes: List[Bbox], bg_color: np.ndarray) -> np.ndarray:
mask_img = img.copy()
for bbox in bboxes:
mask_img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w] = bg_color
return mask_img
def bbox_merge(sorted_bboxes: List[Bbox]) -> List[Bbox]:
if len(sorted_bboxes) == 0:
return []
bboxes = sorted_bboxes.copy()
guard = Bbox(MAXV, bboxes[-1].p.y, -1, -1, label="guard")
bboxes.append(guard)
res = []
prev = bboxes[0]
for curr in bboxes:
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
res.append(prev)
prev = curr
else:
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
return res
def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbox]:
if latex_bboxes == []:
return ocr_bboxes
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
return ocr_bboxes
bboxes = sorted(ocr_bboxes + latex_bboxes)
# log results
for idx, bbox in enumerate(bboxes):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png")
assert len(bboxes) > 1
heapq.heapify(bboxes)
res = []
candidate = heapq.heappop(bboxes)
curr = heapq.heappop(bboxes)
idx = 0
while len(bboxes) > 0:
idx += 1
assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x < curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text" and curr.label == "text":
candidate.w = curr.ur_point.x - candidate.p.x
curr = heapq.heappop(bboxes)
elif candidate.label != curr.label:
if candidate.label == "text":
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
curr.w = curr.ur_point.x - candidate.ur_point.x
curr.p.x = candidate.ur_point.x
heapq.heappush(bboxes, curr)
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x >= curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text":
assert curr.label != "text"
heapq.heappush(
bboxes,
Bbox(
curr.ur_point.x,
candidate.p.y,
candidate.h,
candidate.ur_point.x - curr.ur_point.x,
label="text",
confidence=candidate.confidence,
content=None,
),
)
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
assert curr.label == "text"
curr = heapq.heappop(bboxes)
else:
assert False
res.append(candidate)
res.append(curr)
# log results
for idx, bbox in enumerate(res):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png")
return res
def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray]:
sliced_imgs = []
for bbox in ocr_bboxes:
x, y = int(bbox.p.x), int(bbox.p.y)
w, h = int(bbox.w), int(bbox.h)
sliced_img = img[y : y + h, x : x + w]
sliced_imgs.append(sliced_img)
return sliced_imgs
def mix_inference(
img_path: str,
infer_config,
latex_det_model,
lang_ocr_models,
latex_rec_models,
accelerator="cpu",
num_beams=1,
) -> str:
'''
Input a mixed image of formula text and output str (in markdown syntax)
'''
global img
img = cv2.imread(img_path)
corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])]
bg_color = np.array(Counter(corners).most_common(1)[0][0])
start_time = time.time()
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
end_time = time.time()
print(f"latex_det_model time: {end_time - start_time:.2f}s")
latex_bboxes = sorted(latex_bboxes)
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
latex_bboxes = bbox_merge(latex_bboxes)
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
masked_img = mask_img(img, latex_bboxes, bg_color)
det_model, rec_model = lang_ocr_models
start_time = time.time()
det_prediction, _ = det_model(masked_img)
end_time = time.time()
print(f"ocr_det_model time: {end_time - start_time:.2f}s")
ocr_bboxes = [
Bbox(
p[0][0],
p[0][1],
p[3][1] - p[0][1],
p[1][0] - p[0][0],
label="text",
confidence=None,
content=None,
)
for p in det_prediction
]
# log results
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(unmerged).png")
ocr_bboxes = sorted(ocr_bboxes)
ocr_bboxes = bbox_merge(ocr_bboxes)
# log results
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
start_time = time.time()
rec_predictions, _ = rec_model(sliced_imgs)
end_time = time.time()
print(f"ocr_rec_model time: {end_time - start_time:.2f}s")
assert len(rec_predictions) == len(ocr_bboxes)
for content, bbox in zip(rec_predictions, ocr_bboxes):
bbox.content = content[0]
latex_imgs = []
for bbox in latex_bboxes:
latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w])
start_time = time.time()
latex_rec_res = latex_rec_predict(
*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=800
)
end_time = time.time()
print(f"latex_rec_model time: {end_time - start_time:.2f}s")
for bbox, content in zip(latex_bboxes, latex_rec_res):
bbox.content = to_katex(content)
if bbox.label == "embedding":
bbox.content = " $" + bbox.content + "$ "
elif bbox.label == "isolated":
bbox.content = '\n\n' + r"$$" + bbox.content + r"$$" + '\n\n'
bboxes = sorted(ocr_bboxes + latex_bboxes)
if bboxes == []:
return ""
md = ""
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
for curr in bboxes:
# Add the formula number back to the isolated formula
if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr):
curr.content = curr.content.strip()
if curr.content.startswith('(') and curr.content.endswith(')'):
curr.content = curr.content[1:-1]
if re.search(r'\\tag\{.*\}$', md[:-4]) is not None:
# in case of multiple tag
md = md[:-5] + f', {curr.content}' + '}' + md[-4:]
else:
md = md[:-4] + f'\\tag{{{curr.content}}}' + md[-4:]
continue
if not prev.same_row(curr):
md += " "
if curr.label == "embedding":
# remove the bold effect from inline formulas
curr.content = change_all(curr.content, r'\bm', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\boldsymbol', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textit', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
curr.content = change_all(curr.content, r'\mathbf', r' ', r'{', r'}', r'', r' ')
# change split environment into aligned
curr.content = curr.content.replace(r'\begin{split}', r'\begin{aligned}')
curr.content = curr.content.replace(r'\end{split}', r'\end{aligned}')
# remove extra spaces (keeping only one)
curr.content = re.sub(r' +', ' ', curr.content)
assert curr.content.startswith(' $') and curr.content.endswith('$ ')
curr.content = ' $' + curr.content[2:-2].strip() + '$ '
md += curr.content
prev = curr
return md.strip()

View File

@@ -81,7 +81,7 @@ class BaseRecLabelDecode(object):
word_list = [] word_list = []
word_col_list = [] word_col_list = []
state_list = [] state_list = []
valid_col = np.where(selection == True)[0] valid_col = np.where(selection)[0]
for c_i, char in enumerate(text): for c_i, char in enumerate(text):
if "\u4e00" <= char <= "\u9fff": if "\u4e00" <= char <= "\u9fff":

View File

@@ -12,25 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
os.environ["FLAGS_allocator_strategy"] = "auto_growth" os.environ["FLAGS_allocator_strategy"] = "auto_growth"
import sys
import time import time
import cv2
import numpy as np import numpy as np
# import tools.infer.utility as utility from .DBPostProcess import DBPostProcess
import utility from .operators import DetResizeForTest, KeepKeys, NormalizeImage, ToCHWImage
from DBPostProcess import DBPostProcess from .utility import create_predictor, get_logger
from operators import DetResizeForTest, KeepKeys, NormalizeImage, ToCHWImage
from utility import get_logger
def transform(data, ops=None): def transform(data, ops=None):
@@ -82,7 +73,7 @@ class TextDetector(object):
self.input_tensor, self.input_tensor,
self.output_tensors, self.output_tensors,
self.config, self.config,
) = utility.create_predictor(args, "det", logger) ) = create_predictor(args, "det", logger)
assert self.use_onnx assert self.use_onnx
if self.use_onnx: if self.use_onnx:

View File

@@ -1,155 +0,0 @@
import sys
import argparse
import tempfile
import time
import numpy as np
import cv2
from pathlib import Path
from starlette.requests import Request
from ray import serve
from ray.serve.handle import DeploymentHandle
from onnxruntime import InferenceSession
from texteller.models.ocr_model.utils.inference import inference as rec_inference
from texteller.models.det_model.inference import predict as det_inference
from texteller.models.ocr_model.model.TexTeller import TexTeller
from texteller.models.det_model.inference import PredictConfig
from texteller.models.ocr_model.utils.to_katex import to_katex
PYTHON_VERSION = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
LIBPATH = Path(sys.executable).parent.parent / 'lib' / ('python' + PYTHON_VERSION) / 'site-packages'
CUDNNPATH = LIBPATH / 'nvidia' / 'cudnn' / 'lib'
parser = argparse.ArgumentParser()
parser.add_argument('-ckpt', '--checkpoint_dir', type=str)
parser.add_argument('-tknz', '--tokenizer_dir', type=str)
parser.add_argument('-port', '--server_port', type=int, default=8000)
parser.add_argument('--num_replicas', type=int, default=1)
parser.add_argument('--ncpu_per_replica', type=float, default=1.0)
parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
parser.add_argument('--inference-mode', type=str, default='cpu')
parser.add_argument('--num_beams', type=int, default=1)
parser.add_argument('-onnx', action='store_true', help='using onnx runtime')
args = parser.parse_args()
if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda':
raise ValueError("--inference-mode must be cuda or mps if ngpu_per_replica > 0")
@serve.deployment(
num_replicas=args.num_replicas,
ray_actor_options={
"num_cpus": args.ncpu_per_replica,
"num_gpus": args.ngpu_per_replica * 1.0 / 2,
},
)
class TexTellerRecServer:
def __init__(
self,
checkpoint_path: str,
tokenizer_path: str,
inf_mode: str = 'cpu',
use_onnx: bool = False,
num_beams: int = 1,
) -> None:
self.model = TexTeller.from_pretrained(
checkpoint_path, use_onnx=use_onnx, onnx_provider=inf_mode
)
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
self.inf_mode = inf_mode
self.num_beams = num_beams
if not use_onnx:
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
def predict(self, image_nparray) -> str:
return to_katex(
rec_inference(
self.model,
self.tokenizer,
[image_nparray],
accelerator=self.inf_mode,
num_beams=self.num_beams,
)[0]
)
@serve.deployment(
num_replicas=args.num_replicas,
ray_actor_options={
"num_cpus": args.ncpu_per_replica,
"num_gpus": args.ngpu_per_replica * 1.0 / 2,
"runtime_env": {"env_vars": {"LD_LIBRARY_PATH": f"{str(CUDNNPATH)}/:$LD_LIBRARY_PATH"}},
},
)
class TexTellerDetServer:
def __init__(self, inf_mode='cpu'):
self.infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
self.latex_det_model = InferenceSession(
"./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
providers=['CUDAExecutionProvider'] if inf_mode == 'cuda' else ['CPUExecutionProvider'],
)
async def predict(self, image_nparray) -> str:
with tempfile.TemporaryDirectory() as temp_dir:
img_path = f"{temp_dir}/temp_image.jpg"
cv2.imwrite(img_path, image_nparray)
latex_bboxes = det_inference(img_path, self.latex_det_model, self.infer_config)
return latex_bboxes
@serve.deployment()
class Ingress:
def __init__(self, det_server: DeploymentHandle, rec_server: DeploymentHandle) -> None:
self.det_server = det_server
self.texteller_server = rec_server
async def __call__(self, request: Request) -> str:
request_path = request.url.path
form = await request.form()
img_rb = await form['img'].read()
img_nparray = np.frombuffer(img_rb, np.uint8)
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
if request_path.startswith("/fdet"):
if self.det_server is None:
return "[ERROR] rtdetr_r50vd_6x_coco.onnx not found."
pred = await self.det_server.predict.remote(img_nparray)
return pred
elif request_path.startswith("/frec"):
pred = await self.texteller_server.predict.remote(img_nparray)
return pred
else:
return "[ERROR] Invalid request path"
if __name__ == '__main__':
ckpt_dir = args.checkpoint_dir
tknz_dir = args.tokenizer_dir
serve.start(http_options={"host": "0.0.0.0", "port": args.server_port})
rec_server = TexTellerRecServer.bind(
ckpt_dir,
tknz_dir,
inf_mode=args.inference_mode,
use_onnx=args.onnx,
num_beams=args.num_beams,
)
det_server = None
if Path('./models/det_model/model/rtdetr_r50vd_6x_coco.onnx').exists():
det_server = TexTellerDetServer.bind(args.inference_mode)
ingress = Ingress.bind(det_server, rec_server)
# ingress_handle = serve.run(ingress, route_prefix="/predict")
ingress_handle = serve.run(ingress, route_prefix="/")
while True:
time.sleep(1)

View File

@@ -1,9 +0,0 @@
@echo off
SETLOCAL ENABLEEXTENSIONS
set CHECKPOINT_DIR=default
set TOKENIZER_DIR=default
streamlit run web.py
ENDLOCAL

View File

@@ -1,7 +0,0 @@
#!/usr/bin/env bash
set -exu
export CHECKPOINT_DIR="default"
export TOKENIZER_DIR="default"
streamlit run web.py

View File

@@ -1,14 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
gpu_ids: all
num_processes: 1
machine_rank: 0
main_training_function: main
num_machines: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,12 @@
from typing import TypeAlias
from optimum.onnxruntime import ORTModelForVision2Seq
from transformers import VisionEncoderDecoderModel
from .bbox import Bbox
TexTellerModel: TypeAlias = VisionEncoderDecoderModel | ORTModelForVision2Seq
__all__ = ["Bbox", "TexTellerModel"]

View File

@@ -1,10 +1,3 @@
import os
from PIL import Image, ImageDraw
from typing import List
from pathlib import Path
class Point: class Point:
def __init__(self, x: int, y: int): def __init__(self, x: int, y: int):
self.x = int(x) self.x = int(x)
@@ -51,9 +44,9 @@ class Bbox:
return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD
def __lt__(self, other) -> bool: def __lt__(self, other) -> bool:
''' """
from top to bottom, from left to right from top to bottom, from left to right
''' """
if not self.same_row(other): if not self.same_row(other):
return self.p.y < other.p.y return self.p.y < other.p.y
else: else:
@@ -61,29 +54,3 @@ class Bbox:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})" return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})"
def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"):
curr_work_dir = Path(os.getcwd())
log_dir = curr_work_dir / "logs"
log_dir.mkdir(exist_ok=True)
drawer = ImageDraw.Draw(img)
for bbox in bboxes:
# Calculate the coordinates for the rectangle to be drawn
left = bbox.p.x
top = bbox.p.y
right = bbox.p.x + bbox.w
bottom = bbox.p.y + bbox.h
# Draw the rectangle on the image
drawer.rectangle([left, top, right, bottom], outline="green", width=1)
# Optionally, add text label if it exists
if bbox.label:
drawer.text((left, top), bbox.label, fill="blue")
if bbox.content:
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
# Save the image with drawn rectangles
img.save(log_dir / name)

View File

@@ -0,0 +1,26 @@
from .device import get_device, cuda_available, mps_available, str2device
from .image import readimgs, transform
from .latex import change_all, remove_style, add_newlines
from .path import mkdir, resolve_path
from .misc import lines_dedent
from .bbox import mask_img, bbox_merge, split_conflict, slice_from_image, draw_bboxes
__all__ = [
"get_device",
"cuda_available",
"mps_available",
"str2device",
"readimgs",
"transform",
"change_all",
"remove_style",
"add_newlines",
"mkdir",
"resolve_path",
"lines_dedent",
"mask_img",
"bbox_merge",
"split_conflict",
"slice_from_image",
"draw_bboxes",
]

142
texteller/utils/bbox.py Normal file
View File

@@ -0,0 +1,142 @@
import heapq
import os
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw
from texteller.types import Bbox
_MAXV = 999999999
def mask_img(img, bboxes: list[Bbox], bg_color: np.ndarray) -> np.ndarray:
mask_img = img.copy()
for bbox in bboxes:
mask_img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w] = bg_color
return mask_img
def bbox_merge(sorted_bboxes: list[Bbox]) -> list[Bbox]:
if len(sorted_bboxes) == 0:
return []
bboxes = sorted_bboxes.copy()
guard = Bbox(_MAXV, bboxes[-1].p.y, -1, -1, label="guard")
bboxes.append(guard)
res = []
prev = bboxes[0]
for curr in bboxes:
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
res.append(prev)
prev = curr
else:
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
return res
def split_conflict(ocr_bboxes: list[Bbox], latex_bboxes: list[Bbox]) -> list[Bbox]:
if latex_bboxes == []:
return ocr_bboxes
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
return ocr_bboxes
bboxes = sorted(ocr_bboxes + latex_bboxes)
assert len(bboxes) > 1
heapq.heapify(bboxes)
res = []
candidate = heapq.heappop(bboxes)
curr = heapq.heappop(bboxes)
idx = 0
while len(bboxes) > 0:
idx += 1
assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x < curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text" and curr.label == "text":
candidate.w = curr.ur_point.x - candidate.p.x
curr = heapq.heappop(bboxes)
elif candidate.label != curr.label:
if candidate.label == "text":
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
curr.w = curr.ur_point.x - candidate.ur_point.x
curr.p.x = candidate.ur_point.x
heapq.heappush(bboxes, curr)
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x >= curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text":
assert curr.label != "text"
heapq.heappush(
bboxes,
Bbox(
curr.ur_point.x,
candidate.p.y,
candidate.h,
candidate.ur_point.x - curr.ur_point.x,
label="text",
confidence=candidate.confidence,
content=None,
),
)
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
assert curr.label == "text"
curr = heapq.heappop(bboxes)
else:
assert False
res.append(candidate)
res.append(curr)
return res
def slice_from_image(img: np.ndarray, ocr_bboxes: list[Bbox]) -> list[np.ndarray]:
sliced_imgs = []
for bbox in ocr_bboxes:
x, y = int(bbox.p.x), int(bbox.p.y)
w, h = int(bbox.w), int(bbox.h)
sliced_img = img[y : y + h, x : x + w]
sliced_imgs.append(sliced_img)
return sliced_imgs
def draw_bboxes(img: Image.Image, bboxes: list[Bbox], name="annotated_image.png"):
curr_work_dir = Path(os.getcwd())
log_dir = curr_work_dir / "logs"
log_dir.mkdir(exist_ok=True)
drawer = ImageDraw.Draw(img)
for bbox in bboxes:
# Calculate the coordinates for the rectangle to be drawn
left = bbox.p.x
top = bbox.p.y
right = bbox.p.x + bbox.w
bottom = bbox.p.y + bbox.h
# Draw the rectangle on the image
drawer.rectangle([left, top, right, bottom], outline="green", width=1)
# Optionally, add text label if it exists
if bbox.label:
drawer.text((left, top), bbox.label, fill="blue")
if bbox.content:
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
# Save the image with drawn rectangles
img.save(log_dir / name)

41
texteller/utils/device.py Normal file
View File

@@ -0,0 +1,41 @@
from typing import Literal
import torch
def str2device(device_str: Literal["cpu", "cuda", "mps"]) -> torch.device:
if device_str == "cpu":
return torch.device("cpu")
elif device_str == "cuda":
return torch.device("cuda")
elif device_str == "mps":
return torch.device("mps")
else:
raise ValueError(f"Invalid device: {device_str}")
def get_device(device_index: int = None) -> torch.device:
"""
Automatically detect the best available device for inference.
Args:
device_index: The index of GPU device to use if multiple are available.
Defaults to None, which uses the first available GPU.
Returns:
torch.device: Selected device for model inference.
"""
if cuda_available():
return str2device("cuda")
elif mps_available():
return str2device("mps")
else:
return str2device("cpu")
def cuda_available() -> bool:
return torch.cuda.is_available()
def mps_available() -> bool:
return torch.backends.mps.is_available()

121
texteller/utils/image.py Normal file
View File

@@ -0,0 +1,121 @@
from collections import Counter
from typing import List, Union
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import v2
from texteller.constants import (
FIXED_IMG_SIZE,
IMG_CHANNELS,
IMAGE_MEAN,
IMAGE_STD,
)
from texteller.logger import get_logger
_logger = get_logger()
def readimgs(image_paths: list[str]) -> list[np.ndarray]:
"""
Read and preprocess a list of images from their file paths.
This function reads each image from the provided paths, handles different
bit depths (converting 16-bit to 8-bit if necessary), and normalizes color
channels to RGB format regardless of the original color space (BGR, BGRA,
or grayscale).
Args:
image_paths (list[str]): A list of file paths to the images to be read.
Returns:
list[np.ndarray]: A list of NumPy arrays containing the preprocessed images
in RGB format. Images that could not be read are skipped.
"""
processed_images = []
for path in image_paths:
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
raise ValueError(f"Image at {path} could not be read.")
if image.dtype == np.uint16:
_logger.warning(f'Converting {path} to 8-bit, image may be lossy.')
image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
channels = 1 if len(image.shape) == 2 else image.shape[2]
if channels == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
elif channels == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(image)
return processed_images
def trim_white_border(image: np.ndarray) -> np.ndarray:
if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8")
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
bg_color = Counter(corners).most_common(1)[0][0]
bg_color_np = np.array(bg_color, dtype=np.uint8)
h, w = image.shape[:2]
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
diff = cv2.absdiff(image, bg)
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
threshold = 15
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
x, y, w, h = cv2.boundingRect(diff)
trimmed_image = image[y : y + h, x : x + w]
return trimmed_image
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
images = [
v2.functional.pad(
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
)
for img in images
]
return images
def transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
general_transform_pipeline = v2.Compose(
[
v2.ToImage(),
v2.ToDtype(torch.uint8, scale=True),
v2.Grayscale(),
v2.Resize(
size=FIXED_IMG_SIZE - 1,
interpolation=v2.InterpolationMode.BICUBIC,
max_size=FIXED_IMG_SIZE,
antialias=True,
),
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
]
)
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
images = [
np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images
]
images = [trim_white_border(image) for image in images]
images = [general_transform_pipeline(image) for image in images]
images = padding(images, FIXED_IMG_SIZE)
return images

128
texteller/utils/latex.py Normal file
View File

@@ -0,0 +1,128 @@
import re
def _change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
result = ""
i = 0
n = len(input_str)
while i < n:
if input_str[i : i + len(old_inst)] == old_inst:
# check if the old_inst is followed by old_surr_l
start = i + len(old_inst)
else:
result += input_str[i]
i += 1
continue
if start < n and input_str[start] == old_surr_l:
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
count = 1
j = start + 1
escaped = False
while j < n and count > 0:
if input_str[j] == '\\' and not escaped:
escaped = True
j += 1
continue
if input_str[j] == old_surr_r and not escaped:
count -= 1
if count == 0:
break
elif input_str[j] == old_surr_l and not escaped:
count += 1
escaped = False
j += 1
if count == 0:
assert j < n
assert input_str[start] == old_surr_l
assert input_str[j] == old_surr_r
inner_content = input_str[start + 1 : j]
# Replace the content with new pattern
result += new_inst + new_surr_l + inner_content + new_surr_r
i = j + 1
continue
else:
assert count >= 1
assert j == n
print("Warning: unbalanced surrogate pair in input string")
result += new_inst + new_surr_l
i = start + 1
continue
else:
result += input_str[i:start]
i = start
if old_inst != new_inst and (old_inst + old_surr_l) in result:
return _change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
else:
return result
def _find_substring_positions(string, substring):
positions = [match.start() for match in re.finditer(re.escape(substring), string)]
return positions
def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
pos = _find_substring_positions(input_str, old_inst + old_surr_l)
res = list(input_str)
for p in pos[::-1]:
res[p:] = list(
_change(
''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r
)
)
res = ''.join(res)
return res
def remove_style(input_str: str) -> str:
input_str = change_all(input_str, r"\bm", r" ", r"{", r"}", r"", r" ")
input_str = change_all(input_str, r"\boldsymbol", r" ", r"{", r"}", r"", r" ")
input_str = change_all(input_str, r"\textit", r" ", r"{", r"}", r"", r" ")
input_str = change_all(input_str, r"\textbf", r" ", r"{", r"}", r"", r" ")
input_str = change_all(input_str, r"\textbf", r" ", r"{", r"}", r"", r" ")
input_str = change_all(input_str, r"\mathbf", r" ", r"{", r"}", r"", r" ")
output_str = input_str.strip()
return output_str
def add_newlines(latex_str: str) -> str:
"""
Adds newlines to a LaTeX string based on specific patterns, ensuring no
duplicate newlines are added around begin/end environments.
- After \\ (if not already followed by newline)
- Before \\begin{...} (if not already preceded by newline)
- After \\begin{...} (if not already followed by newline)
- Before \\end{...} (if not already preceded by newline)
- After \\end{...} (if not already followed by newline)
Args:
latex_str: The input LaTeX string.
Returns:
The LaTeX string with added newlines, avoiding duplicates.
"""
processed_str = latex_str
# 1. Replace whitespace around \begin{...} with \n...\n
# \s* matches zero or more whitespace characters (space, tab, newline)
# Captures the \begin{...} part in group 1 (\g<1>)
processed_str = re.sub(r"\s*(\\begin\{[^}]*\})\s*", r"\n\g<1>\n", processed_str)
# 2. Replace whitespace around \end{...} with \n...\n
# Same logic as for \begin
processed_str = re.sub(r"\s*(\\end\{[^}]*\})\s*", r"\n\g<1>\n", processed_str)
# 3. Add newline after \\ (if not already followed by newline)
processed_str = re.sub(r"\\\\(?!\n| )|\\\\ ", r"\\\\\n", processed_str)
# 4. Cleanup: Collapse multiple consecutive newlines into a single newline.
# This handles cases where the replacements above might have created \n\n.
processed_str = re.sub(r'\n{2,}', '\n', processed_str)
# Remove leading/trailing whitespace (including potential single newlines
# at the very start/end resulting from the replacements) from the entire result.
return processed_str.strip()

5
texteller/utils/misc.py Normal file
View File

@@ -0,0 +1,5 @@
from textwrap import dedent
def lines_dedent(s: str) -> str:
return dedent(s).strip()

52
texteller/utils/path.py Normal file
View File

@@ -0,0 +1,52 @@
from pathlib import Path
from typing import Literal
from texteller.logger import get_logger
_logger = get_logger(__name__)
def resolve_path(path: str | Path) -> str:
if isinstance(path, str):
path = Path(path)
return str(path.expanduser().resolve())
def touch(path: str | Path) -> None:
if isinstance(path, str):
path = Path(path)
path.touch(exist_ok=True)
def mkdir(path: str | Path) -> None:
if isinstance(path, str):
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
def rmfile(path: str | Path) -> None:
if isinstance(path, str):
path = Path(path)
path.unlink(missing_ok=False)
def rmdir(path: str | Path, mode: Literal["empty", "recursive"] = "empty") -> None:
"""Remove a directory.
Args:
path: Path to directory to remove
mode: "empty" to only remove empty directories, "all" to recursively remove all contents
"""
if isinstance(path, str):
path = Path(path)
if mode == "empty":
path.rmdir()
_logger.info(f"Removed empty directory: {path}")
elif mode == "recursive":
import shutil
shutil.rmtree(path)
_logger.info(f"Recursively removed directory and all contents: {path}")
else:
raise ValueError(f"Invalid mode: {mode}. Must be 'empty' or 'all'")

Some files were not shown because too many files have changed in this diff Show More