[refactor] Init
1
texteller/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from texteller.api import *
|
||||||
24
texteller/api/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from .detection import latex_detect
|
||||||
|
from .format import format_latex
|
||||||
|
from .inference import img2latex, paragraph2md
|
||||||
|
from .katex import to_katex
|
||||||
|
from .load import (
|
||||||
|
load_latexdet_model,
|
||||||
|
load_model,
|
||||||
|
load_textdet_model,
|
||||||
|
load_textrec_model,
|
||||||
|
load_tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"to_katex",
|
||||||
|
"format_latex",
|
||||||
|
"img2latex",
|
||||||
|
"paragraph2md",
|
||||||
|
"load_model",
|
||||||
|
"load_tokenizer",
|
||||||
|
"load_latexdet_model",
|
||||||
|
"load_textrec_model",
|
||||||
|
"load_textdet_model",
|
||||||
|
"latex_detect",
|
||||||
|
]
|
||||||
4
texteller/api/criterias/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .ngram import DetectRepeatingNgramCriteria
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["DetectRepeatingNgramCriteria"]
|
||||||
@@ -1,16 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
from transformers import StoppingCriteria
|
||||||
|
|
||||||
from transformers import RobertaTokenizerFast, GenerationConfig, StoppingCriteria
|
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
from .transforms import inference_transform
|
|
||||||
from .helpers import convert2rgb
|
|
||||||
from ..model.TexTeller import TexTeller
|
|
||||||
from ...globals import MAX_TOKEN_SIZE
|
|
||||||
|
|
||||||
|
|
||||||
class EfficientDetectRepeatingNgramCriteria(StoppingCriteria):
|
class DetectRepeatingNgramCriteria(StoppingCriteria):
|
||||||
"""
|
"""
|
||||||
Stops generation efficiently if any n-gram repeats.
|
Stops generation efficiently if any n-gram repeats.
|
||||||
|
|
||||||
@@ -69,48 +61,3 @@ class EfficientDetectRepeatingNgramCriteria(StoppingCriteria):
|
|||||||
# It's a new n-gram, add it to the set and continue
|
# It's a new n-gram, add it to the set and continue
|
||||||
self.seen_ngrams.add(last_ngram_tuple)
|
self.seen_ngrams.add(last_ngram_tuple)
|
||||||
return False # Continue generation
|
return False # Continue generation
|
||||||
|
|
||||||
|
|
||||||
def inference(
|
|
||||||
model: TexTeller,
|
|
||||||
tokenizer: RobertaTokenizerFast,
|
|
||||||
imgs: Union[List[str], List[np.ndarray]],
|
|
||||||
accelerator: str = 'cpu',
|
|
||||||
num_beams: int = 1,
|
|
||||||
max_tokens=None,
|
|
||||||
) -> List[str]:
|
|
||||||
if imgs == []:
|
|
||||||
return []
|
|
||||||
if hasattr(model, 'eval'):
|
|
||||||
# not onnx session, turn model.eval()
|
|
||||||
model.eval()
|
|
||||||
if isinstance(imgs[0], str):
|
|
||||||
imgs = convert2rgb(imgs)
|
|
||||||
else: # already numpy array(rgb format)
|
|
||||||
assert isinstance(imgs[0], np.ndarray)
|
|
||||||
imgs = imgs
|
|
||||||
imgs = inference_transform(imgs)
|
|
||||||
pixel_values = torch.stack(imgs)
|
|
||||||
|
|
||||||
if hasattr(model, 'eval'):
|
|
||||||
# not onnx session, move weights to device
|
|
||||||
model = model.to(accelerator)
|
|
||||||
pixel_values = pixel_values.to(accelerator)
|
|
||||||
|
|
||||||
generate_config = GenerationConfig(
|
|
||||||
max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
|
|
||||||
num_beams=num_beams,
|
|
||||||
do_sample=False,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
|
||||||
# no_repeat_ngram_size=10,
|
|
||||||
)
|
|
||||||
pred = model.generate(
|
|
||||||
pixel_values.to(model.device),
|
|
||||||
generation_config=generate_config,
|
|
||||||
# stopping_criteria=[EfficientDetectRepeatingNgramCriteria(20)],
|
|
||||||
)
|
|
||||||
|
|
||||||
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
|
||||||
return res
|
|
||||||
3
texteller/api/detection/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .detect import latex_detect
|
||||||
|
|
||||||
|
__all__ = ["latex_detect"]
|
||||||
48
texteller/api/detection/detect.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from onnxruntime import InferenceSession
|
||||||
|
|
||||||
|
from texteller.types import Bbox
|
||||||
|
|
||||||
|
from .preprocess import Compose
|
||||||
|
|
||||||
|
_config = {
|
||||||
|
"mode": "paddle",
|
||||||
|
"draw_threshold": 0.5,
|
||||||
|
"metric": "COCO",
|
||||||
|
"use_dynamic_shape": False,
|
||||||
|
"arch": "DETR",
|
||||||
|
"min_subgraph_size": 3,
|
||||||
|
"preprocess": [
|
||||||
|
{"interp": 2, "keep_ratio": False, "target_size": [1600, 1600], "type": "Resize"},
|
||||||
|
{
|
||||||
|
"mean": [0.0, 0.0, 0.0],
|
||||||
|
"norm_type": "none",
|
||||||
|
"std": [1.0, 1.0, 1.0],
|
||||||
|
"type": "NormalizeImage",
|
||||||
|
},
|
||||||
|
{"type": "Permute"},
|
||||||
|
],
|
||||||
|
"label_list": ["isolated", "embedding"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def latex_detect(img_path: str, predictor: InferenceSession) -> List[Bbox]:
|
||||||
|
transforms = Compose(_config["preprocess"])
|
||||||
|
inputs = transforms(img_path)
|
||||||
|
inputs_name = [var.name for var in predictor.get_inputs()]
|
||||||
|
inputs = {k: inputs[k][None,] for k in inputs_name}
|
||||||
|
|
||||||
|
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
|
||||||
|
res = []
|
||||||
|
for output in outputs:
|
||||||
|
cls_name = _config["label_list"][int(output[0])]
|
||||||
|
score = output[1]
|
||||||
|
xmin = int(max(output[2], 0))
|
||||||
|
ymin = int(max(output[3], 0))
|
||||||
|
xmax = int(output[4])
|
||||||
|
ymax = int(output[5])
|
||||||
|
if score > 0.5:
|
||||||
|
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
|
||||||
|
|
||||||
|
return res
|
||||||
161
texteller/api/detection/preprocess.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import copy
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def decode_image(img_path):
|
||||||
|
if isinstance(img_path, str):
|
||||||
|
with open(img_path, "rb") as f:
|
||||||
|
im_read = f.read()
|
||||||
|
data = np.frombuffer(im_read, dtype="uint8")
|
||||||
|
else:
|
||||||
|
assert isinstance(img_path, np.ndarray)
|
||||||
|
data = img_path
|
||||||
|
|
||||||
|
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
|
||||||
|
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||||
|
img_info = {
|
||||||
|
"im_shape": np.array(im.shape[:2], dtype=np.float32),
|
||||||
|
"scale_factor": np.array([1.0, 1.0], dtype=np.float32),
|
||||||
|
}
|
||||||
|
return im, img_info
|
||||||
|
|
||||||
|
|
||||||
|
class Resize(object):
|
||||||
|
"""resize image by target_size and max_size
|
||||||
|
Args:
|
||||||
|
target_size (int): the target size of image
|
||||||
|
keep_ratio (bool): whether keep_ratio or not, default true
|
||||||
|
interp (int): method of resize
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
|
||||||
|
if isinstance(target_size, int):
|
||||||
|
target_size = [target_size, target_size]
|
||||||
|
self.target_size = target_size
|
||||||
|
self.keep_ratio = keep_ratio
|
||||||
|
self.interp = interp
|
||||||
|
|
||||||
|
def __call__(self, im, im_info):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
im (np.ndarray): image (np.ndarray)
|
||||||
|
im_info (dict): info of image
|
||||||
|
Returns:
|
||||||
|
im (np.ndarray): processed image (np.ndarray)
|
||||||
|
im_info (dict): info of processed image
|
||||||
|
"""
|
||||||
|
assert len(self.target_size) == 2
|
||||||
|
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||||
|
im_channel = im.shape[2]
|
||||||
|
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||||
|
im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp)
|
||||||
|
im_info["im_shape"] = np.array(im.shape[:2]).astype("float32")
|
||||||
|
im_info["scale_factor"] = np.array([im_scale_y, im_scale_x]).astype("float32")
|
||||||
|
return im, im_info
|
||||||
|
|
||||||
|
def generate_scale(self, im):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
im (np.ndarray): image (np.ndarray)
|
||||||
|
Returns:
|
||||||
|
im_scale_x: the resize ratio of X
|
||||||
|
im_scale_y: the resize ratio of Y
|
||||||
|
"""
|
||||||
|
origin_shape = im.shape[:2]
|
||||||
|
im_c = im.shape[2]
|
||||||
|
if self.keep_ratio:
|
||||||
|
im_size_min = np.min(origin_shape)
|
||||||
|
im_size_max = np.max(origin_shape)
|
||||||
|
target_size_min = np.min(self.target_size)
|
||||||
|
target_size_max = np.max(self.target_size)
|
||||||
|
im_scale = float(target_size_min) / float(im_size_min)
|
||||||
|
if np.round(im_scale * im_size_max) > target_size_max:
|
||||||
|
im_scale = float(target_size_max) / float(im_size_max)
|
||||||
|
im_scale_x = im_scale
|
||||||
|
im_scale_y = im_scale
|
||||||
|
else:
|
||||||
|
resize_h, resize_w = self.target_size
|
||||||
|
im_scale_y = resize_h / float(origin_shape[0])
|
||||||
|
im_scale_x = resize_w / float(origin_shape[1])
|
||||||
|
return im_scale_y, im_scale_x
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeImage(object):
|
||||||
|
"""normalize image
|
||||||
|
Args:
|
||||||
|
mean (list): im - mean
|
||||||
|
std (list): im / std
|
||||||
|
is_scale (bool): whether need im / 255
|
||||||
|
norm_type (str): type in ['mean_std', 'none']
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mean, std, is_scale=True, norm_type="mean_std"):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
self.is_scale = is_scale
|
||||||
|
self.norm_type = norm_type
|
||||||
|
|
||||||
|
def __call__(self, im, im_info):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
im (np.ndarray): image (np.ndarray)
|
||||||
|
im_info (dict): info of image
|
||||||
|
Returns:
|
||||||
|
im (np.ndarray): processed image (np.ndarray)
|
||||||
|
im_info (dict): info of processed image
|
||||||
|
"""
|
||||||
|
im = im.astype(np.float32, copy=False)
|
||||||
|
if self.is_scale:
|
||||||
|
scale = 1.0 / 255.0
|
||||||
|
im *= scale
|
||||||
|
|
||||||
|
if self.norm_type == "mean_std":
|
||||||
|
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
|
||||||
|
std = np.array(self.std)[np.newaxis, np.newaxis, :]
|
||||||
|
im -= mean
|
||||||
|
im /= std
|
||||||
|
return im, im_info
|
||||||
|
|
||||||
|
|
||||||
|
class Permute(object):
|
||||||
|
"""permute image
|
||||||
|
Args:
|
||||||
|
to_bgr (bool): whether convert RGB to BGR
|
||||||
|
channel_first (bool): whether convert HWC to CHW
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
super(Permute, self).__init__()
|
||||||
|
|
||||||
|
def __call__(self, im, im_info):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
im (np.ndarray): image (np.ndarray)
|
||||||
|
im_info (dict): info of image
|
||||||
|
Returns:
|
||||||
|
im (np.ndarray): processed image (np.ndarray)
|
||||||
|
im_info (dict): info of processed image
|
||||||
|
"""
|
||||||
|
im = im.transpose((2, 0, 1)).copy()
|
||||||
|
return im, im_info
|
||||||
|
|
||||||
|
|
||||||
|
class Compose:
|
||||||
|
def __init__(self, transforms):
|
||||||
|
self.transforms = []
|
||||||
|
for op_info in transforms:
|
||||||
|
new_op_info = op_info.copy()
|
||||||
|
op_type = new_op_info.pop("type")
|
||||||
|
self.transforms.append(eval(op_type)(**new_op_info))
|
||||||
|
|
||||||
|
def __call__(self, img_path):
|
||||||
|
img, im_info = decode_image(img_path)
|
||||||
|
for t in self.transforms:
|
||||||
|
img, im_info = t(img, im_info)
|
||||||
|
inputs = copy.deepcopy(im_info)
|
||||||
|
inputs["image"] = img
|
||||||
|
return inputs
|
||||||
@@ -5,9 +5,8 @@ Based on the Rust implementation at https://github.com/WGUNDERWOOD/tex-fmt
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import argparse
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Dict, Set
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
LINE_END = "\n"
|
LINE_END = "\n"
|
||||||
@@ -49,7 +48,7 @@ RE_SPLITTING_SHARED_LINE_CAPTURE = re.compile(f"(?P<prev>\\S.*?)(?P<env>{SPLITTI
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Args:
|
class Args:
|
||||||
"""Command line arguments and configuration."""
|
"""Formatter configuration."""
|
||||||
|
|
||||||
tabchar: str = " "
|
tabchar: str = " "
|
||||||
tabsize: int = 4
|
tabsize: int = 4
|
||||||
@@ -542,13 +541,29 @@ def indents_return_to_zero(state: State) -> bool:
|
|||||||
return state.indent.actual == 0
|
return state.indent.actual == 0
|
||||||
|
|
||||||
|
|
||||||
def format_latex(
|
def format_latex(text: str) -> str:
|
||||||
old_text: str, file: str = "input.tex", args: Optional[Args] = None
|
"""Format LaTeX text with default formatting options.
|
||||||
) -> Tuple[str, List[Log]]:
|
|
||||||
"""Central function to format a LaTeX string."""
|
|
||||||
if args is None:
|
|
||||||
args = Args()
|
|
||||||
|
|
||||||
|
This is the main API function for formatting LaTeX text.
|
||||||
|
It uses pre-defined default values for all formatting parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: LaTeX text to format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted LaTeX text
|
||||||
|
"""
|
||||||
|
# Use default configuration
|
||||||
|
args = Args()
|
||||||
|
file = "input.tex"
|
||||||
|
|
||||||
|
# Format and return only the text
|
||||||
|
formatted_text, _ = _format_latex(text, file, args)
|
||||||
|
return formatted_text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _format_latex(old_text: str, file: str, args: Args) -> Tuple[str, List[Log]]:
|
||||||
|
"""Internal function to format a LaTeX string."""
|
||||||
logs = []
|
logs = []
|
||||||
logs.append(Log(level="INFO", file=file, message="Formatting started."))
|
logs.append(Log(level="INFO", file=file, message="Formatting started."))
|
||||||
|
|
||||||
@@ -636,63 +651,3 @@ def format_latex(
|
|||||||
logs.append(Log(level="INFO", file=file, message="Formatting complete."))
|
logs.append(Log(level="INFO", file=file, message="Formatting complete."))
|
||||||
|
|
||||||
return new_text, logs
|
return new_text, logs
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Command-line entry point."""
|
|
||||||
parser = argparse.ArgumentParser(description="Format LaTeX files")
|
|
||||||
parser.add_argument("file", help="LaTeX file to format")
|
|
||||||
parser.add_argument(
|
|
||||||
"--tabchar",
|
|
||||||
choices=["space", "tab"],
|
|
||||||
default="space",
|
|
||||||
help="Character to use for indentation",
|
|
||||||
)
|
|
||||||
parser.add_argument("--tabsize", type=int, default=4, help="Number of spaces per indent level")
|
|
||||||
parser.add_argument("--wrap", action="store_true", help="Enable line wrapping")
|
|
||||||
parser.add_argument("--wraplen", type=int, default=80, help="Maximum line length")
|
|
||||||
parser.add_argument(
|
|
||||||
"--wrapmin", type=int, default=40, help="Minimum line length before wrapping"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lists", nargs="+", default=[], help="Additional environments to indent as lists"
|
|
||||||
)
|
|
||||||
parser.add_argument("--verbose", "-v", action="count", default=0, help="Increase verbosity")
|
|
||||||
parser.add_argument("--output", "-o", help="Output file (default: overwrite input)")
|
|
||||||
|
|
||||||
args_parsed = parser.parse_args()
|
|
||||||
|
|
||||||
# Convert command line args to our Args class
|
|
||||||
args = Args(
|
|
||||||
tabchar="\t" if args_parsed.tabchar == "tab" else " ",
|
|
||||||
tabsize=args_parsed.tabsize,
|
|
||||||
wrap=args_parsed.wrap,
|
|
||||||
wraplen=args_parsed.wraplen,
|
|
||||||
wrapmin=args_parsed.wrapmin,
|
|
||||||
lists=args_parsed.lists,
|
|
||||||
verbosity=args_parsed.verbose,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Read input file
|
|
||||||
with open(args_parsed.file, "r", encoding="utf-8") as f:
|
|
||||||
text = f.read()
|
|
||||||
|
|
||||||
# Format the text
|
|
||||||
formatted_text, logs = format_latex(text, args_parsed.file, args)
|
|
||||||
|
|
||||||
# Print logs if verbose
|
|
||||||
if args.verbosity > 0:
|
|
||||||
for log in logs:
|
|
||||||
if log.linum_new is not None:
|
|
||||||
print(f"{log.level} {log.file}:{log.linum_new}:{log.linum_old}: {log.message}")
|
|
||||||
else:
|
|
||||||
print(f"{log.level} {log.file}: {log.message}")
|
|
||||||
|
|
||||||
# Write output
|
|
||||||
output_file = args_parsed.output or args_parsed.file
|
|
||||||
with open(output_file, "w", encoding="utf-8") as f:
|
|
||||||
f.write(formatted_text)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
241
texteller/api/inference.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
import re
|
||||||
|
import time
|
||||||
|
from collections import Counter
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from onnxruntime import InferenceSession
|
||||||
|
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||||
|
from transformers import GenerationConfig, RobertaTokenizerFast
|
||||||
|
|
||||||
|
from texteller.constants import MAX_TOKEN_SIZE
|
||||||
|
from texteller.logger import get_logger
|
||||||
|
from texteller.paddleocr import predict_det, predict_rec
|
||||||
|
from texteller.types import Bbox, TexTellerModel
|
||||||
|
from texteller.utils import (
|
||||||
|
bbox_merge,
|
||||||
|
get_device,
|
||||||
|
mask_img,
|
||||||
|
readimgs,
|
||||||
|
remove_style,
|
||||||
|
slice_from_image,
|
||||||
|
split_conflict,
|
||||||
|
transform,
|
||||||
|
add_newlines,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .detection import latex_detect
|
||||||
|
from .format import format_latex
|
||||||
|
from .katex import to_katex
|
||||||
|
|
||||||
|
_logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def img2latex(
|
||||||
|
model: TexTellerModel,
|
||||||
|
tokenizer: RobertaTokenizerFast,
|
||||||
|
images: list[str] | list[np.ndarray],
|
||||||
|
device: torch.device | None = None,
|
||||||
|
out_format: Literal["latex", "katex"] = "latex",
|
||||||
|
keep_style: bool = False,
|
||||||
|
max_tokens: int = MAX_TOKEN_SIZE,
|
||||||
|
num_beams: int = 1,
|
||||||
|
no_repeat_ngram_size: int = 0,
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Convert images to LaTeX or KaTeX formatted strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The TexTeller or ORTModelForVision2Seq model instance
|
||||||
|
tokenizer: The tokenizer for the model
|
||||||
|
images: List of image paths or numpy arrays (RGB format)
|
||||||
|
device: The torch device to use (defaults to available GPU or CPU)
|
||||||
|
out_format: Output format, either "latex" or "katex"
|
||||||
|
keep_style: Whether to keep the style of the LaTeX
|
||||||
|
max_tokens: Maximum number of tokens to generate
|
||||||
|
num_beams: Number of beams for beam search
|
||||||
|
no_repeat_ngram_size: Size of n-grams to prevent repetition
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LaTeX or KaTeX strings corresponding to each input image
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
>>> import torch
|
||||||
|
>>> from texteller import load_model, load_tokenizer, img2latex
|
||||||
|
|
||||||
|
>>> model = load_model(model_path=None, use_onnx=False)
|
||||||
|
>>> tokenizer = load_tokenizer(tokenizer_path=None)
|
||||||
|
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
>>> res = img2latex(model, tokenizer, ["path/to/image.png"], device=device, out_format="katex")
|
||||||
|
"""
|
||||||
|
assert isinstance(images, list)
|
||||||
|
assert len(images) > 0
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device = get_device()
|
||||||
|
|
||||||
|
if device.type != model.device.type:
|
||||||
|
if isinstance(model, ORTModelForVision2Seq):
|
||||||
|
_logger.warning(
|
||||||
|
f"Onnxruntime device mismatch: detected {str(device)} but model is on {str(model.device)}, using {str(model.device)} instead"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = model.to(device=device)
|
||||||
|
|
||||||
|
if isinstance(images[0], str):
|
||||||
|
images = readimgs(images)
|
||||||
|
else: # already numpy array(rgb format)
|
||||||
|
assert isinstance(images[0], np.ndarray)
|
||||||
|
images = images
|
||||||
|
|
||||||
|
images = transform(images)
|
||||||
|
pixel_values = torch.stack(images)
|
||||||
|
|
||||||
|
generate_config = GenerationConfig(
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
num_beams=num_beams,
|
||||||
|
do_sample=False,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
|
)
|
||||||
|
pred = model.generate(
|
||||||
|
pixel_values.to(model.device),
|
||||||
|
generation_config=generate_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||||
|
|
||||||
|
if out_format == "katex":
|
||||||
|
res = [to_katex(r) for r in res]
|
||||||
|
|
||||||
|
if not keep_style:
|
||||||
|
res = [remove_style(r) for r in res]
|
||||||
|
|
||||||
|
res = [format_latex(r) for r in res]
|
||||||
|
res = [add_newlines(r) for r in res]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def paragraph2md(
|
||||||
|
img_path: str,
|
||||||
|
latexdet_model: InferenceSession,
|
||||||
|
textdet_model: predict_det.TextDetector,
|
||||||
|
textrec_model: predict_rec.TextRecognizer,
|
||||||
|
latexrec_model: TexTellerModel,
|
||||||
|
tokenizer: RobertaTokenizerFast,
|
||||||
|
device: torch.device | None = None,
|
||||||
|
num_beams=1,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Input a mixed image of formula text and output str (in markdown syntax)
|
||||||
|
"""
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])]
|
||||||
|
bg_color = np.array(Counter(corners).most_common(1)[0][0])
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
latex_bboxes = latex_detect(img_path, latexdet_model)
|
||||||
|
end_time = time.time()
|
||||||
|
_logger.info(f"latex_det_model time: {end_time - start_time:.2f}s")
|
||||||
|
latex_bboxes = sorted(latex_bboxes)
|
||||||
|
latex_bboxes = bbox_merge(latex_bboxes)
|
||||||
|
masked_img = mask_img(img, latex_bboxes, bg_color)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
det_prediction, _ = textdet_model(masked_img)
|
||||||
|
end_time = time.time()
|
||||||
|
_logger.info(f"ocr_det_model time: {end_time - start_time:.2f}s")
|
||||||
|
ocr_bboxes = [
|
||||||
|
Bbox(
|
||||||
|
p[0][0],
|
||||||
|
p[0][1],
|
||||||
|
p[3][1] - p[0][1],
|
||||||
|
p[1][0] - p[0][0],
|
||||||
|
label="text",
|
||||||
|
confidence=None,
|
||||||
|
content=None,
|
||||||
|
)
|
||||||
|
for p in det_prediction
|
||||||
|
]
|
||||||
|
|
||||||
|
ocr_bboxes = sorted(ocr_bboxes)
|
||||||
|
ocr_bboxes = bbox_merge(ocr_bboxes)
|
||||||
|
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
|
||||||
|
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
|
||||||
|
|
||||||
|
sliced_imgs: list[np.ndarray] = slice_from_image(img, ocr_bboxes)
|
||||||
|
start_time = time.time()
|
||||||
|
rec_predictions, _ = textrec_model(sliced_imgs)
|
||||||
|
end_time = time.time()
|
||||||
|
_logger.info(f"ocr_rec_model time: {end_time - start_time:.2f}s")
|
||||||
|
|
||||||
|
assert len(rec_predictions) == len(ocr_bboxes)
|
||||||
|
for content, bbox in zip(rec_predictions, ocr_bboxes):
|
||||||
|
bbox.content = content[0]
|
||||||
|
|
||||||
|
latex_imgs = []
|
||||||
|
for bbox in latex_bboxes:
|
||||||
|
latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w])
|
||||||
|
start_time = time.time()
|
||||||
|
latex_rec_res = img2latex(
|
||||||
|
model=latexrec_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
images=latex_imgs,
|
||||||
|
num_beams=num_beams,
|
||||||
|
out_format="katex",
|
||||||
|
device=device,
|
||||||
|
keep_style=False,
|
||||||
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
_logger.info(f"latex_rec_model time: {end_time - start_time:.2f}s")
|
||||||
|
|
||||||
|
for bbox, content in zip(latex_bboxes, latex_rec_res):
|
||||||
|
if bbox.label == "embedding":
|
||||||
|
bbox.content = " $" + content + "$ "
|
||||||
|
elif bbox.label == "isolated":
|
||||||
|
bbox.content = "\n\n" + r"$$" + content + r"$$" + "\n\n"
|
||||||
|
|
||||||
|
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
||||||
|
if bboxes == []:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
md = ""
|
||||||
|
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
|
||||||
|
for curr in bboxes:
|
||||||
|
# Add the formula number back to the isolated formula
|
||||||
|
if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr):
|
||||||
|
curr.content = curr.content.strip()
|
||||||
|
if curr.content.startswith("(") and curr.content.endswith(")"):
|
||||||
|
curr.content = curr.content[1:-1]
|
||||||
|
|
||||||
|
if re.search(r"\\tag\{.*\}$", md[:-4]) is not None:
|
||||||
|
# in case of multiple tag
|
||||||
|
md = md[:-5] + f", {curr.content}" + "}" + md[-4:]
|
||||||
|
else:
|
||||||
|
md = md[:-4] + f"\\tag{{{curr.content}}}" + md[-4:]
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not prev.same_row(curr):
|
||||||
|
md += " "
|
||||||
|
|
||||||
|
if curr.label == "embedding":
|
||||||
|
# remove the bold effect from inline formulas
|
||||||
|
curr.content = remove_style(curr.content)
|
||||||
|
|
||||||
|
# change split environment into aligned
|
||||||
|
curr.content = curr.content.replace(r"\begin{split}", r"\begin{aligned}")
|
||||||
|
curr.content = curr.content.replace(r"\end{split}", r"\end{aligned}")
|
||||||
|
|
||||||
|
# remove extra spaces (keeping only one)
|
||||||
|
curr.content = re.sub(r" +", " ", curr.content)
|
||||||
|
assert curr.content.startswith("$") and curr.content.endswith("$")
|
||||||
|
curr.content = " $" + curr.content.strip("$") + "$ "
|
||||||
|
md += curr.content
|
||||||
|
prev = curr
|
||||||
|
|
||||||
|
return md.strip()
|
||||||
@@ -1,73 +1,10 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from .latex_formatter import format_latex
|
from ..utils.latex import change_all
|
||||||
|
from .format import format_latex
|
||||||
|
|
||||||
|
|
||||||
def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
|
def _rm_dollar_surr(content):
|
||||||
result = ""
|
|
||||||
i = 0
|
|
||||||
n = len(input_str)
|
|
||||||
|
|
||||||
while i < n:
|
|
||||||
if input_str[i : i + len(old_inst)] == old_inst:
|
|
||||||
# check if the old_inst is followed by old_surr_l
|
|
||||||
start = i + len(old_inst)
|
|
||||||
else:
|
|
||||||
result += input_str[i]
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
if start < n and input_str[start] == old_surr_l:
|
|
||||||
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
|
|
||||||
count = 1
|
|
||||||
j = start + 1
|
|
||||||
escaped = False
|
|
||||||
while j < n and count > 0:
|
|
||||||
if input_str[j] == '\\' and not escaped:
|
|
||||||
escaped = True
|
|
||||||
j += 1
|
|
||||||
continue
|
|
||||||
if input_str[j] == old_surr_r and not escaped:
|
|
||||||
count -= 1
|
|
||||||
if count == 0:
|
|
||||||
break
|
|
||||||
elif input_str[j] == old_surr_l and not escaped:
|
|
||||||
count += 1
|
|
||||||
escaped = False
|
|
||||||
j += 1
|
|
||||||
|
|
||||||
if count == 0:
|
|
||||||
assert j < n
|
|
||||||
assert input_str[start] == old_surr_l
|
|
||||||
assert input_str[j] == old_surr_r
|
|
||||||
inner_content = input_str[start + 1 : j]
|
|
||||||
# Replace the content with new pattern
|
|
||||||
result += new_inst + new_surr_l + inner_content + new_surr_r
|
|
||||||
i = j + 1
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
assert count >= 1
|
|
||||||
assert j == n
|
|
||||||
print("Warning: unbalanced surrogate pair in input string")
|
|
||||||
result += new_inst + new_surr_l
|
|
||||||
i = start + 1
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
result += input_str[i:start]
|
|
||||||
i = start
|
|
||||||
|
|
||||||
if old_inst != new_inst and (old_inst + old_surr_l) in result:
|
|
||||||
return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
|
|
||||||
else:
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def find_substring_positions(string, substring):
|
|
||||||
positions = [match.start() for match in re.finditer(re.escape(substring), string)]
|
|
||||||
return positions
|
|
||||||
|
|
||||||
|
|
||||||
def rm_dollar_surr(content):
|
|
||||||
pattern = re.compile(r'\\[a-zA-Z]+\$.*?\$|\$.*?\$')
|
pattern = re.compile(r'\\[a-zA-Z]+\$.*?\$|\$.*?\$')
|
||||||
matches = pattern.findall(content)
|
matches = pattern.findall(content)
|
||||||
|
|
||||||
@@ -79,19 +16,6 @@ def rm_dollar_surr(content):
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
|
|
||||||
pos = find_substring_positions(input_str, old_inst + old_surr_l)
|
|
||||||
res = list(input_str)
|
|
||||||
for p in pos[::-1]:
|
|
||||||
res[p:] = list(
|
|
||||||
change(
|
|
||||||
''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r
|
|
||||||
)
|
|
||||||
)
|
|
||||||
res = ''.join(res)
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def to_katex(formula: str) -> str:
|
def to_katex(formula: str) -> str:
|
||||||
res = formula
|
res = formula
|
||||||
# remove mbox surrounding
|
# remove mbox surrounding
|
||||||
@@ -182,13 +106,13 @@ def to_katex(formula: str) -> str:
|
|||||||
res = re.sub(r'(\\text\{[^}]*\}\s*){2,}', merge_texts, res)
|
res = re.sub(r'(\\text\{[^}]*\}\s*){2,}', merge_texts, res)
|
||||||
|
|
||||||
res = res.replace(r'\bf ', '')
|
res = res.replace(r'\bf ', '')
|
||||||
res = rm_dollar_surr(res)
|
res = _rm_dollar_surr(res)
|
||||||
|
|
||||||
# remove extra spaces (keeping only one)
|
# remove extra spaces (keeping only one)
|
||||||
res = re.sub(r' +', ' ', res)
|
res = re.sub(r' +', ' ', res)
|
||||||
|
|
||||||
# format latex
|
# format latex
|
||||||
res = res.strip()
|
res = res.strip()
|
||||||
res, logs = format_latex(res)
|
res = format_latex(res)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
66
texteller/api/load.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import wget
|
||||||
|
from onnxruntime import InferenceSession
|
||||||
|
from transformers import RobertaTokenizerFast
|
||||||
|
|
||||||
|
from texteller.constants import LATEX_DET_MODEL_URL, TEXT_DET_MODEL_URL, TEXT_REC_MODEL_URL
|
||||||
|
from texteller.globals import Globals
|
||||||
|
from texteller.logger import get_logger
|
||||||
|
from texteller.models import TexTeller
|
||||||
|
from texteller.paddleocr import predict_det, predict_rec
|
||||||
|
from texteller.paddleocr.utility import parse_args
|
||||||
|
from texteller.utils import cuda_available, mkdir, resolve_path
|
||||||
|
from texteller.types import TexTellerModel
|
||||||
|
|
||||||
|
_logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_dir: str | None = None, use_onnx: bool = False) -> TexTellerModel:
|
||||||
|
return TexTeller.from_pretrained(model_dir, use_onnx=use_onnx)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokenizer(tokenizer_dir: str | None = None) -> RobertaTokenizerFast:
|
||||||
|
return TexTeller.get_tokenizer(tokenizer_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def load_latexdet_model() -> InferenceSession:
|
||||||
|
fpath = _maybe_download(LATEX_DET_MODEL_URL)
|
||||||
|
return InferenceSession(
|
||||||
|
resolve_path(fpath),
|
||||||
|
providers=["CUDAExecutionProvider" if cuda_available() else "CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_textrec_model() -> predict_rec.TextRecognizer:
|
||||||
|
fpath = _maybe_download(TEXT_REC_MODEL_URL)
|
||||||
|
paddleocr_args = parse_args()
|
||||||
|
paddleocr_args.use_onnx = True
|
||||||
|
paddleocr_args.rec_model_dir = resolve_path(fpath)
|
||||||
|
paddleocr_args.use_gpu = cuda_available()
|
||||||
|
predictor = predict_rec.TextRecognizer(paddleocr_args)
|
||||||
|
return predictor
|
||||||
|
|
||||||
|
|
||||||
|
def load_textdet_model() -> predict_det.TextDetector:
|
||||||
|
fpath = _maybe_download(TEXT_DET_MODEL_URL)
|
||||||
|
paddleocr_args = parse_args()
|
||||||
|
paddleocr_args.use_onnx = True
|
||||||
|
paddleocr_args.det_model_dir = resolve_path(fpath)
|
||||||
|
paddleocr_args.use_gpu = cuda_available()
|
||||||
|
predictor = predict_det.TextDetector(paddleocr_args)
|
||||||
|
return predictor
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_download(url: str, dirpath: str | None = None, force: bool = False) -> Path:
|
||||||
|
if dirpath is None:
|
||||||
|
dirpath = Globals().cache_dir
|
||||||
|
mkdir(dirpath)
|
||||||
|
|
||||||
|
fname = Path(url).name
|
||||||
|
fpath = Path(dirpath) / fname
|
||||||
|
if not fpath.exists() or force:
|
||||||
|
_logger.info(f"Downloading {fname} from {url} to {fpath}")
|
||||||
|
wget.download(url, resolve_path(fpath))
|
||||||
|
|
||||||
|
return fpath
|
||||||
25
texteller/cli/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""
|
||||||
|
CLI entry point for TexTeller.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from texteller.cli.commands.inference import inference
|
||||||
|
from texteller.cli.commands.launch import launch
|
||||||
|
from texteller.cli.commands.web import web
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def cli():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
cli.add_command(inference)
|
||||||
|
cli.add_command(web)
|
||||||
|
cli.add_command(launch)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
3
texteller/cli/commands/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
CLI commands for TexTeller
|
||||||
|
"""
|
||||||
51
texteller/cli/commands/inference.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""
|
||||||
|
CLI command for formula inference from images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from texteller.api import img2latex, load_model, load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("image_path", type=click.Path(exists=True, file_okay=True, dir_okay=False))
|
||||||
|
@click.option(
|
||||||
|
"--model-path",
|
||||||
|
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
||||||
|
default=None,
|
||||||
|
help="Path to the model dir path, if not provided, will use model from huggingface repo",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--tokenizer-path",
|
||||||
|
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
||||||
|
default=None,
|
||||||
|
help="Path to the tokenizer dir path, if not provided, will use tokenizer from huggingface repo",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--output-format",
|
||||||
|
type=click.Choice(["latex", "katex"]),
|
||||||
|
default="katex",
|
||||||
|
help="Output format, either latex or katex",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--keep-style",
|
||||||
|
is_flag=True,
|
||||||
|
default=False,
|
||||||
|
help="Whether to keep the style of the LaTeX (e.g. bold, italic, etc.)",
|
||||||
|
)
|
||||||
|
def inference(image_path, model_path, tokenizer_path, output_format, keep_style):
|
||||||
|
"""
|
||||||
|
CLI command for formula inference from images.
|
||||||
|
"""
|
||||||
|
model = load_model(model_dir=model_path)
|
||||||
|
tknz = load_tokenizer(tokenizer_dir=tokenizer_path)
|
||||||
|
|
||||||
|
pred = img2latex(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tknz,
|
||||||
|
images=[image_path],
|
||||||
|
out_format=output_format,
|
||||||
|
keep_style=keep_style,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
click.echo(f"Predicted LaTeX: ```\n{pred}\n```")
|
||||||
106
texteller/cli/commands/launch/__init__.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""
|
||||||
|
CLI commands for launching server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import click
|
||||||
|
from ray import serve
|
||||||
|
|
||||||
|
from texteller.globals import Globals
|
||||||
|
from texteller.utils import get_device
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option(
|
||||||
|
"-ckpt",
|
||||||
|
"--checkpoint_dir",
|
||||||
|
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
||||||
|
default=None,
|
||||||
|
help="Path to the checkpoint directory, if not provided, will use model from huggingface repo",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-tknz",
|
||||||
|
"--tokenizer_dir",
|
||||||
|
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
||||||
|
default=None,
|
||||||
|
help="Path to the tokenizer directory, if not provided, will use tokenizer from huggingface repo",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-p",
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=8000,
|
||||||
|
help="Port to run the server on",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--num-replicas",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of replicas to run the server on",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--ncpu-per-replica",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Number of CPUs per replica",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--ngpu-per-replica",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Number of GPUs per replica",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--num-beams",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of beams to use",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--use-onnx",
|
||||||
|
is_flag=True,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="Use ONNX runtime",
|
||||||
|
)
|
||||||
|
def launch(
|
||||||
|
checkpoint_dir,
|
||||||
|
tokenizer_dir,
|
||||||
|
port,
|
||||||
|
num_replicas,
|
||||||
|
ncpu_per_replica,
|
||||||
|
ngpu_per_replica,
|
||||||
|
num_beams,
|
||||||
|
use_onnx,
|
||||||
|
):
|
||||||
|
"""Launch the api server"""
|
||||||
|
device = get_device()
|
||||||
|
if ngpu_per_replica > 0 and not device.type == "cuda":
|
||||||
|
click.echo(
|
||||||
|
click.style(
|
||||||
|
f"Error: --ngpu-per-replica > 0 but detected device is {device.type}",
|
||||||
|
fg="red",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
Globals().num_replicas = num_replicas
|
||||||
|
Globals().ncpu_per_replica = ncpu_per_replica
|
||||||
|
Globals().ngpu_per_replica = ngpu_per_replica
|
||||||
|
from texteller.cli.commands.launch.server import Ingress, TexTellerServer
|
||||||
|
|
||||||
|
serve.start(http_options={"host": "0.0.0.0", "port": port})
|
||||||
|
rec_server = TexTellerServer.bind(
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
tokenizer_dir=tokenizer_dir,
|
||||||
|
use_onnx=use_onnx,
|
||||||
|
num_beams=num_beams,
|
||||||
|
)
|
||||||
|
ingress = Ingress.bind(rec_server)
|
||||||
|
|
||||||
|
serve.run(ingress, route_prefix="/predict")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
69
texteller/cli/commands/launch/server.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from starlette.requests import Request
|
||||||
|
from ray import serve
|
||||||
|
from ray.serve.handle import DeploymentHandle
|
||||||
|
|
||||||
|
from texteller.api import load_model, load_tokenizer, img2latex
|
||||||
|
from texteller.utils import get_device
|
||||||
|
from texteller.globals import Globals
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
@serve.deployment(
|
||||||
|
num_replicas=Globals().num_replicas,
|
||||||
|
ray_actor_options={
|
||||||
|
"num_cpus": Globals().ncpu_per_replica,
|
||||||
|
"num_gpus": Globals().ngpu_per_replica * 1.0 / 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
class TexTellerServer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
checkpoint_dir: str,
|
||||||
|
tokenizer_dir: str,
|
||||||
|
use_onnx: bool = False,
|
||||||
|
out_format: Literal["latex", "katex"] = "katex",
|
||||||
|
keep_style: bool = False,
|
||||||
|
num_beams: int = 1,
|
||||||
|
) -> None:
|
||||||
|
self.model = load_model(
|
||||||
|
model_dir=checkpoint_dir,
|
||||||
|
use_onnx=use_onnx,
|
||||||
|
)
|
||||||
|
self.tokenizer = load_tokenizer(tokenizer_dir=tokenizer_dir)
|
||||||
|
self.num_beams = num_beams
|
||||||
|
self.out_format = out_format
|
||||||
|
self.keep_style = keep_style
|
||||||
|
|
||||||
|
if not use_onnx:
|
||||||
|
self.model = self.model.to(get_device())
|
||||||
|
|
||||||
|
def predict(self, image_nparray: np.ndarray) -> str:
|
||||||
|
return img2latex(
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
images=[image_nparray],
|
||||||
|
device=get_device(),
|
||||||
|
out_format=self.out_format,
|
||||||
|
keep_style=self.keep_style,
|
||||||
|
num_beams=self.num_beams,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
|
||||||
|
@serve.deployment()
|
||||||
|
class Ingress:
|
||||||
|
def __init__(self, rec_server: DeploymentHandle) -> None:
|
||||||
|
self.texteller_server = rec_server
|
||||||
|
|
||||||
|
async def __call__(self, request: Request) -> str:
|
||||||
|
form = await request.form()
|
||||||
|
img_rb = await form["img"].read()
|
||||||
|
|
||||||
|
img_nparray = np.frombuffer(img_rb, np.uint8)
|
||||||
|
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
|
||||||
|
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
pred = await self.texteller_server.predict.remote(img_nparray)
|
||||||
|
return pred
|
||||||
9
texteller/cli/commands/web/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import os
|
||||||
|
import click
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
def web():
|
||||||
|
"""Launch the web interface for TexTeller."""
|
||||||
|
os.system(f"streamlit run {Path(__file__).parent / 'streamlit_demo.py'}")
|
||||||
225
texteller/cli/commands/web/streamlit_demo.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
from PIL import Image
|
||||||
|
from streamlit_paste_button import paste_image_button as pbutton
|
||||||
|
|
||||||
|
from texteller.api import (
|
||||||
|
img2latex,
|
||||||
|
load_latexdet_model,
|
||||||
|
load_model,
|
||||||
|
load_textdet_model,
|
||||||
|
load_textrec_model,
|
||||||
|
load_tokenizer,
|
||||||
|
paragraph2md,
|
||||||
|
)
|
||||||
|
from texteller.cli.commands.web.style import (
|
||||||
|
HEADER_HTML,
|
||||||
|
IMAGE_EMBED_HTML,
|
||||||
|
IMAGE_INFO_HTML,
|
||||||
|
SUCCESS_GIF_HTML,
|
||||||
|
)
|
||||||
|
from texteller.utils import str2device
|
||||||
|
|
||||||
|
st.set_page_config(page_title="TexTeller", page_icon="🧮")
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def get_texteller(use_onnx):
|
||||||
|
return load_model(use_onnx=use_onnx)
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def get_tokenizer():
|
||||||
|
return load_tokenizer()
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def get_latexdet_model():
|
||||||
|
return load_latexdet_model()
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource()
|
||||||
|
def get_textrec_model():
|
||||||
|
return load_textrec_model()
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource()
|
||||||
|
def get_textdet_model():
|
||||||
|
return load_textdet_model()
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_base64(img_file):
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
img_file.seek(0)
|
||||||
|
img = Image.open(img_file)
|
||||||
|
img.save(buffered, format="PNG")
|
||||||
|
return base64.b64encode(buffered.getvalue()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def on_file_upload():
|
||||||
|
st.session_state["UPLOADED_FILE_CHANGED"] = True
|
||||||
|
|
||||||
|
|
||||||
|
def change_side_bar():
|
||||||
|
st.session_state["CHANGE_SIDEBAR_FLAG"] = True
|
||||||
|
|
||||||
|
|
||||||
|
if "start" not in st.session_state:
|
||||||
|
st.session_state["start"] = 1
|
||||||
|
st.toast("Hooray!", icon="🎉")
|
||||||
|
|
||||||
|
if "UPLOADED_FILE_CHANGED" not in st.session_state:
|
||||||
|
st.session_state["UPLOADED_FILE_CHANGED"] = False
|
||||||
|
|
||||||
|
if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
|
||||||
|
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
|
||||||
|
|
||||||
|
if "INF_MODE" not in st.session_state:
|
||||||
|
st.session_state["INF_MODE"] = "Formula recognition"
|
||||||
|
|
||||||
|
|
||||||
|
# ====== <sidebar> ======
|
||||||
|
|
||||||
|
with st.sidebar:
|
||||||
|
num_beams = 1
|
||||||
|
|
||||||
|
st.markdown("# 🔨️ Config")
|
||||||
|
st.markdown("")
|
||||||
|
|
||||||
|
inf_mode = st.selectbox(
|
||||||
|
"Inference mode",
|
||||||
|
("Formula recognition", "Paragraph recognition"),
|
||||||
|
on_change=change_side_bar,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_beams = st.number_input(
|
||||||
|
"Number of beams", min_value=1, max_value=20, step=1, on_change=change_side_bar
|
||||||
|
)
|
||||||
|
|
||||||
|
device = st.radio("device", ("cpu", "cuda", "mps"), on_change=change_side_bar)
|
||||||
|
|
||||||
|
st.markdown("## Seedup")
|
||||||
|
use_onnx = st.toggle("ONNX Runtime ")
|
||||||
|
|
||||||
|
|
||||||
|
# ====== </sidebar> ======
|
||||||
|
|
||||||
|
|
||||||
|
# ====== <page> ======
|
||||||
|
|
||||||
|
latexrec_model = get_texteller(use_onnx)
|
||||||
|
tokenizer = get_tokenizer()
|
||||||
|
|
||||||
|
if inf_mode == "Paragraph recognition":
|
||||||
|
latexdet_model = get_latexdet_model()
|
||||||
|
textrec_model = get_textrec_model()
|
||||||
|
textdet_model = get_textdet_model()
|
||||||
|
|
||||||
|
st.markdown(HEADER_HTML, unsafe_allow_html=True)
|
||||||
|
|
||||||
|
uploaded_file = st.file_uploader(" ", type=["jpg", "png"], on_change=on_file_upload)
|
||||||
|
|
||||||
|
paste_result = pbutton(
|
||||||
|
label="📋 Paste an image",
|
||||||
|
background_color="#5BBCFF",
|
||||||
|
hover_background_color="#3498db",
|
||||||
|
)
|
||||||
|
st.write("")
|
||||||
|
|
||||||
|
if st.session_state["CHANGE_SIDEBAR_FLAG"] is True:
|
||||||
|
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
|
||||||
|
elif uploaded_file or paste_result.image_data is not None:
|
||||||
|
if st.session_state["UPLOADED_FILE_CHANGED"] is False and paste_result.image_data is not None:
|
||||||
|
uploaded_file = io.BytesIO()
|
||||||
|
paste_result.image_data.save(uploaded_file, format="PNG")
|
||||||
|
uploaded_file.seek(0)
|
||||||
|
|
||||||
|
if st.session_state["UPLOADED_FILE_CHANGED"] is True:
|
||||||
|
st.session_state["UPLOADED_FILE_CHANGED"] = False
|
||||||
|
|
||||||
|
img = Image.open(uploaded_file)
|
||||||
|
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
png_fpath = os.path.join(temp_dir, "image.png")
|
||||||
|
img.save(png_fpath, "PNG")
|
||||||
|
|
||||||
|
with st.container(height=300):
|
||||||
|
img_base64 = get_image_base64(uploaded_file)
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
IMAGE_EMBED_HTML.format(img_base64=img_base64),
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
IMAGE_INFO_HTML.format(img_height=img.height, img_width=img.width),
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
st.write("")
|
||||||
|
|
||||||
|
with st.spinner("Predicting..."):
|
||||||
|
if inf_mode == "Formula recognition":
|
||||||
|
pred = img2latex(
|
||||||
|
model=latexrec_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
images=[png_fpath],
|
||||||
|
device=str2device(device),
|
||||||
|
out_format="katex",
|
||||||
|
num_beams=num_beams,
|
||||||
|
keep_style=False,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
pred = paragraph2md(
|
||||||
|
img_path=png_fpath,
|
||||||
|
latexdet_model=latexdet_model,
|
||||||
|
textdet_model=textdet_model,
|
||||||
|
textrec_model=textrec_model,
|
||||||
|
latexrec_model=latexrec_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
device=str2device(device),
|
||||||
|
num_beams=num_beams,
|
||||||
|
)
|
||||||
|
|
||||||
|
st.success("Completed!", icon="✅")
|
||||||
|
# st.markdown(SUCCESS_GIF_HTML, unsafe_allow_html=True)
|
||||||
|
# st.text_area("Predicted LaTeX", pred, height=150)
|
||||||
|
if inf_mode == "Formula recognition":
|
||||||
|
st.code(pred, language="latex")
|
||||||
|
elif inf_mode == "Paragraph recognition":
|
||||||
|
st.code(pred, language="markdown")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid inference mode: {inf_mode}")
|
||||||
|
|
||||||
|
if inf_mode == "Formula recognition":
|
||||||
|
st.latex(pred)
|
||||||
|
elif inf_mode == "Paragraph recognition":
|
||||||
|
mixed_res = re.split(r"(\$\$.*?\$\$)", pred, flags=re.DOTALL)
|
||||||
|
for text in mixed_res:
|
||||||
|
if text.startswith("$$") and text.endswith("$$"):
|
||||||
|
st.latex(text.strip("$$"))
|
||||||
|
else:
|
||||||
|
st.markdown(text)
|
||||||
|
|
||||||
|
st.write("")
|
||||||
|
st.write("")
|
||||||
|
|
||||||
|
with st.expander(":star2: :gray[Tips for better results]"):
|
||||||
|
st.markdown("""
|
||||||
|
* :mag_right: Use a clear and high-resolution image.
|
||||||
|
* :scissors: Crop images as accurately as possible.
|
||||||
|
* :jigsaw: Split large multi line formulas into smaller ones.
|
||||||
|
* :page_facing_up: Use images with **white background and black text** as much as possible.
|
||||||
|
* :book: Use a font with good readability.
|
||||||
|
""")
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
paste_result.image_data = None
|
||||||
|
|
||||||
|
# ====== </page> ======
|
||||||
55
texteller/cli/commands/web/style.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
from texteller.utils import lines_dedent
|
||||||
|
|
||||||
|
|
||||||
|
HEADER_HTML = lines_dedent("""
|
||||||
|
<h1 style="color: black; text-align: center;">
|
||||||
|
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
|
||||||
|
𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
|
||||||
|
<img src="https://raw.githubusercontent.com/OleehyO/TexTeller/main/assets/fire.svg" width="100">
|
||||||
|
</h1>
|
||||||
|
""")
|
||||||
|
|
||||||
|
SUCCESS_GIF_HTML = lines_dedent("""
|
||||||
|
<h1 style="color: black; text-align: center;">
|
||||||
|
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
|
||||||
|
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
|
||||||
|
<img src="https://slackmojis.com/emojis/90621-clapclap-e/download" width="50">
|
||||||
|
</h1>
|
||||||
|
""")
|
||||||
|
|
||||||
|
FAIL_GIF_HTML = lines_dedent("""
|
||||||
|
<h1 style="color: black; text-align: center;">
|
||||||
|
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download">
|
||||||
|
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download">
|
||||||
|
<img src="https://slackmojis.com/emojis/51439-allthethings_intensifies/download">
|
||||||
|
</h1>
|
||||||
|
""")
|
||||||
|
|
||||||
|
IMAGE_EMBED_HTML = lines_dedent("""
|
||||||
|
<style>
|
||||||
|
.centered-container {{
|
||||||
|
text-align: center;
|
||||||
|
}}
|
||||||
|
.centered-image {{
|
||||||
|
display: block;
|
||||||
|
margin-left: auto;
|
||||||
|
margin-right: auto;
|
||||||
|
max-height: 350px;
|
||||||
|
max-width: 100%;
|
||||||
|
}}
|
||||||
|
</style>
|
||||||
|
<div class="centered-container">
|
||||||
|
<img src="data:image/png;base64,{img_base64}" class="centered-image" alt="Input image">
|
||||||
|
</div>
|
||||||
|
""")
|
||||||
|
|
||||||
|
IMAGE_INFO_HTML = lines_dedent("""
|
||||||
|
<style>
|
||||||
|
.centered-container {{
|
||||||
|
text-align: center;
|
||||||
|
}}
|
||||||
|
</style>
|
||||||
|
<div class="centered-container">
|
||||||
|
<p style="color:gray;">Input image ({img_height}✖️{img_width})</p>
|
||||||
|
</div>
|
||||||
|
""")
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
import requests
|
|
||||||
|
|
||||||
rec_server_url = "http://127.0.0.1:8000/frec"
|
|
||||||
det_server_url = "http://127.0.0.1:8000/fdet"
|
|
||||||
|
|
||||||
img_path = "/your/image/path/"
|
|
||||||
with open(img_path, 'rb') as img:
|
|
||||||
files = {'img': img}
|
|
||||||
response = requests.post(rec_server_url, files=files)
|
|
||||||
# response = requests.post(det_server_url, files=files)
|
|
||||||
|
|
||||||
print(response.text)
|
|
||||||
@@ -21,3 +21,13 @@ MIN_RESIZE_RATIO = 0.75
|
|||||||
# Minimum height and width for input image for TexTeller
|
# Minimum height and width for input image for TexTeller
|
||||||
MIN_HEIGHT = 12
|
MIN_HEIGHT = 12
|
||||||
MIN_WIDTH = 30
|
MIN_WIDTH = 30
|
||||||
|
|
||||||
|
LATEX_DET_MODEL_URL = (
|
||||||
|
"https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx"
|
||||||
|
)
|
||||||
|
TEXT_REC_MODEL_URL = (
|
||||||
|
"https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_server_rec.onnx"
|
||||||
|
)
|
||||||
|
TEXT_DET_MODEL_URL = (
|
||||||
|
"https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_det.onnx"
|
||||||
|
)
|
||||||
41
texteller/globals.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class Globals:
|
||||||
|
"""
|
||||||
|
Singleton class for managing global variables with predefined and dynamic attributes.
|
||||||
|
|
||||||
|
Usage Example:
|
||||||
|
>>> # 1. Access predefined variable (with default value)
|
||||||
|
>>> print(Globals().repo_name) # Output: OleehyO/TexTeller
|
||||||
|
|
||||||
|
>>> # 2. Modify predefined variable
|
||||||
|
>>> Globals().repo_name = "NewRepo/NewProject"
|
||||||
|
>>> print(Globals().repo_name) # Output: NewRepo/NewProject
|
||||||
|
|
||||||
|
>>> # 3. Dynamically add new variable
|
||||||
|
>>> Globals().new_var = "hello"
|
||||||
|
>>> print(Globals().new_var) # Output: hello
|
||||||
|
|
||||||
|
>>> # 4. View all variables
|
||||||
|
>>> print(Globals()) # Output: <Globals: {'repo_name': ..., 'new_var': ...}>
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_initialized = False
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not self._initialized:
|
||||||
|
self.repo_name = "OleehyO/TexTeller"
|
||||||
|
self.logging_level = logging.INFO
|
||||||
|
self.cache_dir = Path("~/.cache/texteller").expanduser().resolve()
|
||||||
|
self.__class__._initialized = True
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Globals: {self.__dict__}>"
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
import os
|
|
||||||
import argparse
|
|
||||||
import glob
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
import onnxruntime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from models.det_model.inference import PredictConfig, predict_image
|
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
|
||||||
parser.add_argument(
|
|
||||||
"--infer_cfg", type=str, help="infer_cfg.yml", default="./models/det_model/model/infer_cfg.yml"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--onnx_file',
|
|
||||||
type=str,
|
|
||||||
help="onnx model file path",
|
|
||||||
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
|
|
||||||
)
|
|
||||||
parser.add_argument("--image_dir", type=str, default='./testImgs')
|
|
||||||
parser.add_argument("--image_file", type=str)
|
|
||||||
parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
|
|
||||||
parser.add_argument(
|
|
||||||
'--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_test_images(infer_dir, infer_img):
|
|
||||||
"""
|
|
||||||
Get image path list in TEST mode
|
|
||||||
"""
|
|
||||||
assert (
|
|
||||||
infer_img is not None or infer_dir is not None
|
|
||||||
), "--image_file or --image_dir should be set"
|
|
||||||
assert infer_img is None or os.path.isfile(infer_img), "{} is not a file".format(infer_img)
|
|
||||||
assert infer_dir is None or os.path.isdir(infer_dir), "{} is not a directory".format(infer_dir)
|
|
||||||
|
|
||||||
# infer_img has a higher priority
|
|
||||||
if infer_img and os.path.isfile(infer_img):
|
|
||||||
return [infer_img]
|
|
||||||
|
|
||||||
images = set()
|
|
||||||
infer_dir = os.path.abspath(infer_dir)
|
|
||||||
assert os.path.isdir(infer_dir), "infer_dir {} is not a directory".format(infer_dir)
|
|
||||||
exts = ['jpg', 'jpeg', 'png', 'bmp']
|
|
||||||
exts += [ext.upper() for ext in exts]
|
|
||||||
for ext in exts:
|
|
||||||
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
|
|
||||||
images = list(images)
|
|
||||||
|
|
||||||
assert len(images) > 0, "no image found in {}".format(infer_dir)
|
|
||||||
print("Found {} inference images in total.".format(len(images)))
|
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
def download_file(url, filename):
|
|
||||||
print(f"Downloading {filename}...")
|
|
||||||
subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True)
|
|
||||||
print("Download complete.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
cur_path = os.getcwd()
|
|
||||||
script_dirpath = Path(__file__).resolve().parent
|
|
||||||
os.chdir(script_dirpath)
|
|
||||||
|
|
||||||
FLAGS = parser.parse_args()
|
|
||||||
|
|
||||||
if not os.path.exists(FLAGS.infer_cfg):
|
|
||||||
infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true"
|
|
||||||
download_file(infer_cfg_url, FLAGS.infer_cfg)
|
|
||||||
|
|
||||||
if not os.path.exists(FLAGS.onnx_file):
|
|
||||||
onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true"
|
|
||||||
download_file(onnx_file_url, FLAGS.onnx_file)
|
|
||||||
|
|
||||||
# load image list
|
|
||||||
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
|
|
||||||
|
|
||||||
if FLAGS.use_gpu:
|
|
||||||
predictor = onnxruntime.InferenceSession(
|
|
||||||
FLAGS.onnx_file, providers=['CUDAExecutionProvider']
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
predictor = onnxruntime.InferenceSession(
|
|
||||||
FLAGS.onnx_file, providers=['CPUExecutionProvider']
|
|
||||||
)
|
|
||||||
# load infer config
|
|
||||||
infer_config = PredictConfig(FLAGS.infer_cfg)
|
|
||||||
|
|
||||||
predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list)
|
|
||||||
|
|
||||||
os.chdir(cur_path)
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
import os
|
|
||||||
import argparse
|
|
||||||
import cv2 as cv
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from onnxruntime import InferenceSession
|
|
||||||
from models.thrid_party.paddleocr.infer import predict_det, predict_rec
|
|
||||||
from models.thrid_party.paddleocr.infer import utility
|
|
||||||
|
|
||||||
from models.utils import mix_inference
|
|
||||||
from models.ocr_model.utils.to_katex import to_katex
|
|
||||||
from models.ocr_model.utils.inference import inference as latex_inference
|
|
||||||
|
|
||||||
from models.ocr_model.model.TexTeller import TexTeller
|
|
||||||
from models.det_model.inference import PredictConfig
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
os.chdir(Path(__file__).resolve().parent)
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('-img', type=str, required=True, help='path to the input image')
|
|
||||||
parser.add_argument(
|
|
||||||
'--inference-mode',
|
|
||||||
type=str,
|
|
||||||
default='cpu',
|
|
||||||
help='Inference mode, select one of cpu, cuda, or mps',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--num-beam', type=int, default=1, help='number of beam search for decoding'
|
|
||||||
)
|
|
||||||
parser.add_argument('-mix', action='store_true', help='use mix mode')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# You can use your own checkpoint and tokenizer path.
|
|
||||||
print('Loading model and tokenizer...')
|
|
||||||
latex_rec_model = TexTeller.from_pretrained()
|
|
||||||
tokenizer = TexTeller.get_tokenizer()
|
|
||||||
print('Model and tokenizer loaded.')
|
|
||||||
|
|
||||||
img_path = args.img
|
|
||||||
img = cv.imread(img_path)
|
|
||||||
print('Inference...')
|
|
||||||
if not args.mix:
|
|
||||||
res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam)
|
|
||||||
res = to_katex(res[0])
|
|
||||||
print(res)
|
|
||||||
else:
|
|
||||||
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
|
|
||||||
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
|
||||||
|
|
||||||
use_gpu = args.inference_mode == 'cuda'
|
|
||||||
SIZE_LIMIT = 20 * 1024 * 1024
|
|
||||||
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
|
|
||||||
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
|
|
||||||
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
|
|
||||||
det_use_gpu = False
|
|
||||||
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
|
|
||||||
|
|
||||||
paddleocr_args = utility.parse_args()
|
|
||||||
paddleocr_args.use_onnx = True
|
|
||||||
paddleocr_args.det_model_dir = det_model_dir
|
|
||||||
paddleocr_args.rec_model_dir = rec_model_dir
|
|
||||||
|
|
||||||
paddleocr_args.use_gpu = det_use_gpu
|
|
||||||
detector = predict_det.TextDetector(paddleocr_args)
|
|
||||||
paddleocr_args.use_gpu = rec_use_gpu
|
|
||||||
recognizer = predict_rec.TextRecognizer(paddleocr_args)
|
|
||||||
|
|
||||||
lang_ocr_models = [detector, recognizer]
|
|
||||||
latex_rec_models = [latex_rec_model, tokenizer]
|
|
||||||
res = mix_inference(
|
|
||||||
img_path,
|
|
||||||
infer_config,
|
|
||||||
latex_det_model,
|
|
||||||
lang_ocr_models,
|
|
||||||
latex_rec_models,
|
|
||||||
args.inference_mode,
|
|
||||||
args.num_beam,
|
|
||||||
)
|
|
||||||
print(res)
|
|
||||||
96
texteller/logger.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
|
import colorama
|
||||||
|
from colorama import Fore, Style
|
||||||
|
|
||||||
|
from texteller.globals import Globals
|
||||||
|
|
||||||
|
# Initialize colorama for colored console output
|
||||||
|
colorama.init(autoreset=True)
|
||||||
|
|
||||||
|
|
||||||
|
TEMPLATE = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
|
||||||
|
|
||||||
|
class ColoredFormatter(logging.Formatter):
|
||||||
|
"""Custom formatter to add colors based on log level."""
|
||||||
|
|
||||||
|
FORMATS = { # noqa: E501
|
||||||
|
logging.DEBUG: Fore.LIGHTBLACK_EX + TEMPLATE + Style.RESET_ALL,
|
||||||
|
logging.INFO: Fore.WHITE + TEMPLATE + Style.RESET_ALL,
|
||||||
|
logging.WARNING: Fore.YELLOW + TEMPLATE + Style.RESET_ALL,
|
||||||
|
logging.ERROR: Fore.RED + TEMPLATE + Style.RESET_ALL,
|
||||||
|
logging.CRITICAL: Fore.RED + Style.BRIGHT + TEMPLATE + Style.RESET_ALL,
|
||||||
|
} # noqa: E501
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
log_fmt = self.FORMATS.get(record.levelno, self.FORMATS[logging.INFO])
|
||||||
|
formatter = logging.Formatter(log_fmt, datefmt="%Y-%m-%d %H:%M:%S")
|
||||||
|
return formatter.format(record)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str | None = None, use_file_handler: bool = False) -> Logger:
|
||||||
|
"""
|
||||||
|
Creates and configures a logger with the caller's module name (if provided) or the first two modules.
|
||||||
|
If the module name is too long, it takes the first two modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str, optional): Custom logger name. If None, derives from caller's module.
|
||||||
|
use_file_handler (bool, optional): Whether to use a file handler. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logger: Configured logger with colored console output and file handler.
|
||||||
|
"""
|
||||||
|
# If name is not provided, derive it from the caller's module
|
||||||
|
if name is None:
|
||||||
|
# Get the caller's stack frame
|
||||||
|
frame = inspect.stack()[1]
|
||||||
|
module = inspect.getmodule(frame[0])
|
||||||
|
if module and module.__name__:
|
||||||
|
module_name = module.__name__
|
||||||
|
# Split module name and take first two components if too long
|
||||||
|
parts = module_name.split(".")
|
||||||
|
if len(parts) > 2:
|
||||||
|
name = ".".join(parts[:2])
|
||||||
|
else:
|
||||||
|
name = module_name
|
||||||
|
else:
|
||||||
|
name = "root"
|
||||||
|
|
||||||
|
# Create or get logger
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
|
# Prevent duplicate handlers
|
||||||
|
if logger.handlers:
|
||||||
|
return logger
|
||||||
|
|
||||||
|
# Set logger level
|
||||||
|
logger.setLevel(Globals().logging_level)
|
||||||
|
|
||||||
|
# Create console handler with colored formatter
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setLevel(Globals().logging_level)
|
||||||
|
console_formatter = ColoredFormatter()
|
||||||
|
console_handler.setFormatter(console_formatter)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# Create file handler
|
||||||
|
if use_file_handler:
|
||||||
|
log_dir = "logs"
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
log_file = os.path.join(log_dir, f"{datetime.now().strftime('%Y%m%d')}.log")
|
||||||
|
file_handler = logging.FileHandler(log_file)
|
||||||
|
file_handler.setLevel(Globals().logging_level)
|
||||||
|
# File formatter (no colors)
|
||||||
|
file_formatter = logging.Formatter(TEMPLATE, datefmt="%Y-%m-%d %H:%M:%S")
|
||||||
|
file_handler.setFormatter(file_formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# Prevent logger from propagating to root logger
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
return logger
|
||||||
3
texteller/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .texteller import TexTeller
|
||||||
|
|
||||||
|
__all__ = ['TexTeller']
|
||||||
@@ -1,226 +0,0 @@
|
|||||||
import os
|
|
||||||
import time
|
|
||||||
import yaml
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import List
|
|
||||||
from .preprocess import Compose
|
|
||||||
from .Bbox import Bbox
|
|
||||||
|
|
||||||
|
|
||||||
# Global dictionary
|
|
||||||
SUPPORT_MODELS = {
|
|
||||||
'YOLO',
|
|
||||||
'PPYOLOE',
|
|
||||||
'RCNN',
|
|
||||||
'SSD',
|
|
||||||
'Face',
|
|
||||||
'FCOS',
|
|
||||||
'SOLOv2',
|
|
||||||
'TTFNet',
|
|
||||||
'S2ANet',
|
|
||||||
'JDE',
|
|
||||||
'FairMOT',
|
|
||||||
'DeepSORT',
|
|
||||||
'GFL',
|
|
||||||
'PicoDet',
|
|
||||||
'CenterNet',
|
|
||||||
'TOOD',
|
|
||||||
'RetinaNet',
|
|
||||||
'StrongBaseline',
|
|
||||||
'STGCN',
|
|
||||||
'YOLOX',
|
|
||||||
'HRNet',
|
|
||||||
'DETR',
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class PredictConfig(object):
|
|
||||||
"""set config of preprocess, postprocess and visualize
|
|
||||||
Args:
|
|
||||||
infer_config (str): path of infer_cfg.yml
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, infer_config):
|
|
||||||
# parsing Yaml config for Preprocess
|
|
||||||
with open(infer_config) as f:
|
|
||||||
yml_conf = yaml.safe_load(f)
|
|
||||||
self.check_model(yml_conf)
|
|
||||||
self.arch = yml_conf['arch']
|
|
||||||
self.preprocess_infos = yml_conf['Preprocess']
|
|
||||||
self.min_subgraph_size = yml_conf['min_subgraph_size']
|
|
||||||
self.label_list = yml_conf['label_list']
|
|
||||||
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
|
|
||||||
self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
|
|
||||||
self.mask = yml_conf.get("mask", False)
|
|
||||||
self.tracker = yml_conf.get("tracker", None)
|
|
||||||
self.nms = yml_conf.get("NMS", None)
|
|
||||||
self.fpn_stride = yml_conf.get("fpn_stride", None)
|
|
||||||
|
|
||||||
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
|
|
||||||
self.colors = {
|
|
||||||
label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
|
|
||||||
print('The RCNN export model is used for ONNX and it only supports batch_size = 1')
|
|
||||||
self.print_config()
|
|
||||||
|
|
||||||
def check_model(self, yml_conf):
|
|
||||||
"""
|
|
||||||
Raises:
|
|
||||||
ValueError: loaded model not in supported model type
|
|
||||||
"""
|
|
||||||
for support_model in SUPPORT_MODELS:
|
|
||||||
if support_model in yml_conf['arch']:
|
|
||||||
return True
|
|
||||||
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf['arch'], SUPPORT_MODELS))
|
|
||||||
|
|
||||||
def print_config(self):
|
|
||||||
print('----------- Model Configuration -----------')
|
|
||||||
print('%s: %s' % ('Model Arch', self.arch))
|
|
||||||
print('%s: ' % ('Transform Order'))
|
|
||||||
for op_info in self.preprocess_infos:
|
|
||||||
print('--%s: %s' % ('transform op', op_info['type']))
|
|
||||||
print('--------------------------------------------')
|
|
||||||
|
|
||||||
|
|
||||||
def draw_bbox(image, outputs, infer_config):
|
|
||||||
for output in outputs:
|
|
||||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
|
||||||
if score > infer_config.draw_threshold:
|
|
||||||
label = infer_config.label_list[int(cls_id)]
|
|
||||||
color = infer_config.colors[label]
|
|
||||||
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
|
|
||||||
cv2.putText(
|
|
||||||
image,
|
|
||||||
"{}: {:.2f}".format(label, score),
|
|
||||||
(int(xmin), int(ymin - 5)),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX,
|
|
||||||
0.5,
|
|
||||||
color,
|
|
||||||
2,
|
|
||||||
)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def predict_image(imgsave_dir, infer_config, predictor, img_list):
|
|
||||||
# load preprocess transforms
|
|
||||||
transforms = Compose(infer_config.preprocess_infos)
|
|
||||||
errImgList = []
|
|
||||||
|
|
||||||
# Check and create subimg_save_dir if not exist
|
|
||||||
subimg_save_dir = os.path.join(imgsave_dir, 'subimages')
|
|
||||||
os.makedirs(subimg_save_dir, exist_ok=True)
|
|
||||||
|
|
||||||
first_image_skipped = False
|
|
||||||
total_time = 0
|
|
||||||
num_images = 0
|
|
||||||
# predict image
|
|
||||||
for img_path in tqdm(img_list):
|
|
||||||
img = cv2.imread(img_path)
|
|
||||||
if img is None:
|
|
||||||
print(f"Warning: Could not read image {img_path}. Skipping...")
|
|
||||||
errImgList.append(img_path)
|
|
||||||
continue
|
|
||||||
|
|
||||||
inputs = transforms(img_path)
|
|
||||||
inputs_name = [var.name for var in predictor.get_inputs()]
|
|
||||||
inputs = {k: inputs[k][None,] for k in inputs_name}
|
|
||||||
|
|
||||||
# Start timing
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
outputs = predictor.run(output_names=None, input_feed=inputs)
|
|
||||||
|
|
||||||
# Stop timing
|
|
||||||
end_time = time.time()
|
|
||||||
inference_time = end_time - start_time
|
|
||||||
if not first_image_skipped:
|
|
||||||
first_image_skipped = True
|
|
||||||
else:
|
|
||||||
total_time += inference_time
|
|
||||||
num_images += 1
|
|
||||||
print(
|
|
||||||
f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("ONNXRuntime predict: ")
|
|
||||||
if infer_config.arch in ["HRNet"]:
|
|
||||||
print(np.array(outputs[0]))
|
|
||||||
else:
|
|
||||||
bboxes = np.array(outputs[0])
|
|
||||||
for bbox in bboxes:
|
|
||||||
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
|
|
||||||
print(f"{int(bbox[0])} {bbox[1]} " f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
|
|
||||||
|
|
||||||
# Save the subimages (crop from the original image)
|
|
||||||
subimg_counter = 1
|
|
||||||
for output in np.array(outputs[0]):
|
|
||||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
|
||||||
if score > infer_config.draw_threshold:
|
|
||||||
label = infer_config.label_list[int(cls_id)]
|
|
||||||
subimg = img[int(max(ymin, 0)) : int(ymax), int(max(xmin, 0)) : int(xmax)]
|
|
||||||
if len(subimg) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg"
|
|
||||||
subimg_path = os.path.join(subimg_save_dir, subimg_filename)
|
|
||||||
cv2.imwrite(subimg_path, subimg)
|
|
||||||
subimg_counter += 1
|
|
||||||
|
|
||||||
# Draw bounding boxes and save the image with bounding boxes
|
|
||||||
img_with_mask = img.copy()
|
|
||||||
for output in np.array(outputs[0]):
|
|
||||||
cls_id, score, xmin, ymin, xmax, ymax = output
|
|
||||||
if score > infer_config.draw_threshold:
|
|
||||||
cv2.rectangle(
|
|
||||||
img_with_mask,
|
|
||||||
(int(xmin), int(ymin)),
|
|
||||||
(int(xmax), int(ymax)),
|
|
||||||
(255, 255, 255),
|
|
||||||
-1,
|
|
||||||
) # 盖白
|
|
||||||
|
|
||||||
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
|
|
||||||
|
|
||||||
output_dir = imgsave_dir
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
draw_box_dir = os.path.join(output_dir, 'draw_box')
|
|
||||||
mask_white_dir = os.path.join(output_dir, 'mask_white')
|
|
||||||
os.makedirs(draw_box_dir, exist_ok=True)
|
|
||||||
os.makedirs(mask_white_dir, exist_ok=True)
|
|
||||||
|
|
||||||
output_file_mask = os.path.join(mask_white_dir, os.path.basename(img_path))
|
|
||||||
output_file_bbox = os.path.join(draw_box_dir, os.path.basename(img_path))
|
|
||||||
cv2.imwrite(output_file_mask, img_with_mask)
|
|
||||||
cv2.imwrite(output_file_bbox, img_with_bbox)
|
|
||||||
|
|
||||||
avg_time_per_image = total_time / num_images if num_images > 0 else 0
|
|
||||||
print(f"Total inference time for {num_images} images: {total_time:.4f} seconds")
|
|
||||||
print(f"Average time per image: {avg_time_per_image:.4f} seconds")
|
|
||||||
print("ErrorImgs:")
|
|
||||||
print(errImgList)
|
|
||||||
|
|
||||||
|
|
||||||
def predict(img_path: str, predictor, infer_config) -> List[Bbox]:
|
|
||||||
transforms = Compose(infer_config.preprocess_infos)
|
|
||||||
inputs = transforms(img_path)
|
|
||||||
inputs_name = [var.name for var in predictor.get_inputs()]
|
|
||||||
inputs = {k: inputs[k][None,] for k in inputs_name}
|
|
||||||
|
|
||||||
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
|
|
||||||
res = []
|
|
||||||
for output in outputs:
|
|
||||||
cls_name = infer_config.label_list[int(output[0])]
|
|
||||||
score = output[1]
|
|
||||||
xmin = int(max(output[2], 0))
|
|
||||||
ymin = int(max(output[3], 0))
|
|
||||||
xmax = int(output[4])
|
|
||||||
ymax = int(output[5])
|
|
||||||
if score > infer_config.draw_threshold:
|
|
||||||
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
|
|
||||||
|
|
||||||
return res
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
mode: paddle
|
|
||||||
draw_threshold: 0.5
|
|
||||||
metric: COCO
|
|
||||||
use_dynamic_shape: false
|
|
||||||
arch: DETR
|
|
||||||
min_subgraph_size: 3
|
|
||||||
Preprocess:
|
|
||||||
- interp: 2
|
|
||||||
keep_ratio: false
|
|
||||||
target_size:
|
|
||||||
- 1600
|
|
||||||
- 1600
|
|
||||||
type: Resize
|
|
||||||
- mean:
|
|
||||||
- 0.0
|
|
||||||
- 0.0
|
|
||||||
- 0.0
|
|
||||||
norm_type: none
|
|
||||||
std:
|
|
||||||
- 1.0
|
|
||||||
- 1.0
|
|
||||||
- 1.0
|
|
||||||
type: NormalizeImage
|
|
||||||
- type: Permute
|
|
||||||
label_list:
|
|
||||||
- isolated
|
|
||||||
- embedding
|
|
||||||
@@ -1,485 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import copy
|
|
||||||
|
|
||||||
|
|
||||||
def decode_image(img_path):
|
|
||||||
if isinstance(img_path, str):
|
|
||||||
with open(img_path, 'rb') as f:
|
|
||||||
im_read = f.read()
|
|
||||||
data = np.frombuffer(im_read, dtype='uint8')
|
|
||||||
else:
|
|
||||||
assert isinstance(img_path, np.ndarray)
|
|
||||||
data = img_path
|
|
||||||
|
|
||||||
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
|
|
||||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
|
||||||
img_info = {
|
|
||||||
"im_shape": np.array(im.shape[:2], dtype=np.float32),
|
|
||||||
"scale_factor": np.array([1.0, 1.0], dtype=np.float32),
|
|
||||||
}
|
|
||||||
return im, img_info
|
|
||||||
|
|
||||||
|
|
||||||
class Resize(object):
|
|
||||||
"""resize image by target_size and max_size
|
|
||||||
Args:
|
|
||||||
target_size (int): the target size of image
|
|
||||||
keep_ratio (bool): whether keep_ratio or not, default true
|
|
||||||
interp (int): method of resize
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
|
|
||||||
if isinstance(target_size, int):
|
|
||||||
target_size = [target_size, target_size]
|
|
||||||
self.target_size = target_size
|
|
||||||
self.keep_ratio = keep_ratio
|
|
||||||
self.interp = interp
|
|
||||||
|
|
||||||
def __call__(self, im, im_info):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
im (np.ndarray): image (np.ndarray)
|
|
||||||
im_info (dict): info of image
|
|
||||||
Returns:
|
|
||||||
im (np.ndarray): processed image (np.ndarray)
|
|
||||||
im_info (dict): info of processed image
|
|
||||||
"""
|
|
||||||
assert len(self.target_size) == 2
|
|
||||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
|
||||||
im_channel = im.shape[2]
|
|
||||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
|
||||||
im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp)
|
|
||||||
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
|
|
||||||
im_info['scale_factor'] = np.array([im_scale_y, im_scale_x]).astype('float32')
|
|
||||||
return im, im_info
|
|
||||||
|
|
||||||
def generate_scale(self, im):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
im (np.ndarray): image (np.ndarray)
|
|
||||||
Returns:
|
|
||||||
im_scale_x: the resize ratio of X
|
|
||||||
im_scale_y: the resize ratio of Y
|
|
||||||
"""
|
|
||||||
origin_shape = im.shape[:2]
|
|
||||||
im_c = im.shape[2]
|
|
||||||
if self.keep_ratio:
|
|
||||||
im_size_min = np.min(origin_shape)
|
|
||||||
im_size_max = np.max(origin_shape)
|
|
||||||
target_size_min = np.min(self.target_size)
|
|
||||||
target_size_max = np.max(self.target_size)
|
|
||||||
im_scale = float(target_size_min) / float(im_size_min)
|
|
||||||
if np.round(im_scale * im_size_max) > target_size_max:
|
|
||||||
im_scale = float(target_size_max) / float(im_size_max)
|
|
||||||
im_scale_x = im_scale
|
|
||||||
im_scale_y = im_scale
|
|
||||||
else:
|
|
||||||
resize_h, resize_w = self.target_size
|
|
||||||
im_scale_y = resize_h / float(origin_shape[0])
|
|
||||||
im_scale_x = resize_w / float(origin_shape[1])
|
|
||||||
return im_scale_y, im_scale_x
|
|
||||||
|
|
||||||
|
|
||||||
class NormalizeImage(object):
|
|
||||||
"""normalize image
|
|
||||||
Args:
|
|
||||||
mean (list): im - mean
|
|
||||||
std (list): im / std
|
|
||||||
is_scale (bool): whether need im / 255
|
|
||||||
norm_type (str): type in ['mean_std', 'none']
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
|
|
||||||
self.mean = mean
|
|
||||||
self.std = std
|
|
||||||
self.is_scale = is_scale
|
|
||||||
self.norm_type = norm_type
|
|
||||||
|
|
||||||
def __call__(self, im, im_info):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
im (np.ndarray): image (np.ndarray)
|
|
||||||
im_info (dict): info of image
|
|
||||||
Returns:
|
|
||||||
im (np.ndarray): processed image (np.ndarray)
|
|
||||||
im_info (dict): info of processed image
|
|
||||||
"""
|
|
||||||
im = im.astype(np.float32, copy=False)
|
|
||||||
if self.is_scale:
|
|
||||||
scale = 1.0 / 255.0
|
|
||||||
im *= scale
|
|
||||||
|
|
||||||
if self.norm_type == 'mean_std':
|
|
||||||
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
|
|
||||||
std = np.array(self.std)[np.newaxis, np.newaxis, :]
|
|
||||||
im -= mean
|
|
||||||
im /= std
|
|
||||||
return im, im_info
|
|
||||||
|
|
||||||
|
|
||||||
class Permute(object):
|
|
||||||
"""permute image
|
|
||||||
Args:
|
|
||||||
to_bgr (bool): whether convert RGB to BGR
|
|
||||||
channel_first (bool): whether convert HWC to CHW
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
):
|
|
||||||
super(Permute, self).__init__()
|
|
||||||
|
|
||||||
def __call__(self, im, im_info):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
im (np.ndarray): image (np.ndarray)
|
|
||||||
im_info (dict): info of image
|
|
||||||
Returns:
|
|
||||||
im (np.ndarray): processed image (np.ndarray)
|
|
||||||
im_info (dict): info of processed image
|
|
||||||
"""
|
|
||||||
im = im.transpose((2, 0, 1)).copy()
|
|
||||||
return im, im_info
|
|
||||||
|
|
||||||
|
|
||||||
class PadStride(object):
|
|
||||||
"""padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
|
|
||||||
Args:
|
|
||||||
stride (bool): model with FPN need image shape % stride == 0
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, stride=0):
|
|
||||||
self.coarsest_stride = stride
|
|
||||||
|
|
||||||
def __call__(self, im, im_info):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
im (np.ndarray): image (np.ndarray)
|
|
||||||
im_info (dict): info of image
|
|
||||||
Returns:
|
|
||||||
im (np.ndarray): processed image (np.ndarray)
|
|
||||||
im_info (dict): info of processed image
|
|
||||||
"""
|
|
||||||
coarsest_stride = self.coarsest_stride
|
|
||||||
if coarsest_stride <= 0:
|
|
||||||
return im, im_info
|
|
||||||
im_c, im_h, im_w = im.shape
|
|
||||||
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
|
|
||||||
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
|
|
||||||
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
|
|
||||||
padding_im[:, :im_h, :im_w] = im
|
|
||||||
return padding_im, im_info
|
|
||||||
|
|
||||||
|
|
||||||
class LetterBoxResize(object):
|
|
||||||
def __init__(self, target_size):
|
|
||||||
"""
|
|
||||||
Resize image to target size, convert normalized xywh to pixel xyxy
|
|
||||||
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
|
|
||||||
Args:
|
|
||||||
target_size (int|list): image target size.
|
|
||||||
"""
|
|
||||||
super(LetterBoxResize, self).__init__()
|
|
||||||
if isinstance(target_size, int):
|
|
||||||
target_size = [target_size, target_size]
|
|
||||||
self.target_size = target_size
|
|
||||||
|
|
||||||
def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
|
|
||||||
# letterbox: resize a rectangular image to a padded rectangular
|
|
||||||
shape = img.shape[:2] # [height, width]
|
|
||||||
ratio_h = float(height) / shape[0]
|
|
||||||
ratio_w = float(width) / shape[1]
|
|
||||||
ratio = min(ratio_h, ratio_w)
|
|
||||||
new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) # [width, height]
|
|
||||||
padw = (width - new_shape[0]) / 2
|
|
||||||
padh = (height - new_shape[1]) / 2
|
|
||||||
top, bottom = round(padh - 0.1), round(padh + 0.1)
|
|
||||||
left, right = round(padw - 0.1), round(padw + 0.1)
|
|
||||||
|
|
||||||
img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
|
|
||||||
img = cv2.copyMakeBorder(
|
|
||||||
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
|
|
||||||
) # padded rectangular
|
|
||||||
return img, ratio, padw, padh
|
|
||||||
|
|
||||||
def __call__(self, im, im_info):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
im (np.ndarray): image (np.ndarray)
|
|
||||||
im_info (dict): info of image
|
|
||||||
Returns:
|
|
||||||
im (np.ndarray): processed image (np.ndarray)
|
|
||||||
im_info (dict): info of processed image
|
|
||||||
"""
|
|
||||||
assert len(self.target_size) == 2
|
|
||||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
|
||||||
height, width = self.target_size
|
|
||||||
h, w = im.shape[:2]
|
|
||||||
im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
|
|
||||||
|
|
||||||
new_shape = [round(h * ratio), round(w * ratio)]
|
|
||||||
im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
|
|
||||||
im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
|
|
||||||
return im, im_info
|
|
||||||
|
|
||||||
|
|
||||||
class Pad(object):
|
|
||||||
def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
|
|
||||||
"""
|
|
||||||
Pad image to a specified size.
|
|
||||||
Args:
|
|
||||||
size (list[int]): image target size
|
|
||||||
fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
|
|
||||||
"""
|
|
||||||
super(Pad, self).__init__()
|
|
||||||
if isinstance(size, int):
|
|
||||||
size = [size, size]
|
|
||||||
self.size = size
|
|
||||||
self.fill_value = fill_value
|
|
||||||
|
|
||||||
def __call__(self, im, im_info):
|
|
||||||
im_h, im_w = im.shape[:2]
|
|
||||||
h, w = self.size
|
|
||||||
if h == im_h and w == im_w:
|
|
||||||
im = im.astype(np.float32)
|
|
||||||
return im, im_info
|
|
||||||
|
|
||||||
canvas = np.ones((h, w, 3), dtype=np.float32)
|
|
||||||
canvas *= np.array(self.fill_value, dtype=np.float32)
|
|
||||||
canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
|
|
||||||
im = canvas
|
|
||||||
return im, im_info
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_point(pt, angle_rad):
|
|
||||||
"""Rotate a point by an angle.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pt (list[float]): 2 dimensional point to be rotated
|
|
||||||
angle_rad (float): rotation angle by radian
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[float]: Rotated point.
|
|
||||||
"""
|
|
||||||
assert len(pt) == 2
|
|
||||||
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
|
||||||
new_x = pt[0] * cs - pt[1] * sn
|
|
||||||
new_y = pt[0] * sn + pt[1] * cs
|
|
||||||
rotated_pt = [new_x, new_y]
|
|
||||||
|
|
||||||
return rotated_pt
|
|
||||||
|
|
||||||
|
|
||||||
def _get_3rd_point(a, b):
|
|
||||||
"""To calculate the affine matrix, three pairs of points are required. This
|
|
||||||
function is used to get the 3rd point, given 2D points a & b.
|
|
||||||
|
|
||||||
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
|
||||||
anticlockwise, using b as the rotation center.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
a (np.ndarray): point(x,y)
|
|
||||||
b (np.ndarray): point(x,y)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: The 3rd point.
|
|
||||||
"""
|
|
||||||
assert len(a) == 2
|
|
||||||
assert len(b) == 2
|
|
||||||
direction = a - b
|
|
||||||
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
|
|
||||||
|
|
||||||
return third_pt
|
|
||||||
|
|
||||||
|
|
||||||
def get_affine_transform(center, input_size, rot, output_size, shift=(0.0, 0.0), inv=False):
|
|
||||||
"""Get the affine transform matrix, given the center/scale/rot/output_size.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
|
||||||
scale (np.ndarray[2, ]): Scale of the bounding box
|
|
||||||
wrt [width, height].
|
|
||||||
rot (float): Rotation angle (degree).
|
|
||||||
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
|
|
||||||
shift (0-100%): Shift translation ratio wrt the width/height.
|
|
||||||
Default (0., 0.).
|
|
||||||
inv (bool): Option to inverse the affine transform direction.
|
|
||||||
(inv=False: src->dst or inv=True: dst->src)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: The transform matrix.
|
|
||||||
"""
|
|
||||||
assert len(center) == 2
|
|
||||||
assert len(output_size) == 2
|
|
||||||
assert len(shift) == 2
|
|
||||||
if not isinstance(input_size, (np.ndarray, list)):
|
|
||||||
input_size = np.array([input_size, input_size], dtype=np.float32)
|
|
||||||
scale_tmp = input_size
|
|
||||||
|
|
||||||
shift = np.array(shift)
|
|
||||||
src_w = scale_tmp[0]
|
|
||||||
dst_w = output_size[0]
|
|
||||||
dst_h = output_size[1]
|
|
||||||
|
|
||||||
rot_rad = np.pi * rot / 180
|
|
||||||
src_dir = rotate_point([0.0, src_w * -0.5], rot_rad)
|
|
||||||
dst_dir = np.array([0.0, dst_w * -0.5])
|
|
||||||
|
|
||||||
src = np.zeros((3, 2), dtype=np.float32)
|
|
||||||
src[0, :] = center + scale_tmp * shift
|
|
||||||
src[1, :] = center + src_dir + scale_tmp * shift
|
|
||||||
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
|
||||||
|
|
||||||
dst = np.zeros((3, 2), dtype=np.float32)
|
|
||||||
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
|
||||||
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
|
||||||
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
|
||||||
|
|
||||||
if inv:
|
|
||||||
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
|
||||||
else:
|
|
||||||
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
|
||||||
|
|
||||||
return trans
|
|
||||||
|
|
||||||
|
|
||||||
class WarpAffine(object):
|
|
||||||
"""Warp affine the image"""
|
|
||||||
|
|
||||||
def __init__(self, keep_res=False, pad=31, input_h=512, input_w=512, scale=0.4, shift=0.1):
|
|
||||||
self.keep_res = keep_res
|
|
||||||
self.pad = pad
|
|
||||||
self.input_h = input_h
|
|
||||||
self.input_w = input_w
|
|
||||||
self.scale = scale
|
|
||||||
self.shift = shift
|
|
||||||
|
|
||||||
def __call__(self, im, im_info):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
im (np.ndarray): image (np.ndarray)
|
|
||||||
im_info (dict): info of image
|
|
||||||
Returns:
|
|
||||||
im (np.ndarray): processed image (np.ndarray)
|
|
||||||
im_info (dict): info of processed image
|
|
||||||
"""
|
|
||||||
img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
|
||||||
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
|
|
||||||
if self.keep_res:
|
|
||||||
input_h = (h | self.pad) + 1
|
|
||||||
input_w = (w | self.pad) + 1
|
|
||||||
s = np.array([input_w, input_h], dtype=np.float32)
|
|
||||||
c = np.array([w // 2, h // 2], dtype=np.float32)
|
|
||||||
|
|
||||||
else:
|
|
||||||
s = max(h, w) * 1.0
|
|
||||||
input_h, input_w = self.input_h, self.input_w
|
|
||||||
c = np.array([w / 2.0, h / 2.0], dtype=np.float32)
|
|
||||||
|
|
||||||
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
|
|
||||||
img = cv2.resize(img, (w, h))
|
|
||||||
inp = cv2.warpAffine(img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
|
|
||||||
return inp, im_info
|
|
||||||
|
|
||||||
|
|
||||||
# keypoint preprocess
|
|
||||||
def get_warp_matrix(theta, size_input, size_dst, size_target):
|
|
||||||
"""This code is based on
|
|
||||||
https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
|
|
||||||
|
|
||||||
Calculate the transformation matrix under the constraint of unbiased.
|
|
||||||
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
|
|
||||||
Data Processing for Human Pose Estimation (CVPR 2020).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
theta (float): Rotation angle in degrees.
|
|
||||||
size_input (np.ndarray): Size of input image [w, h].
|
|
||||||
size_dst (np.ndarray): Size of output image [w, h].
|
|
||||||
size_target (np.ndarray): Size of ROI in input plane [w, h].
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
matrix (np.ndarray): A matrix for transformation.
|
|
||||||
"""
|
|
||||||
theta = np.deg2rad(theta)
|
|
||||||
matrix = np.zeros((2, 3), dtype=np.float32)
|
|
||||||
scale_x = size_dst[0] / size_target[0]
|
|
||||||
scale_y = size_dst[1] / size_target[1]
|
|
||||||
matrix[0, 0] = np.cos(theta) * scale_x
|
|
||||||
matrix[0, 1] = -np.sin(theta) * scale_x
|
|
||||||
matrix[0, 2] = scale_x * (
|
|
||||||
-0.5 * size_input[0] * np.cos(theta)
|
|
||||||
+ 0.5 * size_input[1] * np.sin(theta)
|
|
||||||
+ 0.5 * size_target[0]
|
|
||||||
)
|
|
||||||
matrix[1, 0] = np.sin(theta) * scale_y
|
|
||||||
matrix[1, 1] = np.cos(theta) * scale_y
|
|
||||||
matrix[1, 2] = scale_y * (
|
|
||||||
-0.5 * size_input[0] * np.sin(theta)
|
|
||||||
- 0.5 * size_input[1] * np.cos(theta)
|
|
||||||
+ 0.5 * size_target[1]
|
|
||||||
)
|
|
||||||
return matrix
|
|
||||||
|
|
||||||
|
|
||||||
class TopDownEvalAffine(object):
|
|
||||||
"""apply affine transform to image and coords
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trainsize (list): [w, h], the standard size used to train
|
|
||||||
use_udp (bool): whether to use Unbiased Data Processing.
|
|
||||||
records(dict): the dict contained the image and coords
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
records (dict): contain the image and coords after tranformed
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, trainsize, use_udp=False):
|
|
||||||
self.trainsize = trainsize
|
|
||||||
self.use_udp = use_udp
|
|
||||||
|
|
||||||
def __call__(self, image, im_info):
|
|
||||||
rot = 0
|
|
||||||
imshape = im_info['im_shape'][::-1]
|
|
||||||
center = im_info['center'] if 'center' in im_info else imshape / 2.0
|
|
||||||
scale = im_info['scale'] if 'scale' in im_info else imshape
|
|
||||||
if self.use_udp:
|
|
||||||
trans = get_warp_matrix(
|
|
||||||
rot, center * 2.0, [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale
|
|
||||||
)
|
|
||||||
image = cv2.warpAffine(
|
|
||||||
image,
|
|
||||||
trans,
|
|
||||||
(int(self.trainsize[0]), int(self.trainsize[1])),
|
|
||||||
flags=cv2.INTER_LINEAR,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
trans = get_affine_transform(center, scale, rot, self.trainsize)
|
|
||||||
image = cv2.warpAffine(
|
|
||||||
image,
|
|
||||||
trans,
|
|
||||||
(int(self.trainsize[0]), int(self.trainsize[1])),
|
|
||||||
flags=cv2.INTER_LINEAR,
|
|
||||||
)
|
|
||||||
|
|
||||||
return image, im_info
|
|
||||||
|
|
||||||
|
|
||||||
class Compose:
|
|
||||||
def __init__(self, transforms):
|
|
||||||
self.transforms = []
|
|
||||||
for op_info in transforms:
|
|
||||||
new_op_info = op_info.copy()
|
|
||||||
op_type = new_op_info.pop('type')
|
|
||||||
self.transforms.append(eval(op_type)(**new_op_info))
|
|
||||||
|
|
||||||
def __call__(self, img_path):
|
|
||||||
img, im_info = decode_image(img_path)
|
|
||||||
for t in self.transforms:
|
|
||||||
img, im_info = t(img, im_info)
|
|
||||||
inputs = copy.deepcopy(im_info)
|
|
||||||
inputs['image'] = img
|
|
||||||
return inputs
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from ...globals import VOCAB_SIZE, FIXED_IMG_SIZE, IMG_CHANNELS, MAX_TOKEN_SIZE
|
|
||||||
|
|
||||||
from transformers import RobertaTokenizerFast, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
|
|
||||||
|
|
||||||
|
|
||||||
class TexTeller(VisionEncoderDecoderModel):
|
|
||||||
REPO_NAME = 'OleehyO/TexTeller'
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
config = VisionEncoderDecoderConfig.from_pretrained(
|
|
||||||
Path(__file__).resolve().parent / "config.json"
|
|
||||||
)
|
|
||||||
config.encoder.image_size = FIXED_IMG_SIZE
|
|
||||||
config.encoder.num_channels = IMG_CHANNELS
|
|
||||||
config.decoder.vocab_size = VOCAB_SIZE
|
|
||||||
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
|
|
||||||
|
|
||||||
super().__init__(config=config)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
|
|
||||||
if model_path is None or model_path == 'default':
|
|
||||||
if not use_onnx:
|
|
||||||
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
|
|
||||||
else:
|
|
||||||
from optimum.onnxruntime import ORTModelForVision2Seq
|
|
||||||
|
|
||||||
use_gpu = True if onnx_provider == 'cuda' else False
|
|
||||||
return ORTModelForVision2Seq.from_pretrained(
|
|
||||||
cls.REPO_NAME,
|
|
||||||
provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider",
|
|
||||||
)
|
|
||||||
model_path = Path(model_path).resolve()
|
|
||||||
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast:
|
|
||||||
if tokenizer_path is None or tokenizer_path == 'default':
|
|
||||||
return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME)
|
|
||||||
tokenizer_path = Path(tokenizer_path).resolve()
|
|
||||||
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))
|
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
{
|
|
||||||
"_name_or_path": "OleehyO/TexTeller",
|
|
||||||
"architectures": [
|
|
||||||
"VisionEncoderDecoderModel"
|
|
||||||
],
|
|
||||||
"decoder": {
|
|
||||||
"_name_or_path": "",
|
|
||||||
"activation_dropout": 0.0,
|
|
||||||
"activation_function": "gelu",
|
|
||||||
"add_cross_attention": true,
|
|
||||||
"architectures": null,
|
|
||||||
"attention_dropout": 0.0,
|
|
||||||
"bad_words_ids": null,
|
|
||||||
"begin_suppress_tokens": null,
|
|
||||||
"bos_token_id": 0,
|
|
||||||
"chunk_size_feed_forward": 0,
|
|
||||||
"classifier_dropout": 0.0,
|
|
||||||
"cross_attention_hidden_size": 768,
|
|
||||||
"d_model": 1024,
|
|
||||||
"decoder_attention_heads": 16,
|
|
||||||
"decoder_ffn_dim": 4096,
|
|
||||||
"decoder_layerdrop": 0.0,
|
|
||||||
"decoder_layers": 12,
|
|
||||||
"decoder_start_token_id": 2,
|
|
||||||
"diversity_penalty": 0.0,
|
|
||||||
"do_sample": false,
|
|
||||||
"dropout": 0.1,
|
|
||||||
"early_stopping": false,
|
|
||||||
"encoder_no_repeat_ngram_size": 0,
|
|
||||||
"eos_token_id": 2,
|
|
||||||
"exponential_decay_length_penalty": null,
|
|
||||||
"finetuning_task": null,
|
|
||||||
"forced_bos_token_id": null,
|
|
||||||
"forced_eos_token_id": null,
|
|
||||||
"id2label": {
|
|
||||||
"0": "LABEL_0",
|
|
||||||
"1": "LABEL_1"
|
|
||||||
},
|
|
||||||
"init_std": 0.02,
|
|
||||||
"is_decoder": true,
|
|
||||||
"is_encoder_decoder": false,
|
|
||||||
"label2id": {
|
|
||||||
"LABEL_0": 0,
|
|
||||||
"LABEL_1": 1
|
|
||||||
},
|
|
||||||
"layernorm_embedding": true,
|
|
||||||
"length_penalty": 1.0,
|
|
||||||
"max_length": 20,
|
|
||||||
"max_position_embeddings": 1024,
|
|
||||||
"min_length": 0,
|
|
||||||
"model_type": "trocr",
|
|
||||||
"no_repeat_ngram_size": 0,
|
|
||||||
"num_beam_groups": 1,
|
|
||||||
"num_beams": 1,
|
|
||||||
"num_return_sequences": 1,
|
|
||||||
"output_attentions": false,
|
|
||||||
"output_hidden_states": false,
|
|
||||||
"output_scores": false,
|
|
||||||
"pad_token_id": 1,
|
|
||||||
"prefix": null,
|
|
||||||
"problem_type": null,
|
|
||||||
"pruned_heads": {},
|
|
||||||
"remove_invalid_values": false,
|
|
||||||
"repetition_penalty": 1.0,
|
|
||||||
"return_dict": true,
|
|
||||||
"return_dict_in_generate": false,
|
|
||||||
"scale_embedding": false,
|
|
||||||
"sep_token_id": null,
|
|
||||||
"suppress_tokens": null,
|
|
||||||
"task_specific_params": null,
|
|
||||||
"temperature": 1.0,
|
|
||||||
"tf_legacy_loss": false,
|
|
||||||
"tie_encoder_decoder": false,
|
|
||||||
"tie_word_embeddings": true,
|
|
||||||
"tokenizer_class": null,
|
|
||||||
"top_k": 50,
|
|
||||||
"top_p": 1.0,
|
|
||||||
"torch_dtype": null,
|
|
||||||
"torchscript": false,
|
|
||||||
"typical_p": 1.0,
|
|
||||||
"use_bfloat16": false,
|
|
||||||
"use_cache": false,
|
|
||||||
"use_learned_position_embeddings": true,
|
|
||||||
"vocab_size": 15000
|
|
||||||
},
|
|
||||||
"encoder": {
|
|
||||||
"_name_or_path": "",
|
|
||||||
"add_cross_attention": false,
|
|
||||||
"architectures": null,
|
|
||||||
"attention_probs_dropout_prob": 0.0,
|
|
||||||
"bad_words_ids": null,
|
|
||||||
"begin_suppress_tokens": null,
|
|
||||||
"bos_token_id": null,
|
|
||||||
"chunk_size_feed_forward": 0,
|
|
||||||
"cross_attention_hidden_size": null,
|
|
||||||
"decoder_start_token_id": null,
|
|
||||||
"diversity_penalty": 0.0,
|
|
||||||
"do_sample": false,
|
|
||||||
"early_stopping": false,
|
|
||||||
"encoder_no_repeat_ngram_size": 0,
|
|
||||||
"encoder_stride": 16,
|
|
||||||
"eos_token_id": null,
|
|
||||||
"exponential_decay_length_penalty": null,
|
|
||||||
"finetuning_task": null,
|
|
||||||
"forced_bos_token_id": null,
|
|
||||||
"forced_eos_token_id": null,
|
|
||||||
"hidden_act": "gelu",
|
|
||||||
"hidden_dropout_prob": 0.0,
|
|
||||||
"hidden_size": 768,
|
|
||||||
"id2label": {
|
|
||||||
"0": "LABEL_0",
|
|
||||||
"1": "LABEL_1"
|
|
||||||
},
|
|
||||||
"image_size": 448,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"intermediate_size": 3072,
|
|
||||||
"is_decoder": false,
|
|
||||||
"is_encoder_decoder": false,
|
|
||||||
"label2id": {
|
|
||||||
"LABEL_0": 0,
|
|
||||||
"LABEL_1": 1
|
|
||||||
},
|
|
||||||
"layer_norm_eps": 1e-12,
|
|
||||||
"length_penalty": 1.0,
|
|
||||||
"max_length": 20,
|
|
||||||
"min_length": 0,
|
|
||||||
"model_type": "vit",
|
|
||||||
"no_repeat_ngram_size": 0,
|
|
||||||
"num_attention_heads": 12,
|
|
||||||
"num_beam_groups": 1,
|
|
||||||
"num_beams": 1,
|
|
||||||
"num_channels": 1,
|
|
||||||
"num_hidden_layers": 12,
|
|
||||||
"num_return_sequences": 1,
|
|
||||||
"output_attentions": false,
|
|
||||||
"output_hidden_states": false,
|
|
||||||
"output_scores": false,
|
|
||||||
"pad_token_id": null,
|
|
||||||
"patch_size": 16,
|
|
||||||
"prefix": null,
|
|
||||||
"problem_type": null,
|
|
||||||
"pruned_heads": {},
|
|
||||||
"qkv_bias": false,
|
|
||||||
"remove_invalid_values": false,
|
|
||||||
"repetition_penalty": 1.0,
|
|
||||||
"return_dict": true,
|
|
||||||
"return_dict_in_generate": false,
|
|
||||||
"sep_token_id": null,
|
|
||||||
"suppress_tokens": null,
|
|
||||||
"task_specific_params": null,
|
|
||||||
"temperature": 1.0,
|
|
||||||
"tf_legacy_loss": false,
|
|
||||||
"tie_encoder_decoder": false,
|
|
||||||
"tie_word_embeddings": true,
|
|
||||||
"tokenizer_class": null,
|
|
||||||
"top_k": 50,
|
|
||||||
"top_p": 1.0,
|
|
||||||
"torch_dtype": null,
|
|
||||||
"torchscript": false,
|
|
||||||
"typical_p": 1.0,
|
|
||||||
"use_bfloat16": false
|
|
||||||
},
|
|
||||||
"is_encoder_decoder": true,
|
|
||||||
"model_type": "vision-encoder-decoder",
|
|
||||||
"tie_word_embeddings": false,
|
|
||||||
"transformers_version": "4.41.2",
|
|
||||||
"use_cache": true
|
|
||||||
}
|
|
||||||
|
Before Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 8.7 KiB |
|
Before Width: | Height: | Size: 6.8 KiB |
|
Before Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 5.2 KiB |
|
Before Width: | Height: | Size: 12 KiB |
|
Before Width: | Height: | Size: 2.8 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 2.6 KiB |
|
Before Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.7 KiB |
|
Before Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 3.7 KiB |
|
Before Width: | Height: | Size: 3.5 KiB |
|
Before Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.5 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 3.1 KiB |
|
Before Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 5.3 KiB |
|
Before Width: | Height: | Size: 4.1 KiB |
|
Before Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 4.9 KiB |
|
Before Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 1.8 KiB |
|
Before Width: | Height: | Size: 3.2 KiB |
|
Before Width: | Height: | Size: 5.7 KiB |
|
Before Width: | Height: | Size: 11 KiB |
|
Before Width: | Height: | Size: 4.8 KiB |
|
Before Width: | Height: | Size: 4.5 KiB |
|
Before Width: | Height: | Size: 2.5 KiB |
|
Before Width: | Height: | Size: 5.2 KiB |
@@ -1,35 +0,0 @@
|
|||||||
{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
|
|
||||||
{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
|
|
||||||
{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
|
|
||||||
{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
|
|
||||||
{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
|
|
||||||
{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
|
|
||||||
{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
|
|
||||||
{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
|
|
||||||
{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
|
|
||||||
{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
|
|
||||||
{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
|
|
||||||
{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
|
|
||||||
{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
|
|
||||||
{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
|
|
||||||
{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
|
|
||||||
{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
|
|
||||||
{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
|
|
||||||
{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
|
|
||||||
{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
|
|
||||||
{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
|
|
||||||
{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
|
|
||||||
{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
|
|
||||||
{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
|
|
||||||
{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
|
|
||||||
{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
|
|
||||||
{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
|
|
||||||
{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
|
|
||||||
{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
|
|
||||||
{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
|
|
||||||
{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
|
||||||
{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
|
|
||||||
{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
|
||||||
{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
|
|
||||||
{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
|
|
||||||
{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import (
|
|
||||||
Trainer,
|
|
||||||
TrainingArguments,
|
|
||||||
Seq2SeqTrainer,
|
|
||||||
Seq2SeqTrainingArguments,
|
|
||||||
GenerationConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .training_args import CONFIG
|
|
||||||
from ..model.TexTeller import TexTeller
|
|
||||||
from ..utils.functional import (
|
|
||||||
tokenize_fn,
|
|
||||||
collate_fn,
|
|
||||||
img_train_transform,
|
|
||||||
img_inf_transform,
|
|
||||||
filter_fn,
|
|
||||||
)
|
|
||||||
from ..utils.metrics import bleu_metric
|
|
||||||
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
|
||||||
|
|
||||||
|
|
||||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
|
||||||
training_args = TrainingArguments(**CONFIG)
|
|
||||||
trainer = Trainer(
|
|
||||||
model,
|
|
||||||
training_args,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
data_collator=collate_fn_with_tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.train(resume_from_checkpoint=None)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
|
||||||
eval_config = CONFIG.copy()
|
|
||||||
eval_config['predict_with_generate'] = True
|
|
||||||
generate_config = GenerationConfig(
|
|
||||||
max_new_tokens=MAX_TOKEN_SIZE,
|
|
||||||
num_beams=1,
|
|
||||||
do_sample=False,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
|
||||||
)
|
|
||||||
eval_config['generation_config'] = generate_config
|
|
||||||
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
|
|
||||||
|
|
||||||
trainer = Seq2SeqTrainer(
|
|
||||||
model,
|
|
||||||
seq2seq_config,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
data_collator=collate_fn,
|
|
||||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer),
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_res = trainer.evaluate()
|
|
||||||
print(eval_res)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
script_dirpath = Path(__file__).resolve().parent
|
|
||||||
os.chdir(script_dirpath)
|
|
||||||
|
|
||||||
# dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
|
|
||||||
dataset = load_dataset("imagefolder", data_dir=str(script_dirpath / 'dataset'))['train']
|
|
||||||
dataset = dataset.filter(
|
|
||||||
lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH
|
|
||||||
)
|
|
||||||
dataset = dataset.shuffle(seed=42)
|
|
||||||
dataset = dataset.flatten_indices()
|
|
||||||
|
|
||||||
tokenizer = TexTeller.get_tokenizer()
|
|
||||||
# If you want use your own tokenizer, please modify the path to your tokenizer
|
|
||||||
# +tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
|
|
||||||
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
|
|
||||||
dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8)
|
|
||||||
|
|
||||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
|
||||||
tokenized_dataset = dataset.map(
|
|
||||||
map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8
|
|
||||||
)
|
|
||||||
|
|
||||||
# Split dataset into train and eval, ratio 9:1
|
|
||||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
|
|
||||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
|
||||||
train_dataset = train_dataset.with_transform(img_train_transform)
|
|
||||||
eval_dataset = eval_dataset.with_transform(img_inf_transform)
|
|
||||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
|
||||||
|
|
||||||
# Train from scratch
|
|
||||||
model = TexTeller()
|
|
||||||
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
|
|
||||||
|
|
||||||
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
|
|
||||||
# +e.g.
|
|
||||||
# +model = TexTeller.from_pretrained(
|
|
||||||
# + '/path/to/your/model_checkpoint'
|
|
||||||
# +)
|
|
||||||
|
|
||||||
enable_train = True
|
|
||||||
enable_evaluate = False
|
|
||||||
if enable_train:
|
|
||||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
|
||||||
if enable_evaluate and len(eval_dataset) > 0:
|
|
||||||
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
CONFIG = {
|
|
||||||
"seed": 42, # Random seed for reproducibility
|
|
||||||
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
|
|
||||||
"learning_rate": 5e-5, # Learning rate
|
|
||||||
"num_train_epochs": 10, # Total number of training epochs
|
|
||||||
"per_device_train_batch_size": 4, # Batch size per GPU for training
|
|
||||||
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
|
|
||||||
"output_dir": "train_result", # Output directory
|
|
||||||
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
|
|
||||||
"report_to": ["tensorboard"], # Report logs to TensorBoard
|
|
||||||
"save_strategy": "steps", # Strategy to save checkpoints
|
|
||||||
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
|
|
||||||
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
|
|
||||||
"logging_strategy": "steps", # Log every certain number of steps
|
|
||||||
"logging_steps": 500, # Number of steps between each log
|
|
||||||
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
|
|
||||||
"optim": "adamw_torch", # Optimizer
|
|
||||||
"lr_scheduler_type": "cosine", # Learning rate scheduler
|
|
||||||
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
|
|
||||||
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
|
|
||||||
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
|
|
||||||
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
|
|
||||||
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
|
|
||||||
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
|
|
||||||
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
|
|
||||||
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
|
|
||||||
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
|
|
||||||
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
|
|
||||||
"eval_steps": 500, # If evaluation_strategy="step"
|
|
||||||
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
|
|
||||||
}
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from transformers import DataCollatorForLanguageModeling
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
from .transforms import train_transform, inference_transform
|
|
||||||
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
|
|
||||||
|
|
||||||
|
|
||||||
def left_move(x: torch.Tensor, pad_val):
|
|
||||||
assert len(x.shape) == 2, 'x should be 2-dimensional'
|
|
||||||
lefted_x = torch.ones_like(x)
|
|
||||||
lefted_x[:, :-1] = x[:, 1:]
|
|
||||||
lefted_x[:, -1] = pad_val
|
|
||||||
return lefted_x
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
|
||||||
assert tokenizer is not None, 'tokenizer should not be None'
|
|
||||||
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
|
||||||
tokenized_formula['pixel_values'] = samples['image']
|
|
||||||
return tokenized_formula
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
|
||||||
assert tokenizer is not None, 'tokenizer should not be None'
|
|
||||||
pixel_values = [dic.pop('pixel_values') for dic in samples]
|
|
||||||
|
|
||||||
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
|
||||||
|
|
||||||
batch = clm_collator(samples)
|
|
||||||
batch['pixel_values'] = pixel_values
|
|
||||||
batch['decoder_input_ids'] = batch.pop('input_ids')
|
|
||||||
batch['decoder_attention_mask'] = batch.pop('attention_mask')
|
|
||||||
|
|
||||||
# 左移labels和decoder_attention_mask
|
|
||||||
batch['labels'] = left_move(batch['labels'], -100)
|
|
||||||
|
|
||||||
# 把list of Image转成一个tensor with (B, C, H, W)
|
|
||||||
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def img_train_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
|
||||||
processed_img = train_transform(samples['pixel_values'])
|
|
||||||
samples['pixel_values'] = processed_img
|
|
||||||
return samples
|
|
||||||
|
|
||||||
|
|
||||||
def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
|
||||||
processed_img = inference_transform(samples['pixel_values'])
|
|
||||||
samples['pixel_values'] = processed_img
|
|
||||||
return samples
|
|
||||||
|
|
||||||
|
|
||||||
def filter_fn(sample, tokenizer=None) -> bool:
|
|
||||||
return (
|
|
||||||
sample['image'].height > MIN_HEIGHT
|
|
||||||
and sample['image'].width > MIN_WIDTH
|
|
||||||
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
|
|
||||||
)
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
|
|
||||||
processed_images = []
|
|
||||||
for path in image_paths:
|
|
||||||
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
|
||||||
if image is None:
|
|
||||||
print(f"Image at {path} could not be read.")
|
|
||||||
continue
|
|
||||||
if image.dtype == np.uint16:
|
|
||||||
print(f'Converting {path} to 8-bit, image may be lossy.')
|
|
||||||
image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
|
|
||||||
|
|
||||||
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
|
||||||
if channels == 4:
|
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
|
||||||
elif channels == 1:
|
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
|
||||||
elif channels == 3:
|
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
||||||
processed_images.append(image)
|
|
||||||
|
|
||||||
return processed_images
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
import evaluate
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict
|
|
||||||
from transformers import EvalPrediction, RobertaTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict:
|
|
||||||
cur_dir = Path(os.getcwd())
|
|
||||||
os.chdir(Path(__file__).resolve().parent)
|
|
||||||
metric = evaluate.load(
|
|
||||||
'google_bleu'
|
|
||||||
) # Will download the metric from huggingface if not already downloaded
|
|
||||||
os.chdir(cur_dir)
|
|
||||||
|
|
||||||
logits, labels = eval_preds.predictions, eval_preds.label_ids
|
|
||||||
preds = logits
|
|
||||||
|
|
||||||
labels = np.where(labels == -100, 1, labels)
|
|
||||||
|
|
||||||
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
|
||||||
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
||||||
return metric.compute(predictions=preds, references=labels)
|
|
||||||
@@ -1,152 +0,0 @@
|
|||||||
from augraphy import *
|
|
||||||
import random
|
|
||||||
|
|
||||||
|
|
||||||
def ocr_augmentation_pipeline():
|
|
||||||
pre_phase = []
|
|
||||||
|
|
||||||
ink_phase = [
|
|
||||||
InkColorSwap(
|
|
||||||
ink_swap_color="random",
|
|
||||||
ink_swap_sequence_number_range=(5, 10),
|
|
||||||
ink_swap_min_width_range=(2, 3),
|
|
||||||
ink_swap_max_width_range=(100, 120),
|
|
||||||
ink_swap_min_height_range=(2, 3),
|
|
||||||
ink_swap_max_height_range=(100, 120),
|
|
||||||
ink_swap_min_area_range=(10, 20),
|
|
||||||
ink_swap_max_area_range=(400, 500),
|
|
||||||
# p=0.2
|
|
||||||
p=0.4,
|
|
||||||
),
|
|
||||||
LinesDegradation(
|
|
||||||
line_roi=(0.0, 0.0, 1.0, 1.0),
|
|
||||||
line_gradient_range=(32, 255),
|
|
||||||
line_gradient_direction=(0, 2),
|
|
||||||
line_split_probability=(0.2, 0.4),
|
|
||||||
line_replacement_value=(250, 255),
|
|
||||||
line_min_length=(30, 40),
|
|
||||||
line_long_to_short_ratio=(5, 7),
|
|
||||||
line_replacement_probability=(0.4, 0.5),
|
|
||||||
line_replacement_thickness=(1, 3),
|
|
||||||
# p=0.2
|
|
||||||
p=0.4,
|
|
||||||
),
|
|
||||||
# ============================
|
|
||||||
OneOf(
|
|
||||||
[
|
|
||||||
Dithering(
|
|
||||||
dither="floyd-steinberg",
|
|
||||||
order=(3, 5),
|
|
||||||
),
|
|
||||||
InkBleed(
|
|
||||||
intensity_range=(0.1, 0.2),
|
|
||||||
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]),
|
|
||||||
severity=(0.4, 0.6),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
# p=0.2
|
|
||||||
p=0.4,
|
|
||||||
),
|
|
||||||
# ============================
|
|
||||||
# ============================
|
|
||||||
InkShifter(
|
|
||||||
text_shift_scale_range=(18, 27),
|
|
||||||
text_shift_factor_range=(1, 4),
|
|
||||||
text_fade_range=(0, 2),
|
|
||||||
blur_kernel_size=(5, 5),
|
|
||||||
blur_sigma=0,
|
|
||||||
noise_type="perlin",
|
|
||||||
# p=0.2
|
|
||||||
p=0.4,
|
|
||||||
),
|
|
||||||
# ============================
|
|
||||||
]
|
|
||||||
|
|
||||||
paper_phase = [
|
|
||||||
NoiseTexturize( # tested
|
|
||||||
sigma_range=(3, 10),
|
|
||||||
turbulence_range=(2, 5),
|
|
||||||
texture_width_range=(300, 500),
|
|
||||||
texture_height_range=(300, 500),
|
|
||||||
# p=0.2
|
|
||||||
p=0.4,
|
|
||||||
),
|
|
||||||
BrightnessTexturize( # tested
|
|
||||||
texturize_range=(0.9, 0.99),
|
|
||||||
deviation=0.03,
|
|
||||||
# p=0.2
|
|
||||||
p=0.4,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
post_phase = [
|
|
||||||
ColorShift( # tested
|
|
||||||
color_shift_offset_x_range=(3, 5),
|
|
||||||
color_shift_offset_y_range=(3, 5),
|
|
||||||
color_shift_iterations=(2, 3),
|
|
||||||
color_shift_brightness_range=(0.9, 1.1),
|
|
||||||
color_shift_gaussian_kernel_range=(3, 3),
|
|
||||||
# p=0.2
|
|
||||||
p=0.4,
|
|
||||||
),
|
|
||||||
DirtyDrum( # tested
|
|
||||||
line_width_range=(1, 6),
|
|
||||||
line_concentration=random.uniform(0.05, 0.15),
|
|
||||||
direction=random.randint(0, 2),
|
|
||||||
noise_intensity=random.uniform(0.6, 0.95),
|
|
||||||
noise_value=(64, 224),
|
|
||||||
ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
|
|
||||||
sigmaX=0,
|
|
||||||
# p=0.2
|
|
||||||
p=0.4,
|
|
||||||
),
|
|
||||||
# =====================================
|
|
||||||
OneOf(
|
|
||||||
[
|
|
||||||
LightingGradient(
|
|
||||||
light_position=None,
|
|
||||||
direction=None,
|
|
||||||
max_brightness=255,
|
|
||||||
min_brightness=0,
|
|
||||||
mode="gaussian",
|
|
||||||
linear_decay_rate=None,
|
|
||||||
transparency=None,
|
|
||||||
),
|
|
||||||
Brightness(
|
|
||||||
brightness_range=(0.9, 1.1),
|
|
||||||
min_brightness=0,
|
|
||||||
min_brightness_value=(120, 150),
|
|
||||||
),
|
|
||||||
Gamma(
|
|
||||||
gamma_range=(0.9, 1.1),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
# p=0.2
|
|
||||||
p=0.4,
|
|
||||||
),
|
|
||||||
# =====================================
|
|
||||||
# =====================================
|
|
||||||
OneOf(
|
|
||||||
[
|
|
||||||
SubtleNoise(
|
|
||||||
subtle_range=random.randint(5, 10),
|
|
||||||
),
|
|
||||||
Jpeg(
|
|
||||||
quality_range=(70, 95),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
# p=0.2
|
|
||||||
p=0.4,
|
|
||||||
),
|
|
||||||
# =====================================
|
|
||||||
]
|
|
||||||
|
|
||||||
pipeline = AugraphyPipeline(
|
|
||||||
ink_phase=ink_phase,
|
|
||||||
paper_phase=paper_phase,
|
|
||||||
post_phase=post_phase,
|
|
||||||
pre_phase=pre_phase,
|
|
||||||
log=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return pipeline
|
|
||||||
@@ -1,177 +0,0 @@
|
|||||||
import torch
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
from torchvision.transforms import v2
|
|
||||||
from typing import List, Union
|
|
||||||
from PIL import Image
|
|
||||||
from collections import Counter
|
|
||||||
|
|
||||||
from ...globals import (
|
|
||||||
IMG_CHANNELS,
|
|
||||||
FIXED_IMG_SIZE,
|
|
||||||
IMAGE_MEAN,
|
|
||||||
IMAGE_STD,
|
|
||||||
MAX_RESIZE_RATIO,
|
|
||||||
MIN_RESIZE_RATIO,
|
|
||||||
)
|
|
||||||
from .ocr_aug import ocr_augmentation_pipeline
|
|
||||||
|
|
||||||
# train_pipeline = default_augraphy_pipeline(scan_only=True)
|
|
||||||
train_pipeline = ocr_augmentation_pipeline()
|
|
||||||
|
|
||||||
general_transform_pipeline = v2.Compose(
|
|
||||||
[
|
|
||||||
v2.ToImage(),
|
|
||||||
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
|
|
||||||
v2.Grayscale(),
|
|
||||||
v2.Resize(
|
|
||||||
size=FIXED_IMG_SIZE - 1,
|
|
||||||
interpolation=v2.InterpolationMode.BICUBIC,
|
|
||||||
max_size=FIXED_IMG_SIZE,
|
|
||||||
antialias=True,
|
|
||||||
),
|
|
||||||
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
|
||||||
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
|
||||||
# v2.ToPILImage()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def trim_white_border(image: np.ndarray):
|
|
||||||
if len(image.shape) != 3 or image.shape[2] != 3:
|
|
||||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
|
||||||
|
|
||||||
if image.dtype != np.uint8:
|
|
||||||
raise ValueError(f"Image should stored in uint8")
|
|
||||||
|
|
||||||
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
|
|
||||||
bg_color = Counter(corners).most_common(1)[0][0]
|
|
||||||
bg_color_np = np.array(bg_color, dtype=np.uint8)
|
|
||||||
|
|
||||||
h, w = image.shape[:2]
|
|
||||||
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
|
|
||||||
|
|
||||||
diff = cv2.absdiff(image, bg)
|
|
||||||
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
threshold = 15
|
|
||||||
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
|
|
||||||
|
|
||||||
x, y, w, h = cv2.boundingRect(diff)
|
|
||||||
|
|
||||||
trimmed_image = image[y : y + h, x : x + w]
|
|
||||||
|
|
||||||
return trimmed_image
|
|
||||||
|
|
||||||
|
|
||||||
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
|
|
||||||
randi = [random.randint(0, max_size) for _ in range(4)]
|
|
||||||
pad_height_size = randi[1] + randi[3]
|
|
||||||
pad_width_size = randi[0] + randi[2]
|
|
||||||
if pad_height_size + image.shape[0] < 30:
|
|
||||||
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
|
|
||||||
randi[1] += compensate_height
|
|
||||||
randi[3] += compensate_height
|
|
||||||
if pad_width_size + image.shape[1] < 30:
|
|
||||||
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
|
|
||||||
randi[0] += compensate_width
|
|
||||||
randi[2] += compensate_width
|
|
||||||
return v2.functional.pad(
|
|
||||||
torch.from_numpy(image).permute(2, 0, 1),
|
|
||||||
padding=randi,
|
|
||||||
padding_mode='constant',
|
|
||||||
fill=(255, 255, 255),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
|
|
||||||
images = [
|
|
||||||
v2.functional.pad(
|
|
||||||
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
|
||||||
)
|
|
||||||
for img in images
|
|
||||||
]
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
def random_resize(images: List[np.ndarray], minr: float, maxr: float) -> List[np.ndarray]:
|
|
||||||
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
|
||||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
|
||||||
|
|
||||||
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
|
|
||||||
return [
|
|
||||||
cv2.resize(
|
|
||||||
img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4
|
|
||||||
) # 抗锯齿
|
|
||||||
for img, r in zip(images, ratios)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
|
|
||||||
# Get the center of the image to define the point of rotation
|
|
||||||
image_center = tuple(np.array(image.shape[1::-1]) / 2)
|
|
||||||
|
|
||||||
# Generate a random angle within the specified range
|
|
||||||
angle = random.randint(min_angle, max_angle)
|
|
||||||
|
|
||||||
# Get the rotation matrix for rotating the image around its center
|
|
||||||
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
|
||||||
|
|
||||||
# Determine the size of the rotated image
|
|
||||||
cos = np.abs(rotation_mat[0, 0])
|
|
||||||
sin = np.abs(rotation_mat[0, 1])
|
|
||||||
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
|
|
||||||
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
|
|
||||||
|
|
||||||
# Adjust the rotation matrix to take into account translation
|
|
||||||
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
|
|
||||||
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
|
|
||||||
|
|
||||||
# Rotate the image with the specified border color (white in this case)
|
|
||||||
rotated_image = cv2.warpAffine(
|
|
||||||
image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255)
|
|
||||||
)
|
|
||||||
|
|
||||||
return rotated_image
|
|
||||||
|
|
||||||
|
|
||||||
def ocr_aug(image: np.ndarray) -> np.ndarray:
|
|
||||||
if random.random() < 0.2:
|
|
||||||
image = rotate(image, -5, 5)
|
|
||||||
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
|
|
||||||
image = train_pipeline(image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
|
||||||
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
|
||||||
|
|
||||||
images = [np.array(img.convert('RGB')) for img in images]
|
|
||||||
# random resize first
|
|
||||||
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
|
||||||
images = [trim_white_border(image) for image in images]
|
|
||||||
|
|
||||||
# OCR augmentation
|
|
||||||
images = [ocr_aug(image) for image in images]
|
|
||||||
|
|
||||||
# general transform pipeline
|
|
||||||
images = [general_transform_pipeline(image) for image in images]
|
|
||||||
# padding to fixed size
|
|
||||||
images = padding(images, FIXED_IMG_SIZE)
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
|
|
||||||
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
|
||||||
images = [
|
|
||||||
np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images
|
|
||||||
]
|
|
||||||
images = [trim_white_border(image) for image in images]
|
|
||||||
# general transform pipeline
|
|
||||||
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
|
|
||||||
# padding to fixed size
|
|
||||||
images = padding(images, FIXED_IMG_SIZE)
|
|
||||||
|
|
||||||
return images
|
|
||||||
48
texteller/models/texteller.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers import RobertaTokenizerFast, VisionEncoderDecoderConfig, VisionEncoderDecoderModel
|
||||||
|
|
||||||
|
from texteller.constants import (
|
||||||
|
FIXED_IMG_SIZE,
|
||||||
|
IMG_CHANNELS,
|
||||||
|
MAX_TOKEN_SIZE,
|
||||||
|
VOCAB_SIZE,
|
||||||
|
)
|
||||||
|
from texteller.globals import Globals
|
||||||
|
from texteller.types import TexTellerModel
|
||||||
|
from texteller.utils import cuda_available
|
||||||
|
|
||||||
|
|
||||||
|
class TexTeller(VisionEncoderDecoderModel):
|
||||||
|
def __init__(self):
|
||||||
|
config = VisionEncoderDecoderConfig.from_pretrained(Globals().repo_name)
|
||||||
|
config.encoder.image_size = FIXED_IMG_SIZE
|
||||||
|
config.encoder.num_channels = IMG_CHANNELS
|
||||||
|
config.decoder.vocab_size = VOCAB_SIZE
|
||||||
|
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
|
||||||
|
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_dir: str | None = None, use_onnx=False) -> TexTellerModel:
|
||||||
|
if model_dir is None or model_dir == Globals().repo_name:
|
||||||
|
if not use_onnx:
|
||||||
|
return VisionEncoderDecoderModel.from_pretrained(Globals().repo_name)
|
||||||
|
else:
|
||||||
|
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||||
|
|
||||||
|
return ORTModelForVision2Seq.from_pretrained(
|
||||||
|
Globals().repo_name,
|
||||||
|
provider="CUDAExecutionProvider"
|
||||||
|
if cuda_available()
|
||||||
|
else "CPUExecutionProvider",
|
||||||
|
)
|
||||||
|
model_dir = Path(model_dir).resolve()
|
||||||
|
return VisionEncoderDecoderModel.from_pretrained(str(model_dir))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tokenizer(cls, tokenizer_dir: str = None) -> RobertaTokenizerFast:
|
||||||
|
if tokenizer_dir is None or tokenizer_dir == Globals().repo_name:
|
||||||
|
return RobertaTokenizerFast.from_pretrained(Globals().repo_name)
|
||||||
|
tokenizer_dir = Path(tokenizer_dir).resolve()
|
||||||
|
return RobertaTokenizerFast.from_pretrained(str(tokenizer_dir))
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from datasets import load_dataset
|
|
||||||
from ..ocr_model.model.TexTeller import TexTeller
|
|
||||||
from ..globals import VOCAB_SIZE
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
script_dirpath = Path(__file__).resolve().parent
|
|
||||||
os.chdir(script_dirpath)
|
|
||||||
|
|
||||||
tokenizer = TexTeller.get_tokenizer()
|
|
||||||
|
|
||||||
# Don't forget to config your dataset path in loader.py
|
|
||||||
dataset = load_dataset('../ocr_model/train/dataset/loader.py')['train']
|
|
||||||
|
|
||||||
new_tokenizer = tokenizer.train_new_from_iterator(
|
|
||||||
text_iterator=dataset['latex_formula'],
|
|
||||||
# If you want to use a different vocab size, **change VOCAB_SIZE from globals.py**
|
|
||||||
vocab_size=VOCAB_SIZE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save the new tokenizer for later training and inference
|
|
||||||
new_tokenizer.save_pretrained('./your_dir_name')
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .mix_inference import mix_inference
|
|
||||||
@@ -1,261 +0,0 @@
|
|||||||
import re
|
|
||||||
import heapq
|
|
||||||
import cv2
|
|
||||||
import time
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from collections import Counter
|
|
||||||
from typing import List
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ..det_model.inference import predict as latex_det_predict
|
|
||||||
from ..det_model.Bbox import Bbox, draw_bboxes
|
|
||||||
|
|
||||||
from ..ocr_model.utils.inference import inference as latex_rec_predict
|
|
||||||
from ..ocr_model.utils.to_katex import to_katex, change_all
|
|
||||||
|
|
||||||
MAXV = 999999999
|
|
||||||
|
|
||||||
|
|
||||||
def mask_img(img, bboxes: List[Bbox], bg_color: np.ndarray) -> np.ndarray:
|
|
||||||
mask_img = img.copy()
|
|
||||||
for bbox in bboxes:
|
|
||||||
mask_img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w] = bg_color
|
|
||||||
return mask_img
|
|
||||||
|
|
||||||
|
|
||||||
def bbox_merge(sorted_bboxes: List[Bbox]) -> List[Bbox]:
|
|
||||||
if len(sorted_bboxes) == 0:
|
|
||||||
return []
|
|
||||||
bboxes = sorted_bboxes.copy()
|
|
||||||
guard = Bbox(MAXV, bboxes[-1].p.y, -1, -1, label="guard")
|
|
||||||
bboxes.append(guard)
|
|
||||||
res = []
|
|
||||||
prev = bboxes[0]
|
|
||||||
for curr in bboxes:
|
|
||||||
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
|
|
||||||
res.append(prev)
|
|
||||||
prev = curr
|
|
||||||
else:
|
|
||||||
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbox]:
|
|
||||||
if latex_bboxes == []:
|
|
||||||
return ocr_bboxes
|
|
||||||
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
|
|
||||||
return ocr_bboxes
|
|
||||||
|
|
||||||
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
|
||||||
|
|
||||||
# log results
|
|
||||||
for idx, bbox in enumerate(bboxes):
|
|
||||||
bbox.content = str(idx)
|
|
||||||
draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png")
|
|
||||||
|
|
||||||
assert len(bboxes) > 1
|
|
||||||
|
|
||||||
heapq.heapify(bboxes)
|
|
||||||
res = []
|
|
||||||
candidate = heapq.heappop(bboxes)
|
|
||||||
curr = heapq.heappop(bboxes)
|
|
||||||
idx = 0
|
|
||||||
while len(bboxes) > 0:
|
|
||||||
idx += 1
|
|
||||||
assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
|
|
||||||
|
|
||||||
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
|
|
||||||
res.append(candidate)
|
|
||||||
candidate = curr
|
|
||||||
curr = heapq.heappop(bboxes)
|
|
||||||
elif candidate.ur_point.x < curr.ur_point.x:
|
|
||||||
assert not (candidate.label != "text" and curr.label != "text")
|
|
||||||
if candidate.label == "text" and curr.label == "text":
|
|
||||||
candidate.w = curr.ur_point.x - candidate.p.x
|
|
||||||
curr = heapq.heappop(bboxes)
|
|
||||||
elif candidate.label != curr.label:
|
|
||||||
if candidate.label == "text":
|
|
||||||
candidate.w = curr.p.x - candidate.p.x
|
|
||||||
res.append(candidate)
|
|
||||||
candidate = curr
|
|
||||||
curr = heapq.heappop(bboxes)
|
|
||||||
else:
|
|
||||||
curr.w = curr.ur_point.x - candidate.ur_point.x
|
|
||||||
curr.p.x = candidate.ur_point.x
|
|
||||||
heapq.heappush(bboxes, curr)
|
|
||||||
curr = heapq.heappop(bboxes)
|
|
||||||
|
|
||||||
elif candidate.ur_point.x >= curr.ur_point.x:
|
|
||||||
assert not (candidate.label != "text" and curr.label != "text")
|
|
||||||
|
|
||||||
if candidate.label == "text":
|
|
||||||
assert curr.label != "text"
|
|
||||||
heapq.heappush(
|
|
||||||
bboxes,
|
|
||||||
Bbox(
|
|
||||||
curr.ur_point.x,
|
|
||||||
candidate.p.y,
|
|
||||||
candidate.h,
|
|
||||||
candidate.ur_point.x - curr.ur_point.x,
|
|
||||||
label="text",
|
|
||||||
confidence=candidate.confidence,
|
|
||||||
content=None,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
candidate.w = curr.p.x - candidate.p.x
|
|
||||||
res.append(candidate)
|
|
||||||
candidate = curr
|
|
||||||
curr = heapq.heappop(bboxes)
|
|
||||||
else:
|
|
||||||
assert curr.label == "text"
|
|
||||||
curr = heapq.heappop(bboxes)
|
|
||||||
else:
|
|
||||||
assert False
|
|
||||||
res.append(candidate)
|
|
||||||
res.append(curr)
|
|
||||||
|
|
||||||
# log results
|
|
||||||
for idx, bbox in enumerate(res):
|
|
||||||
bbox.content = str(idx)
|
|
||||||
draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray]:
|
|
||||||
sliced_imgs = []
|
|
||||||
for bbox in ocr_bboxes:
|
|
||||||
x, y = int(bbox.p.x), int(bbox.p.y)
|
|
||||||
w, h = int(bbox.w), int(bbox.h)
|
|
||||||
sliced_img = img[y : y + h, x : x + w]
|
|
||||||
sliced_imgs.append(sliced_img)
|
|
||||||
return sliced_imgs
|
|
||||||
|
|
||||||
|
|
||||||
def mix_inference(
|
|
||||||
img_path: str,
|
|
||||||
infer_config,
|
|
||||||
latex_det_model,
|
|
||||||
lang_ocr_models,
|
|
||||||
latex_rec_models,
|
|
||||||
accelerator="cpu",
|
|
||||||
num_beams=1,
|
|
||||||
) -> str:
|
|
||||||
'''
|
|
||||||
Input a mixed image of formula text and output str (in markdown syntax)
|
|
||||||
'''
|
|
||||||
global img
|
|
||||||
img = cv2.imread(img_path)
|
|
||||||
corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])]
|
|
||||||
bg_color = np.array(Counter(corners).most_common(1)[0][0])
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
|
|
||||||
end_time = time.time()
|
|
||||||
print(f"latex_det_model time: {end_time - start_time:.2f}s")
|
|
||||||
latex_bboxes = sorted(latex_bboxes)
|
|
||||||
# log results
|
|
||||||
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
|
|
||||||
latex_bboxes = bbox_merge(latex_bboxes)
|
|
||||||
# log results
|
|
||||||
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
|
|
||||||
masked_img = mask_img(img, latex_bboxes, bg_color)
|
|
||||||
|
|
||||||
det_model, rec_model = lang_ocr_models
|
|
||||||
start_time = time.time()
|
|
||||||
det_prediction, _ = det_model(masked_img)
|
|
||||||
end_time = time.time()
|
|
||||||
print(f"ocr_det_model time: {end_time - start_time:.2f}s")
|
|
||||||
ocr_bboxes = [
|
|
||||||
Bbox(
|
|
||||||
p[0][0],
|
|
||||||
p[0][1],
|
|
||||||
p[3][1] - p[0][1],
|
|
||||||
p[1][0] - p[0][0],
|
|
||||||
label="text",
|
|
||||||
confidence=None,
|
|
||||||
content=None,
|
|
||||||
)
|
|
||||||
for p in det_prediction
|
|
||||||
]
|
|
||||||
# log results
|
|
||||||
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(unmerged).png")
|
|
||||||
|
|
||||||
ocr_bboxes = sorted(ocr_bboxes)
|
|
||||||
ocr_bboxes = bbox_merge(ocr_bboxes)
|
|
||||||
# log results
|
|
||||||
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
|
|
||||||
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
|
|
||||||
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
|
|
||||||
|
|
||||||
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
|
|
||||||
start_time = time.time()
|
|
||||||
rec_predictions, _ = rec_model(sliced_imgs)
|
|
||||||
end_time = time.time()
|
|
||||||
print(f"ocr_rec_model time: {end_time - start_time:.2f}s")
|
|
||||||
|
|
||||||
assert len(rec_predictions) == len(ocr_bboxes)
|
|
||||||
for content, bbox in zip(rec_predictions, ocr_bboxes):
|
|
||||||
bbox.content = content[0]
|
|
||||||
|
|
||||||
latex_imgs = []
|
|
||||||
for bbox in latex_bboxes:
|
|
||||||
latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w])
|
|
||||||
start_time = time.time()
|
|
||||||
latex_rec_res = latex_rec_predict(
|
|
||||||
*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=800
|
|
||||||
)
|
|
||||||
end_time = time.time()
|
|
||||||
print(f"latex_rec_model time: {end_time - start_time:.2f}s")
|
|
||||||
|
|
||||||
for bbox, content in zip(latex_bboxes, latex_rec_res):
|
|
||||||
bbox.content = to_katex(content)
|
|
||||||
if bbox.label == "embedding":
|
|
||||||
bbox.content = " $" + bbox.content + "$ "
|
|
||||||
elif bbox.label == "isolated":
|
|
||||||
bbox.content = '\n\n' + r"$$" + bbox.content + r"$$" + '\n\n'
|
|
||||||
|
|
||||||
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
|
||||||
if bboxes == []:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
md = ""
|
|
||||||
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
|
|
||||||
for curr in bboxes:
|
|
||||||
# Add the formula number back to the isolated formula
|
|
||||||
if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr):
|
|
||||||
curr.content = curr.content.strip()
|
|
||||||
if curr.content.startswith('(') and curr.content.endswith(')'):
|
|
||||||
curr.content = curr.content[1:-1]
|
|
||||||
|
|
||||||
if re.search(r'\\tag\{.*\}$', md[:-4]) is not None:
|
|
||||||
# in case of multiple tag
|
|
||||||
md = md[:-5] + f', {curr.content}' + '}' + md[-4:]
|
|
||||||
else:
|
|
||||||
md = md[:-4] + f'\\tag{{{curr.content}}}' + md[-4:]
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not prev.same_row(curr):
|
|
||||||
md += " "
|
|
||||||
|
|
||||||
if curr.label == "embedding":
|
|
||||||
# remove the bold effect from inline formulas
|
|
||||||
curr.content = change_all(curr.content, r'\bm', r' ', r'{', r'}', r'', r' ')
|
|
||||||
curr.content = change_all(curr.content, r'\boldsymbol', r' ', r'{', r'}', r'', r' ')
|
|
||||||
curr.content = change_all(curr.content, r'\textit', r' ', r'{', r'}', r'', r' ')
|
|
||||||
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
|
|
||||||
curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ')
|
|
||||||
curr.content = change_all(curr.content, r'\mathbf', r' ', r'{', r'}', r'', r' ')
|
|
||||||
|
|
||||||
# change split environment into aligned
|
|
||||||
curr.content = curr.content.replace(r'\begin{split}', r'\begin{aligned}')
|
|
||||||
curr.content = curr.content.replace(r'\end{split}', r'\end{aligned}')
|
|
||||||
|
|
||||||
# remove extra spaces (keeping only one)
|
|
||||||
curr.content = re.sub(r' +', ' ', curr.content)
|
|
||||||
assert curr.content.startswith(' $') and curr.content.endswith('$ ')
|
|
||||||
curr.content = ' $' + curr.content[2:-2].strip() + '$ '
|
|
||||||
md += curr.content
|
|
||||||
prev = curr
|
|
||||||
return md.strip()
|
|
||||||
@@ -81,7 +81,7 @@ class BaseRecLabelDecode(object):
|
|||||||
word_list = []
|
word_list = []
|
||||||
word_col_list = []
|
word_col_list = []
|
||||||
state_list = []
|
state_list = []
|
||||||
valid_col = np.where(selection == True)[0]
|
valid_col = np.where(selection)[0]
|
||||||
|
|
||||||
for c_i, char in enumerate(text):
|
for c_i, char in enumerate(text):
|
||||||
if "\u4e00" <= char <= "\u9fff":
|
if "\u4e00" <= char <= "\u9fff":
|
||||||
@@ -12,25 +12,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
sys.path.append(__dir__)
|
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
|
|
||||||
|
|
||||||
os.environ["FLAGS_allocator_strategy"] = "auto_growth"
|
os.environ["FLAGS_allocator_strategy"] = "auto_growth"
|
||||||
|
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# import tools.infer.utility as utility
|
from .DBPostProcess import DBPostProcess
|
||||||
import utility
|
from .operators import DetResizeForTest, KeepKeys, NormalizeImage, ToCHWImage
|
||||||
from DBPostProcess import DBPostProcess
|
from .utility import create_predictor, get_logger
|
||||||
from operators import DetResizeForTest, KeepKeys, NormalizeImage, ToCHWImage
|
|
||||||
from utility import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
def transform(data, ops=None):
|
def transform(data, ops=None):
|
||||||
@@ -82,7 +73,7 @@ class TextDetector(object):
|
|||||||
self.input_tensor,
|
self.input_tensor,
|
||||||
self.output_tensors,
|
self.output_tensors,
|
||||||
self.config,
|
self.config,
|
||||||
) = utility.create_predictor(args, "det", logger)
|
) = create_predictor(args, "det", logger)
|
||||||
|
|
||||||
assert self.use_onnx
|
assert self.use_onnx
|
||||||
if self.use_onnx:
|
if self.use_onnx:
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
import sys
|
|
||||||
import argparse
|
|
||||||
import tempfile
|
|
||||||
import time
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from starlette.requests import Request
|
|
||||||
from ray import serve
|
|
||||||
from ray.serve.handle import DeploymentHandle
|
|
||||||
from onnxruntime import InferenceSession
|
|
||||||
|
|
||||||
from texteller.models.ocr_model.utils.inference import inference as rec_inference
|
|
||||||
from texteller.models.det_model.inference import predict as det_inference
|
|
||||||
from texteller.models.ocr_model.model.TexTeller import TexTeller
|
|
||||||
from texteller.models.det_model.inference import PredictConfig
|
|
||||||
from texteller.models.ocr_model.utils.to_katex import to_katex
|
|
||||||
|
|
||||||
|
|
||||||
PYTHON_VERSION = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
|
|
||||||
LIBPATH = Path(sys.executable).parent.parent / 'lib' / ('python' + PYTHON_VERSION) / 'site-packages'
|
|
||||||
CUDNNPATH = LIBPATH / 'nvidia' / 'cudnn' / 'lib'
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('-ckpt', '--checkpoint_dir', type=str)
|
|
||||||
parser.add_argument('-tknz', '--tokenizer_dir', type=str)
|
|
||||||
parser.add_argument('-port', '--server_port', type=int, default=8000)
|
|
||||||
parser.add_argument('--num_replicas', type=int, default=1)
|
|
||||||
parser.add_argument('--ncpu_per_replica', type=float, default=1.0)
|
|
||||||
parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
|
|
||||||
|
|
||||||
parser.add_argument('--inference-mode', type=str, default='cpu')
|
|
||||||
parser.add_argument('--num_beams', type=int, default=1)
|
|
||||||
parser.add_argument('-onnx', action='store_true', help='using onnx runtime')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda':
|
|
||||||
raise ValueError("--inference-mode must be cuda or mps if ngpu_per_replica > 0")
|
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment(
|
|
||||||
num_replicas=args.num_replicas,
|
|
||||||
ray_actor_options={
|
|
||||||
"num_cpus": args.ncpu_per_replica,
|
|
||||||
"num_gpus": args.ngpu_per_replica * 1.0 / 2,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
class TexTellerRecServer:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
checkpoint_path: str,
|
|
||||||
tokenizer_path: str,
|
|
||||||
inf_mode: str = 'cpu',
|
|
||||||
use_onnx: bool = False,
|
|
||||||
num_beams: int = 1,
|
|
||||||
) -> None:
|
|
||||||
self.model = TexTeller.from_pretrained(
|
|
||||||
checkpoint_path, use_onnx=use_onnx, onnx_provider=inf_mode
|
|
||||||
)
|
|
||||||
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
|
|
||||||
self.inf_mode = inf_mode
|
|
||||||
self.num_beams = num_beams
|
|
||||||
|
|
||||||
if not use_onnx:
|
|
||||||
self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model
|
|
||||||
|
|
||||||
def predict(self, image_nparray) -> str:
|
|
||||||
return to_katex(
|
|
||||||
rec_inference(
|
|
||||||
self.model,
|
|
||||||
self.tokenizer,
|
|
||||||
[image_nparray],
|
|
||||||
accelerator=self.inf_mode,
|
|
||||||
num_beams=self.num_beams,
|
|
||||||
)[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment(
|
|
||||||
num_replicas=args.num_replicas,
|
|
||||||
ray_actor_options={
|
|
||||||
"num_cpus": args.ncpu_per_replica,
|
|
||||||
"num_gpus": args.ngpu_per_replica * 1.0 / 2,
|
|
||||||
"runtime_env": {"env_vars": {"LD_LIBRARY_PATH": f"{str(CUDNNPATH)}/:$LD_LIBRARY_PATH"}},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
class TexTellerDetServer:
|
|
||||||
def __init__(self, inf_mode='cpu'):
|
|
||||||
self.infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
|
|
||||||
self.latex_det_model = InferenceSession(
|
|
||||||
"./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
|
|
||||||
providers=['CUDAExecutionProvider'] if inf_mode == 'cuda' else ['CPUExecutionProvider'],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def predict(self, image_nparray) -> str:
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
img_path = f"{temp_dir}/temp_image.jpg"
|
|
||||||
cv2.imwrite(img_path, image_nparray)
|
|
||||||
|
|
||||||
latex_bboxes = det_inference(img_path, self.latex_det_model, self.infer_config)
|
|
||||||
return latex_bboxes
|
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment()
|
|
||||||
class Ingress:
|
|
||||||
def __init__(self, det_server: DeploymentHandle, rec_server: DeploymentHandle) -> None:
|
|
||||||
self.det_server = det_server
|
|
||||||
self.texteller_server = rec_server
|
|
||||||
|
|
||||||
async def __call__(self, request: Request) -> str:
|
|
||||||
request_path = request.url.path
|
|
||||||
form = await request.form()
|
|
||||||
img_rb = await form['img'].read()
|
|
||||||
|
|
||||||
img_nparray = np.frombuffer(img_rb, np.uint8)
|
|
||||||
img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR)
|
|
||||||
img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB)
|
|
||||||
|
|
||||||
if request_path.startswith("/fdet"):
|
|
||||||
if self.det_server is None:
|
|
||||||
return "[ERROR] rtdetr_r50vd_6x_coco.onnx not found."
|
|
||||||
pred = await self.det_server.predict.remote(img_nparray)
|
|
||||||
return pred
|
|
||||||
|
|
||||||
elif request_path.startswith("/frec"):
|
|
||||||
pred = await self.texteller_server.predict.remote(img_nparray)
|
|
||||||
return pred
|
|
||||||
|
|
||||||
else:
|
|
||||||
return "[ERROR] Invalid request path"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
ckpt_dir = args.checkpoint_dir
|
|
||||||
tknz_dir = args.tokenizer_dir
|
|
||||||
|
|
||||||
serve.start(http_options={"host": "0.0.0.0", "port": args.server_port})
|
|
||||||
rec_server = TexTellerRecServer.bind(
|
|
||||||
ckpt_dir,
|
|
||||||
tknz_dir,
|
|
||||||
inf_mode=args.inference_mode,
|
|
||||||
use_onnx=args.onnx,
|
|
||||||
num_beams=args.num_beams,
|
|
||||||
)
|
|
||||||
det_server = None
|
|
||||||
if Path('./models/det_model/model/rtdetr_r50vd_6x_coco.onnx').exists():
|
|
||||||
det_server = TexTellerDetServer.bind(args.inference_mode)
|
|
||||||
ingress = Ingress.bind(det_server, rec_server)
|
|
||||||
|
|
||||||
# ingress_handle = serve.run(ingress, route_prefix="/predict")
|
|
||||||
ingress_handle = serve.run(ingress, route_prefix="/")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
time.sleep(1)
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
@echo off
|
|
||||||
SETLOCAL ENABLEEXTENSIONS
|
|
||||||
|
|
||||||
set CHECKPOINT_DIR=default
|
|
||||||
set TOKENIZER_DIR=default
|
|
||||||
|
|
||||||
streamlit run web.py
|
|
||||||
|
|
||||||
ENDLOCAL
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
set -exu
|
|
||||||
|
|
||||||
export CHECKPOINT_DIR="default"
|
|
||||||
export TOKENIZER_DIR="default"
|
|
||||||
|
|
||||||
streamlit run web.py
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
compute_environment: LOCAL_MACHINE
|
|
||||||
debug: false
|
|
||||||
distributed_type: MULTI_GPU
|
|
||||||
gpu_ids: all
|
|
||||||
num_processes: 1
|
|
||||||
machine_rank: 0
|
|
||||||
main_training_function: main
|
|
||||||
num_machines: 1
|
|
||||||
rdzv_backend: static
|
|
||||||
same_network: true
|
|
||||||
tpu_env: []
|
|
||||||
tpu_use_cluster: false
|
|
||||||
tpu_use_sudo: false
|
|
||||||
use_cpu: false
|
|
||||||
12
texteller/types/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from typing import TypeAlias
|
||||||
|
|
||||||
|
from optimum.onnxruntime import ORTModelForVision2Seq
|
||||||
|
from transformers import VisionEncoderDecoderModel
|
||||||
|
|
||||||
|
from .bbox import Bbox
|
||||||
|
|
||||||
|
|
||||||
|
TexTellerModel: TypeAlias = VisionEncoderDecoderModel | ORTModelForVision2Seq
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Bbox", "TexTellerModel"]
|
||||||
@@ -1,10 +1,3 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from PIL import Image, ImageDraw
|
|
||||||
from typing import List
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
class Point:
|
class Point:
|
||||||
def __init__(self, x: int, y: int):
|
def __init__(self, x: int, y: int):
|
||||||
self.x = int(x)
|
self.x = int(x)
|
||||||
@@ -51,9 +44,9 @@ class Bbox:
|
|||||||
return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD
|
return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD
|
||||||
|
|
||||||
def __lt__(self, other) -> bool:
|
def __lt__(self, other) -> bool:
|
||||||
'''
|
"""
|
||||||
from top to bottom, from left to right
|
from top to bottom, from left to right
|
||||||
'''
|
"""
|
||||||
if not self.same_row(other):
|
if not self.same_row(other):
|
||||||
return self.p.y < other.p.y
|
return self.p.y < other.p.y
|
||||||
else:
|
else:
|
||||||
@@ -61,29 +54,3 @@ class Bbox:
|
|||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})"
|
return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})"
|
||||||
|
|
||||||
|
|
||||||
def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"):
|
|
||||||
curr_work_dir = Path(os.getcwd())
|
|
||||||
log_dir = curr_work_dir / "logs"
|
|
||||||
log_dir.mkdir(exist_ok=True)
|
|
||||||
drawer = ImageDraw.Draw(img)
|
|
||||||
for bbox in bboxes:
|
|
||||||
# Calculate the coordinates for the rectangle to be drawn
|
|
||||||
left = bbox.p.x
|
|
||||||
top = bbox.p.y
|
|
||||||
right = bbox.p.x + bbox.w
|
|
||||||
bottom = bbox.p.y + bbox.h
|
|
||||||
|
|
||||||
# Draw the rectangle on the image
|
|
||||||
drawer.rectangle([left, top, right, bottom], outline="green", width=1)
|
|
||||||
|
|
||||||
# Optionally, add text label if it exists
|
|
||||||
if bbox.label:
|
|
||||||
drawer.text((left, top), bbox.label, fill="blue")
|
|
||||||
|
|
||||||
if bbox.content:
|
|
||||||
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
|
|
||||||
|
|
||||||
# Save the image with drawn rectangles
|
|
||||||
img.save(log_dir / name)
|
|
||||||
26
texteller/utils/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from .device import get_device, cuda_available, mps_available, str2device
|
||||||
|
from .image import readimgs, transform
|
||||||
|
from .latex import change_all, remove_style, add_newlines
|
||||||
|
from .path import mkdir, resolve_path
|
||||||
|
from .misc import lines_dedent
|
||||||
|
from .bbox import mask_img, bbox_merge, split_conflict, slice_from_image, draw_bboxes
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_device",
|
||||||
|
"cuda_available",
|
||||||
|
"mps_available",
|
||||||
|
"str2device",
|
||||||
|
"readimgs",
|
||||||
|
"transform",
|
||||||
|
"change_all",
|
||||||
|
"remove_style",
|
||||||
|
"add_newlines",
|
||||||
|
"mkdir",
|
||||||
|
"resolve_path",
|
||||||
|
"lines_dedent",
|
||||||
|
"mask_img",
|
||||||
|
"bbox_merge",
|
||||||
|
"split_conflict",
|
||||||
|
"slice_from_image",
|
||||||
|
"draw_bboxes",
|
||||||
|
]
|
||||||
142
texteller/utils/bbox.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import heapq
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
from texteller.types import Bbox
|
||||||
|
|
||||||
|
_MAXV = 999999999
|
||||||
|
|
||||||
|
|
||||||
|
def mask_img(img, bboxes: list[Bbox], bg_color: np.ndarray) -> np.ndarray:
|
||||||
|
mask_img = img.copy()
|
||||||
|
for bbox in bboxes:
|
||||||
|
mask_img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w] = bg_color
|
||||||
|
return mask_img
|
||||||
|
|
||||||
|
|
||||||
|
def bbox_merge(sorted_bboxes: list[Bbox]) -> list[Bbox]:
|
||||||
|
if len(sorted_bboxes) == 0:
|
||||||
|
return []
|
||||||
|
bboxes = sorted_bboxes.copy()
|
||||||
|
guard = Bbox(_MAXV, bboxes[-1].p.y, -1, -1, label="guard")
|
||||||
|
bboxes.append(guard)
|
||||||
|
res = []
|
||||||
|
prev = bboxes[0]
|
||||||
|
for curr in bboxes:
|
||||||
|
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
|
||||||
|
res.append(prev)
|
||||||
|
prev = curr
|
||||||
|
else:
|
||||||
|
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def split_conflict(ocr_bboxes: list[Bbox], latex_bboxes: list[Bbox]) -> list[Bbox]:
|
||||||
|
if latex_bboxes == []:
|
||||||
|
return ocr_bboxes
|
||||||
|
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
|
||||||
|
return ocr_bboxes
|
||||||
|
|
||||||
|
bboxes = sorted(ocr_bboxes + latex_bboxes)
|
||||||
|
|
||||||
|
assert len(bboxes) > 1
|
||||||
|
|
||||||
|
heapq.heapify(bboxes)
|
||||||
|
res = []
|
||||||
|
candidate = heapq.heappop(bboxes)
|
||||||
|
curr = heapq.heappop(bboxes)
|
||||||
|
idx = 0
|
||||||
|
while len(bboxes) > 0:
|
||||||
|
idx += 1
|
||||||
|
assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
|
||||||
|
|
||||||
|
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
|
||||||
|
res.append(candidate)
|
||||||
|
candidate = curr
|
||||||
|
curr = heapq.heappop(bboxes)
|
||||||
|
elif candidate.ur_point.x < curr.ur_point.x:
|
||||||
|
assert not (candidate.label != "text" and curr.label != "text")
|
||||||
|
if candidate.label == "text" and curr.label == "text":
|
||||||
|
candidate.w = curr.ur_point.x - candidate.p.x
|
||||||
|
curr = heapq.heappop(bboxes)
|
||||||
|
elif candidate.label != curr.label:
|
||||||
|
if candidate.label == "text":
|
||||||
|
candidate.w = curr.p.x - candidate.p.x
|
||||||
|
res.append(candidate)
|
||||||
|
candidate = curr
|
||||||
|
curr = heapq.heappop(bboxes)
|
||||||
|
else:
|
||||||
|
curr.w = curr.ur_point.x - candidate.ur_point.x
|
||||||
|
curr.p.x = candidate.ur_point.x
|
||||||
|
heapq.heappush(bboxes, curr)
|
||||||
|
curr = heapq.heappop(bboxes)
|
||||||
|
|
||||||
|
elif candidate.ur_point.x >= curr.ur_point.x:
|
||||||
|
assert not (candidate.label != "text" and curr.label != "text")
|
||||||
|
|
||||||
|
if candidate.label == "text":
|
||||||
|
assert curr.label != "text"
|
||||||
|
heapq.heappush(
|
||||||
|
bboxes,
|
||||||
|
Bbox(
|
||||||
|
curr.ur_point.x,
|
||||||
|
candidate.p.y,
|
||||||
|
candidate.h,
|
||||||
|
candidate.ur_point.x - curr.ur_point.x,
|
||||||
|
label="text",
|
||||||
|
confidence=candidate.confidence,
|
||||||
|
content=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
candidate.w = curr.p.x - candidate.p.x
|
||||||
|
res.append(candidate)
|
||||||
|
candidate = curr
|
||||||
|
curr = heapq.heappop(bboxes)
|
||||||
|
else:
|
||||||
|
assert curr.label == "text"
|
||||||
|
curr = heapq.heappop(bboxes)
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
res.append(candidate)
|
||||||
|
res.append(curr)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def slice_from_image(img: np.ndarray, ocr_bboxes: list[Bbox]) -> list[np.ndarray]:
|
||||||
|
sliced_imgs = []
|
||||||
|
for bbox in ocr_bboxes:
|
||||||
|
x, y = int(bbox.p.x), int(bbox.p.y)
|
||||||
|
w, h = int(bbox.w), int(bbox.h)
|
||||||
|
sliced_img = img[y : y + h, x : x + w]
|
||||||
|
sliced_imgs.append(sliced_img)
|
||||||
|
return sliced_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def draw_bboxes(img: Image.Image, bboxes: list[Bbox], name="annotated_image.png"):
|
||||||
|
curr_work_dir = Path(os.getcwd())
|
||||||
|
log_dir = curr_work_dir / "logs"
|
||||||
|
log_dir.mkdir(exist_ok=True)
|
||||||
|
drawer = ImageDraw.Draw(img)
|
||||||
|
for bbox in bboxes:
|
||||||
|
# Calculate the coordinates for the rectangle to be drawn
|
||||||
|
left = bbox.p.x
|
||||||
|
top = bbox.p.y
|
||||||
|
right = bbox.p.x + bbox.w
|
||||||
|
bottom = bbox.p.y + bbox.h
|
||||||
|
|
||||||
|
# Draw the rectangle on the image
|
||||||
|
drawer.rectangle([left, top, right, bottom], outline="green", width=1)
|
||||||
|
|
||||||
|
# Optionally, add text label if it exists
|
||||||
|
if bbox.label:
|
||||||
|
drawer.text((left, top), bbox.label, fill="blue")
|
||||||
|
|
||||||
|
if bbox.content:
|
||||||
|
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
|
||||||
|
|
||||||
|
# Save the image with drawn rectangles
|
||||||
|
img.save(log_dir / name)
|
||||||
41
texteller/utils/device.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def str2device(device_str: Literal["cpu", "cuda", "mps"]) -> torch.device:
|
||||||
|
if device_str == "cpu":
|
||||||
|
return torch.device("cpu")
|
||||||
|
elif device_str == "cuda":
|
||||||
|
return torch.device("cuda")
|
||||||
|
elif device_str == "mps":
|
||||||
|
return torch.device("mps")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid device: {device_str}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_device(device_index: int = None) -> torch.device:
|
||||||
|
"""
|
||||||
|
Automatically detect the best available device for inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_index: The index of GPU device to use if multiple are available.
|
||||||
|
Defaults to None, which uses the first available GPU.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.device: Selected device for model inference.
|
||||||
|
"""
|
||||||
|
if cuda_available():
|
||||||
|
return str2device("cuda")
|
||||||
|
elif mps_available():
|
||||||
|
return str2device("mps")
|
||||||
|
else:
|
||||||
|
return str2device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_available() -> bool:
|
||||||
|
return torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
|
def mps_available() -> bool:
|
||||||
|
return torch.backends.mps.is_available()
|
||||||
121
texteller/utils/image.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
from collections import Counter
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
|
from texteller.constants import (
|
||||||
|
FIXED_IMG_SIZE,
|
||||||
|
IMG_CHANNELS,
|
||||||
|
IMAGE_MEAN,
|
||||||
|
IMAGE_STD,
|
||||||
|
)
|
||||||
|
from texteller.logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
_logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def readimgs(image_paths: list[str]) -> list[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Read and preprocess a list of images from their file paths.
|
||||||
|
|
||||||
|
This function reads each image from the provided paths, handles different
|
||||||
|
bit depths (converting 16-bit to 8-bit if necessary), and normalizes color
|
||||||
|
channels to RGB format regardless of the original color space (BGR, BGRA,
|
||||||
|
or grayscale).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_paths (list[str]): A list of file paths to the images to be read.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[np.ndarray]: A list of NumPy arrays containing the preprocessed images
|
||||||
|
in RGB format. Images that could not be read are skipped.
|
||||||
|
"""
|
||||||
|
processed_images = []
|
||||||
|
for path in image_paths:
|
||||||
|
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
if image is None:
|
||||||
|
raise ValueError(f"Image at {path} could not be read.")
|
||||||
|
if image.dtype == np.uint16:
|
||||||
|
_logger.warning(f'Converting {path} to 8-bit, image may be lossy.')
|
||||||
|
image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
|
||||||
|
|
||||||
|
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
||||||
|
if channels == 4:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
||||||
|
elif channels == 1:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||||
|
elif channels == 3:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
processed_images.append(image)
|
||||||
|
|
||||||
|
return processed_images
|
||||||
|
|
||||||
|
|
||||||
|
def trim_white_border(image: np.ndarray) -> np.ndarray:
|
||||||
|
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||||
|
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||||
|
|
||||||
|
if image.dtype != np.uint8:
|
||||||
|
raise ValueError(f"Image should stored in uint8")
|
||||||
|
|
||||||
|
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
|
||||||
|
bg_color = Counter(corners).most_common(1)[0][0]
|
||||||
|
bg_color_np = np.array(bg_color, dtype=np.uint8)
|
||||||
|
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
|
||||||
|
|
||||||
|
diff = cv2.absdiff(image, bg)
|
||||||
|
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
threshold = 15
|
||||||
|
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
|
||||||
|
|
||||||
|
x, y, w, h = cv2.boundingRect(diff)
|
||||||
|
|
||||||
|
trimmed_image = image[y : y + h, x : x + w]
|
||||||
|
|
||||||
|
return trimmed_image
|
||||||
|
|
||||||
|
|
||||||
|
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
|
||||||
|
images = [
|
||||||
|
v2.functional.pad(
|
||||||
|
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
||||||
|
)
|
||||||
|
for img in images
|
||||||
|
]
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
|
||||||
|
general_transform_pipeline = v2.Compose(
|
||||||
|
[
|
||||||
|
v2.ToImage(),
|
||||||
|
v2.ToDtype(torch.uint8, scale=True),
|
||||||
|
v2.Grayscale(),
|
||||||
|
v2.Resize(
|
||||||
|
size=FIXED_IMG_SIZE - 1,
|
||||||
|
interpolation=v2.InterpolationMode.BICUBIC,
|
||||||
|
max_size=FIXED_IMG_SIZE,
|
||||||
|
antialias=True,
|
||||||
|
),
|
||||||
|
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
||||||
|
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
||||||
|
images = [
|
||||||
|
np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images
|
||||||
|
]
|
||||||
|
images = [trim_white_border(image) for image in images]
|
||||||
|
images = [general_transform_pipeline(image) for image in images]
|
||||||
|
images = padding(images, FIXED_IMG_SIZE)
|
||||||
|
|
||||||
|
return images
|
||||||
128
texteller/utils/latex.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def _change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
|
||||||
|
result = ""
|
||||||
|
i = 0
|
||||||
|
n = len(input_str)
|
||||||
|
|
||||||
|
while i < n:
|
||||||
|
if input_str[i : i + len(old_inst)] == old_inst:
|
||||||
|
# check if the old_inst is followed by old_surr_l
|
||||||
|
start = i + len(old_inst)
|
||||||
|
else:
|
||||||
|
result += input_str[i]
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if start < n and input_str[start] == old_surr_l:
|
||||||
|
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
|
||||||
|
count = 1
|
||||||
|
j = start + 1
|
||||||
|
escaped = False
|
||||||
|
while j < n and count > 0:
|
||||||
|
if input_str[j] == '\\' and not escaped:
|
||||||
|
escaped = True
|
||||||
|
j += 1
|
||||||
|
continue
|
||||||
|
if input_str[j] == old_surr_r and not escaped:
|
||||||
|
count -= 1
|
||||||
|
if count == 0:
|
||||||
|
break
|
||||||
|
elif input_str[j] == old_surr_l and not escaped:
|
||||||
|
count += 1
|
||||||
|
escaped = False
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
if count == 0:
|
||||||
|
assert j < n
|
||||||
|
assert input_str[start] == old_surr_l
|
||||||
|
assert input_str[j] == old_surr_r
|
||||||
|
inner_content = input_str[start + 1 : j]
|
||||||
|
# Replace the content with new pattern
|
||||||
|
result += new_inst + new_surr_l + inner_content + new_surr_r
|
||||||
|
i = j + 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
assert count >= 1
|
||||||
|
assert j == n
|
||||||
|
print("Warning: unbalanced surrogate pair in input string")
|
||||||
|
result += new_inst + new_surr_l
|
||||||
|
i = start + 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
result += input_str[i:start]
|
||||||
|
i = start
|
||||||
|
|
||||||
|
if old_inst != new_inst and (old_inst + old_surr_l) in result:
|
||||||
|
return _change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
|
||||||
|
else:
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _find_substring_positions(string, substring):
|
||||||
|
positions = [match.start() for match in re.finditer(re.escape(substring), string)]
|
||||||
|
return positions
|
||||||
|
|
||||||
|
|
||||||
|
def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
|
||||||
|
pos = _find_substring_positions(input_str, old_inst + old_surr_l)
|
||||||
|
res = list(input_str)
|
||||||
|
for p in pos[::-1]:
|
||||||
|
res[p:] = list(
|
||||||
|
_change(
|
||||||
|
''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r
|
||||||
|
)
|
||||||
|
)
|
||||||
|
res = ''.join(res)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def remove_style(input_str: str) -> str:
|
||||||
|
input_str = change_all(input_str, r"\bm", r" ", r"{", r"}", r"", r" ")
|
||||||
|
input_str = change_all(input_str, r"\boldsymbol", r" ", r"{", r"}", r"", r" ")
|
||||||
|
input_str = change_all(input_str, r"\textit", r" ", r"{", r"}", r"", r" ")
|
||||||
|
input_str = change_all(input_str, r"\textbf", r" ", r"{", r"}", r"", r" ")
|
||||||
|
input_str = change_all(input_str, r"\textbf", r" ", r"{", r"}", r"", r" ")
|
||||||
|
input_str = change_all(input_str, r"\mathbf", r" ", r"{", r"}", r"", r" ")
|
||||||
|
output_str = input_str.strip()
|
||||||
|
return output_str
|
||||||
|
|
||||||
|
|
||||||
|
def add_newlines(latex_str: str) -> str:
|
||||||
|
"""
|
||||||
|
Adds newlines to a LaTeX string based on specific patterns, ensuring no
|
||||||
|
duplicate newlines are added around begin/end environments.
|
||||||
|
- After \\ (if not already followed by newline)
|
||||||
|
- Before \\begin{...} (if not already preceded by newline)
|
||||||
|
- After \\begin{...} (if not already followed by newline)
|
||||||
|
- Before \\end{...} (if not already preceded by newline)
|
||||||
|
- After \\end{...} (if not already followed by newline)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latex_str: The input LaTeX string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The LaTeX string with added newlines, avoiding duplicates.
|
||||||
|
"""
|
||||||
|
processed_str = latex_str
|
||||||
|
|
||||||
|
# 1. Replace whitespace around \begin{...} with \n...\n
|
||||||
|
# \s* matches zero or more whitespace characters (space, tab, newline)
|
||||||
|
# Captures the \begin{...} part in group 1 (\g<1>)
|
||||||
|
processed_str = re.sub(r"\s*(\\begin\{[^}]*\})\s*", r"\n\g<1>\n", processed_str)
|
||||||
|
|
||||||
|
# 2. Replace whitespace around \end{...} with \n...\n
|
||||||
|
# Same logic as for \begin
|
||||||
|
processed_str = re.sub(r"\s*(\\end\{[^}]*\})\s*", r"\n\g<1>\n", processed_str)
|
||||||
|
|
||||||
|
# 3. Add newline after \\ (if not already followed by newline)
|
||||||
|
processed_str = re.sub(r"\\\\(?!\n| )|\\\\ ", r"\\\\\n", processed_str)
|
||||||
|
|
||||||
|
# 4. Cleanup: Collapse multiple consecutive newlines into a single newline.
|
||||||
|
# This handles cases where the replacements above might have created \n\n.
|
||||||
|
processed_str = re.sub(r'\n{2,}', '\n', processed_str)
|
||||||
|
|
||||||
|
# Remove leading/trailing whitespace (including potential single newlines
|
||||||
|
# at the very start/end resulting from the replacements) from the entire result.
|
||||||
|
return processed_str.strip()
|
||||||
5
texteller/utils/misc.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from textwrap import dedent
|
||||||
|
|
||||||
|
|
||||||
|
def lines_dedent(s: str) -> str:
|
||||||
|
return dedent(s).strip()
|
||||||
52
texteller/utils/path.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from texteller.logger import get_logger
|
||||||
|
|
||||||
|
_logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_path(path: str | Path) -> str:
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
|
return str(path.expanduser().resolve())
|
||||||
|
|
||||||
|
|
||||||
|
def touch(path: str | Path) -> None:
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
|
path.touch(exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir(path: str | Path) -> None:
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def rmfile(path: str | Path) -> None:
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
|
path.unlink(missing_ok=False)
|
||||||
|
|
||||||
|
|
||||||
|
def rmdir(path: str | Path, mode: Literal["empty", "recursive"] = "empty") -> None:
|
||||||
|
"""Remove a directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to directory to remove
|
||||||
|
mode: "empty" to only remove empty directories, "all" to recursively remove all contents
|
||||||
|
"""
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if mode == "empty":
|
||||||
|
path.rmdir()
|
||||||
|
_logger.info(f"Removed empty directory: {path}")
|
||||||
|
elif mode == "recursive":
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
shutil.rmtree(path)
|
||||||
|
_logger.info(f"Recursively removed directory and all contents: {path}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid mode: {mode}. Must be 'empty' or 'all'")
|
||||||