diff --git a/texteller/__init__.py b/texteller/__init__.py new file mode 100644 index 0000000..369cb6d --- /dev/null +++ b/texteller/__init__.py @@ -0,0 +1 @@ +from texteller.api import * diff --git a/texteller/api/__init__.py b/texteller/api/__init__.py new file mode 100644 index 0000000..a0dccbe --- /dev/null +++ b/texteller/api/__init__.py @@ -0,0 +1,24 @@ +from .detection import latex_detect +from .format import format_latex +from .inference import img2latex, paragraph2md +from .katex import to_katex +from .load import ( + load_latexdet_model, + load_model, + load_textdet_model, + load_textrec_model, + load_tokenizer, +) + +__all__ = [ + "to_katex", + "format_latex", + "img2latex", + "paragraph2md", + "load_model", + "load_tokenizer", + "load_latexdet_model", + "load_textrec_model", + "load_textdet_model", + "latex_detect", +] diff --git a/texteller/api/criterias/__init__.py b/texteller/api/criterias/__init__.py new file mode 100644 index 0000000..3d54112 --- /dev/null +++ b/texteller/api/criterias/__init__.py @@ -0,0 +1,4 @@ +from .ngram import DetectRepeatingNgramCriteria + + +__all__ = ["DetectRepeatingNgramCriteria"] diff --git a/texteller/models/ocr_model/utils/inference.py b/texteller/api/criterias/ngram.py similarity index 57% rename from texteller/models/ocr_model/utils/inference.py rename to texteller/api/criterias/ngram.py index e00ea12..5d68e57 100644 --- a/texteller/models/ocr_model/utils/inference.py +++ b/texteller/api/criterias/ngram.py @@ -1,16 +1,8 @@ import torch -import numpy as np - -from transformers import RobertaTokenizerFast, GenerationConfig, StoppingCriteria -from typing import List, Union - -from .transforms import inference_transform -from .helpers import convert2rgb -from ..model.TexTeller import TexTeller -from ...globals import MAX_TOKEN_SIZE +from transformers import StoppingCriteria -class EfficientDetectRepeatingNgramCriteria(StoppingCriteria): +class DetectRepeatingNgramCriteria(StoppingCriteria): """ Stops generation efficiently if any n-gram repeats. @@ -69,48 +61,3 @@ class EfficientDetectRepeatingNgramCriteria(StoppingCriteria): # It's a new n-gram, add it to the set and continue self.seen_ngrams.add(last_ngram_tuple) return False # Continue generation - - -def inference( - model: TexTeller, - tokenizer: RobertaTokenizerFast, - imgs: Union[List[str], List[np.ndarray]], - accelerator: str = 'cpu', - num_beams: int = 1, - max_tokens=None, -) -> List[str]: - if imgs == []: - return [] - if hasattr(model, 'eval'): - # not onnx session, turn model.eval() - model.eval() - if isinstance(imgs[0], str): - imgs = convert2rgb(imgs) - else: # already numpy array(rgb format) - assert isinstance(imgs[0], np.ndarray) - imgs = imgs - imgs = inference_transform(imgs) - pixel_values = torch.stack(imgs) - - if hasattr(model, 'eval'): - # not onnx session, move weights to device - model = model.to(accelerator) - pixel_values = pixel_values.to(accelerator) - - generate_config = GenerationConfig( - max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens, - num_beams=num_beams, - do_sample=False, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - bos_token_id=tokenizer.bos_token_id, - # no_repeat_ngram_size=10, - ) - pred = model.generate( - pixel_values.to(model.device), - generation_config=generate_config, - # stopping_criteria=[EfficientDetectRepeatingNgramCriteria(20)], - ) - - res = tokenizer.batch_decode(pred, skip_special_tokens=True) - return res diff --git a/texteller/api/detection/__init__.py b/texteller/api/detection/__init__.py new file mode 100644 index 0000000..3edad3d --- /dev/null +++ b/texteller/api/detection/__init__.py @@ -0,0 +1,3 @@ +from .detect import latex_detect + +__all__ = ["latex_detect"] diff --git a/texteller/api/detection/detect.py b/texteller/api/detection/detect.py new file mode 100644 index 0000000..184116d --- /dev/null +++ b/texteller/api/detection/detect.py @@ -0,0 +1,48 @@ +from typing import List + +from onnxruntime import InferenceSession + +from texteller.types import Bbox + +from .preprocess import Compose + +_config = { + "mode": "paddle", + "draw_threshold": 0.5, + "metric": "COCO", + "use_dynamic_shape": False, + "arch": "DETR", + "min_subgraph_size": 3, + "preprocess": [ + {"interp": 2, "keep_ratio": False, "target_size": [1600, 1600], "type": "Resize"}, + { + "mean": [0.0, 0.0, 0.0], + "norm_type": "none", + "std": [1.0, 1.0, 1.0], + "type": "NormalizeImage", + }, + {"type": "Permute"}, + ], + "label_list": ["isolated", "embedding"], +} + + +def latex_detect(img_path: str, predictor: InferenceSession) -> List[Bbox]: + transforms = Compose(_config["preprocess"]) + inputs = transforms(img_path) + inputs_name = [var.name for var in predictor.get_inputs()] + inputs = {k: inputs[k][None,] for k in inputs_name} + + outputs = predictor.run(output_names=None, input_feed=inputs)[0] + res = [] + for output in outputs: + cls_name = _config["label_list"][int(output[0])] + score = output[1] + xmin = int(max(output[2], 0)) + ymin = int(max(output[3], 0)) + xmax = int(output[4]) + ymax = int(output[5]) + if score > 0.5: + res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score)) + + return res diff --git a/texteller/api/detection/preprocess.py b/texteller/api/detection/preprocess.py new file mode 100644 index 0000000..5172ceb --- /dev/null +++ b/texteller/api/detection/preprocess.py @@ -0,0 +1,161 @@ +import copy + +import cv2 +import numpy as np + + +def decode_image(img_path): + if isinstance(img_path, str): + with open(img_path, "rb") as f: + im_read = f.read() + data = np.frombuffer(im_read, dtype="uint8") + else: + assert isinstance(img_path, np.ndarray) + data = img_path + + im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + img_info = { + "im_shape": np.array(im.shape[:2], dtype=np.float32), + "scale_factor": np.array([1.0, 1.0], dtype=np.float32), + } + return im, img_info + + +class Resize(object): + """resize image by target_size and max_size + Args: + target_size (int): the target size of image + keep_ratio (bool): whether keep_ratio or not, default true + interp (int): method of resize + """ + + def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR): + if isinstance(target_size, int): + target_size = [target_size, target_size] + self.target_size = target_size + self.keep_ratio = keep_ratio + self.interp = interp + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + assert len(self.target_size) == 2 + assert self.target_size[0] > 0 and self.target_size[1] > 0 + im_channel = im.shape[2] + im_scale_y, im_scale_x = self.generate_scale(im) + im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp) + im_info["im_shape"] = np.array(im.shape[:2]).astype("float32") + im_info["scale_factor"] = np.array([im_scale_y, im_scale_x]).astype("float32") + return im, im_info + + def generate_scale(self, im): + """ + Args: + im (np.ndarray): image (np.ndarray) + Returns: + im_scale_x: the resize ratio of X + im_scale_y: the resize ratio of Y + """ + origin_shape = im.shape[:2] + im_c = im.shape[2] + if self.keep_ratio: + im_size_min = np.min(origin_shape) + im_size_max = np.max(origin_shape) + target_size_min = np.min(self.target_size) + target_size_max = np.max(self.target_size) + im_scale = float(target_size_min) / float(im_size_min) + if np.round(im_scale * im_size_max) > target_size_max: + im_scale = float(target_size_max) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + else: + resize_h, resize_w = self.target_size + im_scale_y = resize_h / float(origin_shape[0]) + im_scale_x = resize_w / float(origin_shape[1]) + return im_scale_y, im_scale_x + + +class NormalizeImage(object): + """normalize image + Args: + mean (list): im - mean + std (list): im / std + is_scale (bool): whether need im / 255 + norm_type (str): type in ['mean_std', 'none'] + """ + + def __init__(self, mean, std, is_scale=True, norm_type="mean_std"): + self.mean = mean + self.std = std + self.is_scale = is_scale + self.norm_type = norm_type + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + im = im.astype(np.float32, copy=False) + if self.is_scale: + scale = 1.0 / 255.0 + im *= scale + + if self.norm_type == "mean_std": + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + im -= mean + im /= std + return im, im_info + + +class Permute(object): + """permute image + Args: + to_bgr (bool): whether convert RGB to BGR + channel_first (bool): whether convert HWC to CHW + """ + + def __init__( + self, + ): + super(Permute, self).__init__() + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + im = im.transpose((2, 0, 1)).copy() + return im, im_info + + +class Compose: + def __init__(self, transforms): + self.transforms = [] + for op_info in transforms: + new_op_info = op_info.copy() + op_type = new_op_info.pop("type") + self.transforms.append(eval(op_type)(**new_op_info)) + + def __call__(self, img_path): + img, im_info = decode_image(img_path) + for t in self.transforms: + img, im_info = t(img, im_info) + inputs = copy.deepcopy(im_info) + inputs["image"] = img + return inputs diff --git a/texteller/models/ocr_model/utils/latex_formatter.py b/texteller/api/format.py similarity index 88% rename from texteller/models/ocr_model/utils/latex_formatter.py rename to texteller/api/format.py index 02988f2..0b76bc8 100644 --- a/texteller/models/ocr_model/utils/latex_formatter.py +++ b/texteller/api/format.py @@ -5,9 +5,8 @@ Based on the Rust implementation at https://github.com/WGUNDERWOOD/tex-fmt """ import re -import argparse from dataclasses import dataclass -from typing import List, Optional, Tuple, Dict, Set +from typing import List, Optional, Tuple # Constants LINE_END = "\n" @@ -49,7 +48,7 @@ RE_SPLITTING_SHARED_LINE_CAPTURE = re.compile(f"(?P\\S.*?)(?P{SPLITTI @dataclass class Args: - """Command line arguments and configuration.""" + """Formatter configuration.""" tabchar: str = " " tabsize: int = 4 @@ -542,13 +541,29 @@ def indents_return_to_zero(state: State) -> bool: return state.indent.actual == 0 -def format_latex( - old_text: str, file: str = "input.tex", args: Optional[Args] = None -) -> Tuple[str, List[Log]]: - """Central function to format a LaTeX string.""" - if args is None: - args = Args() +def format_latex(text: str) -> str: + """Format LaTeX text with default formatting options. + This is the main API function for formatting LaTeX text. + It uses pre-defined default values for all formatting parameters. + + Args: + text: LaTeX text to format + + Returns: + Formatted LaTeX text + """ + # Use default configuration + args = Args() + file = "input.tex" + + # Format and return only the text + formatted_text, _ = _format_latex(text, file, args) + return formatted_text.strip() + + +def _format_latex(old_text: str, file: str, args: Args) -> Tuple[str, List[Log]]: + """Internal function to format a LaTeX string.""" logs = [] logs.append(Log(level="INFO", file=file, message="Formatting started.")) @@ -636,63 +651,3 @@ def format_latex( logs.append(Log(level="INFO", file=file, message="Formatting complete.")) return new_text, logs - - -def main(): - """Command-line entry point.""" - parser = argparse.ArgumentParser(description="Format LaTeX files") - parser.add_argument("file", help="LaTeX file to format") - parser.add_argument( - "--tabchar", - choices=["space", "tab"], - default="space", - help="Character to use for indentation", - ) - parser.add_argument("--tabsize", type=int, default=4, help="Number of spaces per indent level") - parser.add_argument("--wrap", action="store_true", help="Enable line wrapping") - parser.add_argument("--wraplen", type=int, default=80, help="Maximum line length") - parser.add_argument( - "--wrapmin", type=int, default=40, help="Minimum line length before wrapping" - ) - parser.add_argument( - "--lists", nargs="+", default=[], help="Additional environments to indent as lists" - ) - parser.add_argument("--verbose", "-v", action="count", default=0, help="Increase verbosity") - parser.add_argument("--output", "-o", help="Output file (default: overwrite input)") - - args_parsed = parser.parse_args() - - # Convert command line args to our Args class - args = Args( - tabchar="\t" if args_parsed.tabchar == "tab" else " ", - tabsize=args_parsed.tabsize, - wrap=args_parsed.wrap, - wraplen=args_parsed.wraplen, - wrapmin=args_parsed.wrapmin, - lists=args_parsed.lists, - verbosity=args_parsed.verbose, - ) - - # Read input file - with open(args_parsed.file, "r", encoding="utf-8") as f: - text = f.read() - - # Format the text - formatted_text, logs = format_latex(text, args_parsed.file, args) - - # Print logs if verbose - if args.verbosity > 0: - for log in logs: - if log.linum_new is not None: - print(f"{log.level} {log.file}:{log.linum_new}:{log.linum_old}: {log.message}") - else: - print(f"{log.level} {log.file}: {log.message}") - - # Write output - output_file = args_parsed.output or args_parsed.file - with open(output_file, "w", encoding="utf-8") as f: - f.write(formatted_text) - - -if __name__ == "__main__": - main() diff --git a/texteller/api/inference.py b/texteller/api/inference.py new file mode 100644 index 0000000..68ef8be --- /dev/null +++ b/texteller/api/inference.py @@ -0,0 +1,241 @@ +import re +import time +from collections import Counter +from typing import Literal + +import cv2 +import numpy as np +import torch +from onnxruntime import InferenceSession +from optimum.onnxruntime import ORTModelForVision2Seq +from transformers import GenerationConfig, RobertaTokenizerFast + +from texteller.constants import MAX_TOKEN_SIZE +from texteller.logger import get_logger +from texteller.paddleocr import predict_det, predict_rec +from texteller.types import Bbox, TexTellerModel +from texteller.utils import ( + bbox_merge, + get_device, + mask_img, + readimgs, + remove_style, + slice_from_image, + split_conflict, + transform, + add_newlines, +) + +from .detection import latex_detect +from .format import format_latex +from .katex import to_katex + +_logger = get_logger() + + +def img2latex( + model: TexTellerModel, + tokenizer: RobertaTokenizerFast, + images: list[str] | list[np.ndarray], + device: torch.device | None = None, + out_format: Literal["latex", "katex"] = "latex", + keep_style: bool = False, + max_tokens: int = MAX_TOKEN_SIZE, + num_beams: int = 1, + no_repeat_ngram_size: int = 0, +) -> list[str]: + """ + Convert images to LaTeX or KaTeX formatted strings. + + Args: + model: The TexTeller or ORTModelForVision2Seq model instance + tokenizer: The tokenizer for the model + images: List of image paths or numpy arrays (RGB format) + device: The torch device to use (defaults to available GPU or CPU) + out_format: Output format, either "latex" or "katex" + keep_style: Whether to keep the style of the LaTeX + max_tokens: Maximum number of tokens to generate + num_beams: Number of beams for beam search + no_repeat_ngram_size: Size of n-grams to prevent repetition + + Returns: + List of LaTeX or KaTeX strings corresponding to each input image + + Example usage: + >>> import torch + >>> from texteller import load_model, load_tokenizer, img2latex + + >>> model = load_model(model_path=None, use_onnx=False) + >>> tokenizer = load_tokenizer(tokenizer_path=None) + >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + >>> res = img2latex(model, tokenizer, ["path/to/image.png"], device=device, out_format="katex") + """ + assert isinstance(images, list) + assert len(images) > 0 + + if device is None: + device = get_device() + + if device.type != model.device.type: + if isinstance(model, ORTModelForVision2Seq): + _logger.warning( + f"Onnxruntime device mismatch: detected {str(device)} but model is on {str(model.device)}, using {str(model.device)} instead" + ) + else: + model = model.to(device=device) + + if isinstance(images[0], str): + images = readimgs(images) + else: # already numpy array(rgb format) + assert isinstance(images[0], np.ndarray) + images = images + + images = transform(images) + pixel_values = torch.stack(images) + + generate_config = GenerationConfig( + max_new_tokens=max_tokens, + num_beams=num_beams, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + bos_token_id=tokenizer.bos_token_id, + no_repeat_ngram_size=no_repeat_ngram_size, + ) + pred = model.generate( + pixel_values.to(model.device), + generation_config=generate_config, + ) + + res = tokenizer.batch_decode(pred, skip_special_tokens=True) + + if out_format == "katex": + res = [to_katex(r) for r in res] + + if not keep_style: + res = [remove_style(r) for r in res] + + res = [format_latex(r) for r in res] + res = [add_newlines(r) for r in res] + return res + + +def paragraph2md( + img_path: str, + latexdet_model: InferenceSession, + textdet_model: predict_det.TextDetector, + textrec_model: predict_rec.TextRecognizer, + latexrec_model: TexTellerModel, + tokenizer: RobertaTokenizerFast, + device: torch.device | None = None, + num_beams=1, +) -> str: + """ + Input a mixed image of formula text and output str (in markdown syntax) + """ + img = cv2.imread(img_path) + corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])] + bg_color = np.array(Counter(corners).most_common(1)[0][0]) + + start_time = time.time() + latex_bboxes = latex_detect(img_path, latexdet_model) + end_time = time.time() + _logger.info(f"latex_det_model time: {end_time - start_time:.2f}s") + latex_bboxes = sorted(latex_bboxes) + latex_bboxes = bbox_merge(latex_bboxes) + masked_img = mask_img(img, latex_bboxes, bg_color) + + start_time = time.time() + det_prediction, _ = textdet_model(masked_img) + end_time = time.time() + _logger.info(f"ocr_det_model time: {end_time - start_time:.2f}s") + ocr_bboxes = [ + Bbox( + p[0][0], + p[0][1], + p[3][1] - p[0][1], + p[1][0] - p[0][0], + label="text", + confidence=None, + content=None, + ) + for p in det_prediction + ] + + ocr_bboxes = sorted(ocr_bboxes) + ocr_bboxes = bbox_merge(ocr_bboxes) + ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes) + ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes)) + + sliced_imgs: list[np.ndarray] = slice_from_image(img, ocr_bboxes) + start_time = time.time() + rec_predictions, _ = textrec_model(sliced_imgs) + end_time = time.time() + _logger.info(f"ocr_rec_model time: {end_time - start_time:.2f}s") + + assert len(rec_predictions) == len(ocr_bboxes) + for content, bbox in zip(rec_predictions, ocr_bboxes): + bbox.content = content[0] + + latex_imgs = [] + for bbox in latex_bboxes: + latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w]) + start_time = time.time() + latex_rec_res = img2latex( + model=latexrec_model, + tokenizer=tokenizer, + images=latex_imgs, + num_beams=num_beams, + out_format="katex", + device=device, + keep_style=False, + ) + end_time = time.time() + _logger.info(f"latex_rec_model time: {end_time - start_time:.2f}s") + + for bbox, content in zip(latex_bboxes, latex_rec_res): + if bbox.label == "embedding": + bbox.content = " $" + content + "$ " + elif bbox.label == "isolated": + bbox.content = "\n\n" + r"$$" + content + r"$$" + "\n\n" + + bboxes = sorted(ocr_bboxes + latex_bboxes) + if bboxes == []: + return "" + + md = "" + prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard") + for curr in bboxes: + # Add the formula number back to the isolated formula + if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr): + curr.content = curr.content.strip() + if curr.content.startswith("(") and curr.content.endswith(")"): + curr.content = curr.content[1:-1] + + if re.search(r"\\tag\{.*\}$", md[:-4]) is not None: + # in case of multiple tag + md = md[:-5] + f", {curr.content}" + "}" + md[-4:] + else: + md = md[:-4] + f"\\tag{{{curr.content}}}" + md[-4:] + continue + + if not prev.same_row(curr): + md += " " + + if curr.label == "embedding": + # remove the bold effect from inline formulas + curr.content = remove_style(curr.content) + + # change split environment into aligned + curr.content = curr.content.replace(r"\begin{split}", r"\begin{aligned}") + curr.content = curr.content.replace(r"\end{split}", r"\end{aligned}") + + # remove extra spaces (keeping only one) + curr.content = re.sub(r" +", " ", curr.content) + assert curr.content.startswith("$") and curr.content.endswith("$") + curr.content = " $" + curr.content.strip("$") + "$ " + md += curr.content + prev = curr + + return md.strip() diff --git a/texteller/models/ocr_model/utils/to_katex.py b/texteller/api/katex.py similarity index 56% rename from texteller/models/ocr_model/utils/to_katex.py rename to texteller/api/katex.py index 3c54e4d..cc32b0e 100644 --- a/texteller/models/ocr_model/utils/to_katex.py +++ b/texteller/api/katex.py @@ -1,73 +1,10 @@ import re -from .latex_formatter import format_latex +from ..utils.latex import change_all +from .format import format_latex -def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r): - result = "" - i = 0 - n = len(input_str) - - while i < n: - if input_str[i : i + len(old_inst)] == old_inst: - # check if the old_inst is followed by old_surr_l - start = i + len(old_inst) - else: - result += input_str[i] - i += 1 - continue - - if start < n and input_str[start] == old_surr_l: - # found an old_inst followed by old_surr_l, now look for the matching old_surr_r - count = 1 - j = start + 1 - escaped = False - while j < n and count > 0: - if input_str[j] == '\\' and not escaped: - escaped = True - j += 1 - continue - if input_str[j] == old_surr_r and not escaped: - count -= 1 - if count == 0: - break - elif input_str[j] == old_surr_l and not escaped: - count += 1 - escaped = False - j += 1 - - if count == 0: - assert j < n - assert input_str[start] == old_surr_l - assert input_str[j] == old_surr_r - inner_content = input_str[start + 1 : j] - # Replace the content with new pattern - result += new_inst + new_surr_l + inner_content + new_surr_r - i = j + 1 - continue - else: - assert count >= 1 - assert j == n - print("Warning: unbalanced surrogate pair in input string") - result += new_inst + new_surr_l - i = start + 1 - continue - else: - result += input_str[i:start] - i = start - - if old_inst != new_inst and (old_inst + old_surr_l) in result: - return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r) - else: - return result - - -def find_substring_positions(string, substring): - positions = [match.start() for match in re.finditer(re.escape(substring), string)] - return positions - - -def rm_dollar_surr(content): +def _rm_dollar_surr(content): pattern = re.compile(r'\\[a-zA-Z]+\$.*?\$|\$.*?\$') matches = pattern.findall(content) @@ -79,19 +16,6 @@ def rm_dollar_surr(content): return content -def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r): - pos = find_substring_positions(input_str, old_inst + old_surr_l) - res = list(input_str) - for p in pos[::-1]: - res[p:] = list( - change( - ''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r - ) - ) - res = ''.join(res) - return res - - def to_katex(formula: str) -> str: res = formula # remove mbox surrounding @@ -182,13 +106,13 @@ def to_katex(formula: str) -> str: res = re.sub(r'(\\text\{[^}]*\}\s*){2,}', merge_texts, res) res = res.replace(r'\bf ', '') - res = rm_dollar_surr(res) + res = _rm_dollar_surr(res) # remove extra spaces (keeping only one) res = re.sub(r' +', ' ', res) # format latex res = res.strip() - res, logs = format_latex(res) + res = format_latex(res) return res diff --git a/texteller/api/load.py b/texteller/api/load.py new file mode 100644 index 0000000..de454bd --- /dev/null +++ b/texteller/api/load.py @@ -0,0 +1,66 @@ +from pathlib import Path + +import wget +from onnxruntime import InferenceSession +from transformers import RobertaTokenizerFast + +from texteller.constants import LATEX_DET_MODEL_URL, TEXT_DET_MODEL_URL, TEXT_REC_MODEL_URL +from texteller.globals import Globals +from texteller.logger import get_logger +from texteller.models import TexTeller +from texteller.paddleocr import predict_det, predict_rec +from texteller.paddleocr.utility import parse_args +from texteller.utils import cuda_available, mkdir, resolve_path +from texteller.types import TexTellerModel + +_logger = get_logger(__name__) + + +def load_model(model_dir: str | None = None, use_onnx: bool = False) -> TexTellerModel: + return TexTeller.from_pretrained(model_dir, use_onnx=use_onnx) + + +def load_tokenizer(tokenizer_dir: str | None = None) -> RobertaTokenizerFast: + return TexTeller.get_tokenizer(tokenizer_dir) + + +def load_latexdet_model() -> InferenceSession: + fpath = _maybe_download(LATEX_DET_MODEL_URL) + return InferenceSession( + resolve_path(fpath), + providers=["CUDAExecutionProvider" if cuda_available() else "CPUExecutionProvider"], + ) + + +def load_textrec_model() -> predict_rec.TextRecognizer: + fpath = _maybe_download(TEXT_REC_MODEL_URL) + paddleocr_args = parse_args() + paddleocr_args.use_onnx = True + paddleocr_args.rec_model_dir = resolve_path(fpath) + paddleocr_args.use_gpu = cuda_available() + predictor = predict_rec.TextRecognizer(paddleocr_args) + return predictor + + +def load_textdet_model() -> predict_det.TextDetector: + fpath = _maybe_download(TEXT_DET_MODEL_URL) + paddleocr_args = parse_args() + paddleocr_args.use_onnx = True + paddleocr_args.det_model_dir = resolve_path(fpath) + paddleocr_args.use_gpu = cuda_available() + predictor = predict_det.TextDetector(paddleocr_args) + return predictor + + +def _maybe_download(url: str, dirpath: str | None = None, force: bool = False) -> Path: + if dirpath is None: + dirpath = Globals().cache_dir + mkdir(dirpath) + + fname = Path(url).name + fpath = Path(dirpath) / fname + if not fpath.exists() or force: + _logger.info(f"Downloading {fname} from {url} to {fpath}") + wget.download(url, resolve_path(fpath)) + + return fpath diff --git a/texteller/cli/__init__.py b/texteller/cli/__init__.py new file mode 100644 index 0000000..fe3da87 --- /dev/null +++ b/texteller/cli/__init__.py @@ -0,0 +1,25 @@ +""" +CLI entry point for TexTeller. +""" + +import time + +import click + +from texteller.cli.commands.inference import inference +from texteller.cli.commands.launch import launch +from texteller.cli.commands.web import web + + +@click.group() +def cli(): + pass + + +cli.add_command(inference) +cli.add_command(web) +cli.add_command(launch) + + +if __name__ == "__main__": + cli() diff --git a/texteller/cli/commands/__init__.py b/texteller/cli/commands/__init__.py new file mode 100644 index 0000000..0d57cf5 --- /dev/null +++ b/texteller/cli/commands/__init__.py @@ -0,0 +1,3 @@ +""" +CLI commands for TexTeller +""" diff --git a/texteller/cli/commands/inference.py b/texteller/cli/commands/inference.py new file mode 100644 index 0000000..5b3a34e --- /dev/null +++ b/texteller/cli/commands/inference.py @@ -0,0 +1,51 @@ +""" +CLI command for formula inference from images. +""" + +import click + +from texteller.api import img2latex, load_model, load_tokenizer + + +@click.command() +@click.argument("image_path", type=click.Path(exists=True, file_okay=True, dir_okay=False)) +@click.option( + "--model-path", + type=click.Path(exists=True, file_okay=False, dir_okay=True), + default=None, + help="Path to the model dir path, if not provided, will use model from huggingface repo", +) +@click.option( + "--tokenizer-path", + type=click.Path(exists=True, file_okay=False, dir_okay=True), + default=None, + help="Path to the tokenizer dir path, if not provided, will use tokenizer from huggingface repo", +) +@click.option( + "--output-format", + type=click.Choice(["latex", "katex"]), + default="katex", + help="Output format, either latex or katex", +) +@click.option( + "--keep-style", + is_flag=True, + default=False, + help="Whether to keep the style of the LaTeX (e.g. bold, italic, etc.)", +) +def inference(image_path, model_path, tokenizer_path, output_format, keep_style): + """ + CLI command for formula inference from images. + """ + model = load_model(model_dir=model_path) + tknz = load_tokenizer(tokenizer_dir=tokenizer_path) + + pred = img2latex( + model=model, + tokenizer=tknz, + images=[image_path], + out_format=output_format, + keep_style=keep_style, + )[0] + + click.echo(f"Predicted LaTeX: ```\n{pred}\n```") diff --git a/texteller/cli/commands/launch/__init__.py b/texteller/cli/commands/launch/__init__.py new file mode 100644 index 0000000..13780db --- /dev/null +++ b/texteller/cli/commands/launch/__init__.py @@ -0,0 +1,106 @@ +""" +CLI commands for launching server. +""" + +import sys +import time + +import click +from ray import serve + +from texteller.globals import Globals +from texteller.utils import get_device + + +@click.command() +@click.option( + "-ckpt", + "--checkpoint_dir", + type=click.Path(exists=True, file_okay=False, dir_okay=True), + default=None, + help="Path to the checkpoint directory, if not provided, will use model from huggingface repo", +) +@click.option( + "-tknz", + "--tokenizer_dir", + type=click.Path(exists=True, file_okay=False, dir_okay=True), + default=None, + help="Path to the tokenizer directory, if not provided, will use tokenizer from huggingface repo", +) +@click.option( + "-p", + "--port", + type=int, + default=8000, + help="Port to run the server on", +) +@click.option( + "--num-replicas", + type=int, + default=1, + help="Number of replicas to run the server on", +) +@click.option( + "--ncpu-per-replica", + type=float, + default=1.0, + help="Number of CPUs per replica", +) +@click.option( + "--ngpu-per-replica", + type=float, + default=1.0, + help="Number of GPUs per replica", +) +@click.option( + "--num-beams", + type=int, + default=1, + help="Number of beams to use", +) +@click.option( + "--use-onnx", + is_flag=True, + type=bool, + default=False, + help="Use ONNX runtime", +) +def launch( + checkpoint_dir, + tokenizer_dir, + port, + num_replicas, + ncpu_per_replica, + ngpu_per_replica, + num_beams, + use_onnx, +): + """Launch the api server""" + device = get_device() + if ngpu_per_replica > 0 and not device.type == "cuda": + click.echo( + click.style( + f"Error: --ngpu-per-replica > 0 but detected device is {device.type}", + fg="red", + ) + ) + sys.exit(1) + + Globals().num_replicas = num_replicas + Globals().ncpu_per_replica = ncpu_per_replica + Globals().ngpu_per_replica = ngpu_per_replica + from texteller.cli.commands.launch.server import Ingress, TexTellerServer + + serve.start(http_options={"host": "0.0.0.0", "port": port}) + rec_server = TexTellerServer.bind( + checkpoint_dir=checkpoint_dir, + tokenizer_dir=tokenizer_dir, + use_onnx=use_onnx, + num_beams=num_beams, + ) + ingress = Ingress.bind(rec_server) + + serve.run(ingress, route_prefix="/predict") + + while True: + time.sleep(1) diff --git a/texteller/cli/commands/launch/server.py b/texteller/cli/commands/launch/server.py new file mode 100644 index 0000000..b31d191 --- /dev/null +++ b/texteller/cli/commands/launch/server.py @@ -0,0 +1,69 @@ +import numpy as np +import cv2 + +from starlette.requests import Request +from ray import serve +from ray.serve.handle import DeploymentHandle + +from texteller.api import load_model, load_tokenizer, img2latex +from texteller.utils import get_device +from texteller.globals import Globals +from typing import Literal + + +@serve.deployment( + num_replicas=Globals().num_replicas, + ray_actor_options={ + "num_cpus": Globals().ncpu_per_replica, + "num_gpus": Globals().ngpu_per_replica * 1.0 / 2, + }, +) +class TexTellerServer: + def __init__( + self, + checkpoint_dir: str, + tokenizer_dir: str, + use_onnx: bool = False, + out_format: Literal["latex", "katex"] = "katex", + keep_style: bool = False, + num_beams: int = 1, + ) -> None: + self.model = load_model( + model_dir=checkpoint_dir, + use_onnx=use_onnx, + ) + self.tokenizer = load_tokenizer(tokenizer_dir=tokenizer_dir) + self.num_beams = num_beams + self.out_format = out_format + self.keep_style = keep_style + + if not use_onnx: + self.model = self.model.to(get_device()) + + def predict(self, image_nparray: np.ndarray) -> str: + return img2latex( + model=self.model, + tokenizer=self.tokenizer, + images=[image_nparray], + device=get_device(), + out_format=self.out_format, + keep_style=self.keep_style, + num_beams=self.num_beams, + )[0] + + +@serve.deployment() +class Ingress: + def __init__(self, rec_server: DeploymentHandle) -> None: + self.texteller_server = rec_server + + async def __call__(self, request: Request) -> str: + form = await request.form() + img_rb = await form["img"].read() + + img_nparray = np.frombuffer(img_rb, np.uint8) + img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR) + img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB) + + pred = await self.texteller_server.predict.remote(img_nparray) + return pred diff --git a/texteller/cli/commands/web/__init__.py b/texteller/cli/commands/web/__init__.py new file mode 100644 index 0000000..9fab94e --- /dev/null +++ b/texteller/cli/commands/web/__init__.py @@ -0,0 +1,9 @@ +import os +import click +from pathlib import Path + + +@click.command() +def web(): + """Launch the web interface for TexTeller.""" + os.system(f"streamlit run {Path(__file__).parent / 'streamlit_demo.py'}") diff --git a/texteller/cli/commands/web/streamlit_demo.py b/texteller/cli/commands/web/streamlit_demo.py new file mode 100644 index 0000000..7361abb --- /dev/null +++ b/texteller/cli/commands/web/streamlit_demo.py @@ -0,0 +1,225 @@ +import base64 +import io +import os +import re +import shutil +import tempfile + +import streamlit as st +from PIL import Image +from streamlit_paste_button import paste_image_button as pbutton + +from texteller.api import ( + img2latex, + load_latexdet_model, + load_model, + load_textdet_model, + load_textrec_model, + load_tokenizer, + paragraph2md, +) +from texteller.cli.commands.web.style import ( + HEADER_HTML, + IMAGE_EMBED_HTML, + IMAGE_INFO_HTML, + SUCCESS_GIF_HTML, +) +from texteller.utils import str2device + +st.set_page_config(page_title="TexTeller", page_icon="🧮") + + +@st.cache_resource +def get_texteller(use_onnx): + return load_model(use_onnx=use_onnx) + + +@st.cache_resource +def get_tokenizer(): + return load_tokenizer() + + +@st.cache_resource +def get_latexdet_model(): + return load_latexdet_model() + + +@st.cache_resource() +def get_textrec_model(): + return load_textrec_model() + + +@st.cache_resource() +def get_textdet_model(): + return load_textdet_model() + + +def get_image_base64(img_file): + buffered = io.BytesIO() + img_file.seek(0) + img = Image.open(img_file) + img.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode() + + +def on_file_upload(): + st.session_state["UPLOADED_FILE_CHANGED"] = True + + +def change_side_bar(): + st.session_state["CHANGE_SIDEBAR_FLAG"] = True + + +if "start" not in st.session_state: + st.session_state["start"] = 1 + st.toast("Hooray!", icon="🎉") + +if "UPLOADED_FILE_CHANGED" not in st.session_state: + st.session_state["UPLOADED_FILE_CHANGED"] = False + +if "CHANGE_SIDEBAR_FLAG" not in st.session_state: + st.session_state["CHANGE_SIDEBAR_FLAG"] = False + +if "INF_MODE" not in st.session_state: + st.session_state["INF_MODE"] = "Formula recognition" + + +# ====== ====== + +with st.sidebar: + num_beams = 1 + + st.markdown("# 🔨️ Config") + st.markdown("") + + inf_mode = st.selectbox( + "Inference mode", + ("Formula recognition", "Paragraph recognition"), + on_change=change_side_bar, + ) + + num_beams = st.number_input( + "Number of beams", min_value=1, max_value=20, step=1, on_change=change_side_bar + ) + + device = st.radio("device", ("cpu", "cuda", "mps"), on_change=change_side_bar) + + st.markdown("## Seedup") + use_onnx = st.toggle("ONNX Runtime ") + + +# ====== ====== + + +# ====== ====== + +latexrec_model = get_texteller(use_onnx) +tokenizer = get_tokenizer() + +if inf_mode == "Paragraph recognition": + latexdet_model = get_latexdet_model() + textrec_model = get_textrec_model() + textdet_model = get_textdet_model() + +st.markdown(HEADER_HTML, unsafe_allow_html=True) + +uploaded_file = st.file_uploader(" ", type=["jpg", "png"], on_change=on_file_upload) + +paste_result = pbutton( + label="📋 Paste an image", + background_color="#5BBCFF", + hover_background_color="#3498db", +) +st.write("") + +if st.session_state["CHANGE_SIDEBAR_FLAG"] is True: + st.session_state["CHANGE_SIDEBAR_FLAG"] = False +elif uploaded_file or paste_result.image_data is not None: + if st.session_state["UPLOADED_FILE_CHANGED"] is False and paste_result.image_data is not None: + uploaded_file = io.BytesIO() + paste_result.image_data.save(uploaded_file, format="PNG") + uploaded_file.seek(0) + + if st.session_state["UPLOADED_FILE_CHANGED"] is True: + st.session_state["UPLOADED_FILE_CHANGED"] = False + + img = Image.open(uploaded_file) + + temp_dir = tempfile.mkdtemp() + png_fpath = os.path.join(temp_dir, "image.png") + img.save(png_fpath, "PNG") + + with st.container(height=300): + img_base64 = get_image_base64(uploaded_file) + + st.markdown( + IMAGE_EMBED_HTML.format(img_base64=img_base64), + unsafe_allow_html=True, + ) + + st.markdown( + IMAGE_INFO_HTML.format(img_height=img.height, img_width=img.width), + unsafe_allow_html=True, + ) + + st.write("") + + with st.spinner("Predicting..."): + if inf_mode == "Formula recognition": + pred = img2latex( + model=latexrec_model, + tokenizer=tokenizer, + images=[png_fpath], + device=str2device(device), + out_format="katex", + num_beams=num_beams, + keep_style=False, + )[0] + else: + pred = paragraph2md( + img_path=png_fpath, + latexdet_model=latexdet_model, + textdet_model=textdet_model, + textrec_model=textrec_model, + latexrec_model=latexrec_model, + tokenizer=tokenizer, + device=str2device(device), + num_beams=num_beams, + ) + + st.success("Completed!", icon="✅") + # st.markdown(SUCCESS_GIF_HTML, unsafe_allow_html=True) + # st.text_area("Predicted LaTeX", pred, height=150) + if inf_mode == "Formula recognition": + st.code(pred, language="latex") + elif inf_mode == "Paragraph recognition": + st.code(pred, language="markdown") + else: + raise ValueError(f"Invalid inference mode: {inf_mode}") + + if inf_mode == "Formula recognition": + st.latex(pred) + elif inf_mode == "Paragraph recognition": + mixed_res = re.split(r"(\$\$.*?\$\$)", pred, flags=re.DOTALL) + for text in mixed_res: + if text.startswith("$$") and text.endswith("$$"): + st.latex(text.strip("$$")) + else: + st.markdown(text) + + st.write("") + st.write("") + + with st.expander(":star2: :gray[Tips for better results]"): + st.markdown(""" + * :mag_right: Use a clear and high-resolution image. + * :scissors: Crop images as accurately as possible. + * :jigsaw: Split large multi line formulas into smaller ones. + * :page_facing_up: Use images with **white background and black text** as much as possible. + * :book: Use a font with good readability. + """) + shutil.rmtree(temp_dir) + + paste_result.image_data = None + +# ====== ====== diff --git a/texteller/cli/commands/web/style.py b/texteller/cli/commands/web/style.py new file mode 100644 index 0000000..7899167 --- /dev/null +++ b/texteller/cli/commands/web/style.py @@ -0,0 +1,55 @@ +from texteller.utils import lines_dedent + + +HEADER_HTML = lines_dedent(""" +

+ + 𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛 + +

+ """) + +SUCCESS_GIF_HTML = lines_dedent(""" +

+ + + +

+ """) + +FAIL_GIF_HTML = lines_dedent(""" +

+ + + +

+ """) + +IMAGE_EMBED_HTML = lines_dedent(""" + +
+ Input image +
+ """) + +IMAGE_INFO_HTML = lines_dedent(""" + +
+

Input image ({img_height}✖️{img_width})

+
+ """) diff --git a/texteller/client_demo.py b/texteller/client_demo.py deleted file mode 100644 index bfd8d95..0000000 --- a/texteller/client_demo.py +++ /dev/null @@ -1,12 +0,0 @@ -import requests - -rec_server_url = "http://127.0.0.1:8000/frec" -det_server_url = "http://127.0.0.1:8000/fdet" - -img_path = "/your/image/path/" -with open(img_path, 'rb') as img: - files = {'img': img} - response = requests.post(rec_server_url, files=files) - # response = requests.post(det_server_url, files=files) - -print(response.text) diff --git a/texteller/models/globals.py b/texteller/constants.py similarity index 59% rename from texteller/models/globals.py rename to texteller/constants.py index 8754d67..e2036e5 100644 --- a/texteller/models/globals.py +++ b/texteller/constants.py @@ -21,3 +21,13 @@ MIN_RESIZE_RATIO = 0.75 # Minimum height and width for input image for TexTeller MIN_HEIGHT = 12 MIN_WIDTH = 30 + +LATEX_DET_MODEL_URL = ( + "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx" +) +TEXT_REC_MODEL_URL = ( + "https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_server_rec.onnx" +) +TEXT_DET_MODEL_URL = ( + "https://huggingface.co/OleehyO/paddleocrv4.onnx/resolve/main/ch_PP-OCRv4_det.onnx" +) diff --git a/texteller/globals.py b/texteller/globals.py new file mode 100644 index 0000000..a30cdb4 --- /dev/null +++ b/texteller/globals.py @@ -0,0 +1,41 @@ +import logging +from pathlib import Path + + +class Globals: + """ + Singleton class for managing global variables with predefined and dynamic attributes. + + Usage Example: + >>> # 1. Access predefined variable (with default value) + >>> print(Globals().repo_name) # Output: OleehyO/TexTeller + + >>> # 2. Modify predefined variable + >>> Globals().repo_name = "NewRepo/NewProject" + >>> print(Globals().repo_name) # Output: NewRepo/NewProject + + >>> # 3. Dynamically add new variable + >>> Globals().new_var = "hello" + >>> print(Globals().new_var) # Output: hello + + >>> # 4. View all variables + >>> print(Globals()) # Output: + """ + + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if not self._initialized: + self.repo_name = "OleehyO/TexTeller" + self.logging_level = logging.INFO + self.cache_dir = Path("~/.cache/texteller").expanduser().resolve() + self.__class__._initialized = True + + def __repr__(self): + return f"" diff --git a/texteller/infer_det.py b/texteller/infer_det.py deleted file mode 100644 index 2250ae3..0000000 --- a/texteller/infer_det.py +++ /dev/null @@ -1,96 +0,0 @@ -import os -import argparse -import glob -import subprocess - -import onnxruntime -from pathlib import Path - -from models.det_model.inference import PredictConfig, predict_image - - -parser = argparse.ArgumentParser(description=__doc__) -parser.add_argument( - "--infer_cfg", type=str, help="infer_cfg.yml", default="./models/det_model/model/infer_cfg.yml" -) -parser.add_argument( - '--onnx_file', - type=str, - help="onnx model file path", - default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx", -) -parser.add_argument("--image_dir", type=str, default='./testImgs') -parser.add_argument("--image_file", type=str) -parser.add_argument("--imgsave_dir", type=str, default="./detect_results") -parser.add_argument( - '--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True -) - - -def get_test_images(infer_dir, infer_img): - """ - Get image path list in TEST mode - """ - assert ( - infer_img is not None or infer_dir is not None - ), "--image_file or --image_dir should be set" - assert infer_img is None or os.path.isfile(infer_img), "{} is not a file".format(infer_img) - assert infer_dir is None or os.path.isdir(infer_dir), "{} is not a directory".format(infer_dir) - - # infer_img has a higher priority - if infer_img and os.path.isfile(infer_img): - return [infer_img] - - images = set() - infer_dir = os.path.abspath(infer_dir) - assert os.path.isdir(infer_dir), "infer_dir {} is not a directory".format(infer_dir) - exts = ['jpg', 'jpeg', 'png', 'bmp'] - exts += [ext.upper() for ext in exts] - for ext in exts: - images.update(glob.glob('{}/*.{}'.format(infer_dir, ext))) - images = list(images) - - assert len(images) > 0, "no image found in {}".format(infer_dir) - print("Found {} inference images in total.".format(len(images))) - - return images - - -def download_file(url, filename): - print(f"Downloading {filename}...") - subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True) - print("Download complete.") - - -if __name__ == '__main__': - cur_path = os.getcwd() - script_dirpath = Path(__file__).resolve().parent - os.chdir(script_dirpath) - - FLAGS = parser.parse_args() - - if not os.path.exists(FLAGS.infer_cfg): - infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true" - download_file(infer_cfg_url, FLAGS.infer_cfg) - - if not os.path.exists(FLAGS.onnx_file): - onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true" - download_file(onnx_file_url, FLAGS.onnx_file) - - # load image list - img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) - - if FLAGS.use_gpu: - predictor = onnxruntime.InferenceSession( - FLAGS.onnx_file, providers=['CUDAExecutionProvider'] - ) - else: - predictor = onnxruntime.InferenceSession( - FLAGS.onnx_file, providers=['CPUExecutionProvider'] - ) - # load infer config - infer_config = PredictConfig(FLAGS.infer_cfg) - - predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list) - - os.chdir(cur_path) diff --git a/texteller/inference.py b/texteller/inference.py deleted file mode 100644 index f6cfe5b..0000000 --- a/texteller/inference.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -import argparse -import cv2 as cv - -from pathlib import Path -from onnxruntime import InferenceSession -from models.thrid_party.paddleocr.infer import predict_det, predict_rec -from models.thrid_party.paddleocr.infer import utility - -from models.utils import mix_inference -from models.ocr_model.utils.to_katex import to_katex -from models.ocr_model.utils.inference import inference as latex_inference - -from models.ocr_model.model.TexTeller import TexTeller -from models.det_model.inference import PredictConfig - - -if __name__ == '__main__': - os.chdir(Path(__file__).resolve().parent) - parser = argparse.ArgumentParser() - parser.add_argument('-img', type=str, required=True, help='path to the input image') - parser.add_argument( - '--inference-mode', - type=str, - default='cpu', - help='Inference mode, select one of cpu, cuda, or mps', - ) - parser.add_argument( - '--num-beam', type=int, default=1, help='number of beam search for decoding' - ) - parser.add_argument('-mix', action='store_true', help='use mix mode') - - args = parser.parse_args() - - # You can use your own checkpoint and tokenizer path. - print('Loading model and tokenizer...') - latex_rec_model = TexTeller.from_pretrained() - tokenizer = TexTeller.get_tokenizer() - print('Model and tokenizer loaded.') - - img_path = args.img - img = cv.imread(img_path) - print('Inference...') - if not args.mix: - res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam) - res = to_katex(res[0]) - print(res) - else: - infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml") - latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx") - - use_gpu = args.inference_mode == 'cuda' - SIZE_LIMIT = 20 * 1024 * 1024 - det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx" - rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx" - # The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime) - det_use_gpu = False - rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT) - - paddleocr_args = utility.parse_args() - paddleocr_args.use_onnx = True - paddleocr_args.det_model_dir = det_model_dir - paddleocr_args.rec_model_dir = rec_model_dir - - paddleocr_args.use_gpu = det_use_gpu - detector = predict_det.TextDetector(paddleocr_args) - paddleocr_args.use_gpu = rec_use_gpu - recognizer = predict_rec.TextRecognizer(paddleocr_args) - - lang_ocr_models = [detector, recognizer] - latex_rec_models = [latex_rec_model, tokenizer] - res = mix_inference( - img_path, - infer_config, - latex_det_model, - lang_ocr_models, - latex_rec_models, - args.inference_mode, - args.num_beam, - ) - print(res) diff --git a/texteller/logger.py b/texteller/logger.py new file mode 100644 index 0000000..52a6e92 --- /dev/null +++ b/texteller/logger.py @@ -0,0 +1,96 @@ +import inspect +import logging +import os +from datetime import datetime +from logging import Logger + +import colorama +from colorama import Fore, Style + +from texteller.globals import Globals + +# Initialize colorama for colored console output +colorama.init(autoreset=True) + + +TEMPLATE = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + +class ColoredFormatter(logging.Formatter): + """Custom formatter to add colors based on log level.""" + + FORMATS = { # noqa: E501 + logging.DEBUG: Fore.LIGHTBLACK_EX + TEMPLATE + Style.RESET_ALL, + logging.INFO: Fore.WHITE + TEMPLATE + Style.RESET_ALL, + logging.WARNING: Fore.YELLOW + TEMPLATE + Style.RESET_ALL, + logging.ERROR: Fore.RED + TEMPLATE + Style.RESET_ALL, + logging.CRITICAL: Fore.RED + Style.BRIGHT + TEMPLATE + Style.RESET_ALL, + } # noqa: E501 + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno, self.FORMATS[logging.INFO]) + formatter = logging.Formatter(log_fmt, datefmt="%Y-%m-%d %H:%M:%S") + return formatter.format(record) + + +def get_logger(name: str | None = None, use_file_handler: bool = False) -> Logger: + """ + Creates and configures a logger with the caller's module name (if provided) or the first two modules. + If the module name is too long, it takes the first two modules. + + Args: + name (str, optional): Custom logger name. If None, derives from caller's module. + use_file_handler (bool, optional): Whether to use a file handler. Defaults to False. + + Returns: + Logger: Configured logger with colored console output and file handler. + """ + # If name is not provided, derive it from the caller's module + if name is None: + # Get the caller's stack frame + frame = inspect.stack()[1] + module = inspect.getmodule(frame[0]) + if module and module.__name__: + module_name = module.__name__ + # Split module name and take first two components if too long + parts = module_name.split(".") + if len(parts) > 2: + name = ".".join(parts[:2]) + else: + name = module_name + else: + name = "root" + + # Create or get logger + logger = logging.getLogger(name) + + # Prevent duplicate handlers + if logger.handlers: + return logger + + # Set logger level + logger.setLevel(Globals().logging_level) + + # Create console handler with colored formatter + console_handler = logging.StreamHandler() + console_handler.setLevel(Globals().logging_level) + console_formatter = ColoredFormatter() + console_handler.setFormatter(console_formatter) + logger.addHandler(console_handler) + + # Create file handler + if use_file_handler: + log_dir = "logs" + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"{datetime.now().strftime('%Y%m%d')}.log") + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(Globals().logging_level) + # File formatter (no colors) + file_formatter = logging.Formatter(TEMPLATE, datefmt="%Y-%m-%d %H:%M:%S") + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + + # Prevent logger from propagating to root logger + logger.propagate = False + + return logger diff --git a/texteller/models/__init__.py b/texteller/models/__init__.py new file mode 100644 index 0000000..1beda4e --- /dev/null +++ b/texteller/models/__init__.py @@ -0,0 +1,3 @@ +from .texteller import TexTeller + +__all__ = ['TexTeller'] diff --git a/texteller/models/det_model/inference.py b/texteller/models/det_model/inference.py deleted file mode 100644 index c866ae7..0000000 --- a/texteller/models/det_model/inference.py +++ /dev/null @@ -1,226 +0,0 @@ -import os -import time -import yaml -import numpy as np -import cv2 - -from tqdm import tqdm -from typing import List -from .preprocess import Compose -from .Bbox import Bbox - - -# Global dictionary -SUPPORT_MODELS = { - 'YOLO', - 'PPYOLOE', - 'RCNN', - 'SSD', - 'Face', - 'FCOS', - 'SOLOv2', - 'TTFNet', - 'S2ANet', - 'JDE', - 'FairMOT', - 'DeepSORT', - 'GFL', - 'PicoDet', - 'CenterNet', - 'TOOD', - 'RetinaNet', - 'StrongBaseline', - 'STGCN', - 'YOLOX', - 'HRNet', - 'DETR', -} - - -class PredictConfig(object): - """set config of preprocess, postprocess and visualize - Args: - infer_config (str): path of infer_cfg.yml - """ - - def __init__(self, infer_config): - # parsing Yaml config for Preprocess - with open(infer_config) as f: - yml_conf = yaml.safe_load(f) - self.check_model(yml_conf) - self.arch = yml_conf['arch'] - self.preprocess_infos = yml_conf['Preprocess'] - self.min_subgraph_size = yml_conf['min_subgraph_size'] - self.label_list = yml_conf['label_list'] - self.use_dynamic_shape = yml_conf['use_dynamic_shape'] - self.draw_threshold = yml_conf.get("draw_threshold", 0.5) - self.mask = yml_conf.get("mask", False) - self.tracker = yml_conf.get("tracker", None) - self.nms = yml_conf.get("NMS", None) - self.fpn_stride = yml_conf.get("fpn_stride", None) - - color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)] - self.colors = { - label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list) - } - - if self.arch == 'RCNN' and yml_conf.get('export_onnx', False): - print('The RCNN export model is used for ONNX and it only supports batch_size = 1') - self.print_config() - - def check_model(self, yml_conf): - """ - Raises: - ValueError: loaded model not in supported model type - """ - for support_model in SUPPORT_MODELS: - if support_model in yml_conf['arch']: - return True - raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf['arch'], SUPPORT_MODELS)) - - def print_config(self): - print('----------- Model Configuration -----------') - print('%s: %s' % ('Model Arch', self.arch)) - print('%s: ' % ('Transform Order')) - for op_info in self.preprocess_infos: - print('--%s: %s' % ('transform op', op_info['type'])) - print('--------------------------------------------') - - -def draw_bbox(image, outputs, infer_config): - for output in outputs: - cls_id, score, xmin, ymin, xmax, ymax = output - if score > infer_config.draw_threshold: - label = infer_config.label_list[int(cls_id)] - color = infer_config.colors[label] - cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2) - cv2.putText( - image, - "{}: {:.2f}".format(label, score), - (int(xmin), int(ymin - 5)), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - color, - 2, - ) - return image - - -def predict_image(imgsave_dir, infer_config, predictor, img_list): - # load preprocess transforms - transforms = Compose(infer_config.preprocess_infos) - errImgList = [] - - # Check and create subimg_save_dir if not exist - subimg_save_dir = os.path.join(imgsave_dir, 'subimages') - os.makedirs(subimg_save_dir, exist_ok=True) - - first_image_skipped = False - total_time = 0 - num_images = 0 - # predict image - for img_path in tqdm(img_list): - img = cv2.imread(img_path) - if img is None: - print(f"Warning: Could not read image {img_path}. Skipping...") - errImgList.append(img_path) - continue - - inputs = transforms(img_path) - inputs_name = [var.name for var in predictor.get_inputs()] - inputs = {k: inputs[k][None,] for k in inputs_name} - - # Start timing - start_time = time.time() - - outputs = predictor.run(output_names=None, input_feed=inputs) - - # Stop timing - end_time = time.time() - inference_time = end_time - start_time - if not first_image_skipped: - first_image_skipped = True - else: - total_time += inference_time - num_images += 1 - print( - f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds" - ) - - print("ONNXRuntime predict: ") - if infer_config.arch in ["HRNet"]: - print(np.array(outputs[0])) - else: - bboxes = np.array(outputs[0]) - for bbox in bboxes: - if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold: - print(f"{int(bbox[0])} {bbox[1]} " f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}") - - # Save the subimages (crop from the original image) - subimg_counter = 1 - for output in np.array(outputs[0]): - cls_id, score, xmin, ymin, xmax, ymax = output - if score > infer_config.draw_threshold: - label = infer_config.label_list[int(cls_id)] - subimg = img[int(max(ymin, 0)) : int(ymax), int(max(xmin, 0)) : int(xmax)] - if len(subimg) == 0: - continue - - subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg" - subimg_path = os.path.join(subimg_save_dir, subimg_filename) - cv2.imwrite(subimg_path, subimg) - subimg_counter += 1 - - # Draw bounding boxes and save the image with bounding boxes - img_with_mask = img.copy() - for output in np.array(outputs[0]): - cls_id, score, xmin, ymin, xmax, ymax = output - if score > infer_config.draw_threshold: - cv2.rectangle( - img_with_mask, - (int(xmin), int(ymin)), - (int(xmax), int(ymax)), - (255, 255, 255), - -1, - ) # 盖白 - - img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config) - - output_dir = imgsave_dir - os.makedirs(output_dir, exist_ok=True) - draw_box_dir = os.path.join(output_dir, 'draw_box') - mask_white_dir = os.path.join(output_dir, 'mask_white') - os.makedirs(draw_box_dir, exist_ok=True) - os.makedirs(mask_white_dir, exist_ok=True) - - output_file_mask = os.path.join(mask_white_dir, os.path.basename(img_path)) - output_file_bbox = os.path.join(draw_box_dir, os.path.basename(img_path)) - cv2.imwrite(output_file_mask, img_with_mask) - cv2.imwrite(output_file_bbox, img_with_bbox) - - avg_time_per_image = total_time / num_images if num_images > 0 else 0 - print(f"Total inference time for {num_images} images: {total_time:.4f} seconds") - print(f"Average time per image: {avg_time_per_image:.4f} seconds") - print("ErrorImgs:") - print(errImgList) - - -def predict(img_path: str, predictor, infer_config) -> List[Bbox]: - transforms = Compose(infer_config.preprocess_infos) - inputs = transforms(img_path) - inputs_name = [var.name for var in predictor.get_inputs()] - inputs = {k: inputs[k][None,] for k in inputs_name} - - outputs = predictor.run(output_names=None, input_feed=inputs)[0] - res = [] - for output in outputs: - cls_name = infer_config.label_list[int(output[0])] - score = output[1] - xmin = int(max(output[2], 0)) - ymin = int(max(output[3], 0)) - xmax = int(output[4]) - ymax = int(output[5]) - if score > infer_config.draw_threshold: - res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score)) - - return res diff --git a/texteller/models/det_model/model/infer_cfg.yml b/texteller/models/det_model/model/infer_cfg.yml deleted file mode 100644 index 09e6603..0000000 --- a/texteller/models/det_model/model/infer_cfg.yml +++ /dev/null @@ -1,27 +0,0 @@ -mode: paddle -draw_threshold: 0.5 -metric: COCO -use_dynamic_shape: false -arch: DETR -min_subgraph_size: 3 -Preprocess: -- interp: 2 - keep_ratio: false - target_size: - - 1600 - - 1600 - type: Resize -- mean: - - 0.0 - - 0.0 - - 0.0 - norm_type: none - std: - - 1.0 - - 1.0 - - 1.0 - type: NormalizeImage -- type: Permute -label_list: -- isolated -- embedding diff --git a/texteller/models/det_model/preprocess.py b/texteller/models/det_model/preprocess.py deleted file mode 100644 index 935a2ae..0000000 --- a/texteller/models/det_model/preprocess.py +++ /dev/null @@ -1,485 +0,0 @@ -import numpy as np -import cv2 -import copy - - -def decode_image(img_path): - if isinstance(img_path, str): - with open(img_path, 'rb') as f: - im_read = f.read() - data = np.frombuffer(im_read, dtype='uint8') - else: - assert isinstance(img_path, np.ndarray) - data = img_path - - im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode - im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) - img_info = { - "im_shape": np.array(im.shape[:2], dtype=np.float32), - "scale_factor": np.array([1.0, 1.0], dtype=np.float32), - } - return im, img_info - - -class Resize(object): - """resize image by target_size and max_size - Args: - target_size (int): the target size of image - keep_ratio (bool): whether keep_ratio or not, default true - interp (int): method of resize - """ - - def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR): - if isinstance(target_size, int): - target_size = [target_size, target_size] - self.target_size = target_size - self.keep_ratio = keep_ratio - self.interp = interp - - def __call__(self, im, im_info): - """ - Args: - im (np.ndarray): image (np.ndarray) - im_info (dict): info of image - Returns: - im (np.ndarray): processed image (np.ndarray) - im_info (dict): info of processed image - """ - assert len(self.target_size) == 2 - assert self.target_size[0] > 0 and self.target_size[1] > 0 - im_channel = im.shape[2] - im_scale_y, im_scale_x = self.generate_scale(im) - im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp) - im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') - im_info['scale_factor'] = np.array([im_scale_y, im_scale_x]).astype('float32') - return im, im_info - - def generate_scale(self, im): - """ - Args: - im (np.ndarray): image (np.ndarray) - Returns: - im_scale_x: the resize ratio of X - im_scale_y: the resize ratio of Y - """ - origin_shape = im.shape[:2] - im_c = im.shape[2] - if self.keep_ratio: - im_size_min = np.min(origin_shape) - im_size_max = np.max(origin_shape) - target_size_min = np.min(self.target_size) - target_size_max = np.max(self.target_size) - im_scale = float(target_size_min) / float(im_size_min) - if np.round(im_scale * im_size_max) > target_size_max: - im_scale = float(target_size_max) / float(im_size_max) - im_scale_x = im_scale - im_scale_y = im_scale - else: - resize_h, resize_w = self.target_size - im_scale_y = resize_h / float(origin_shape[0]) - im_scale_x = resize_w / float(origin_shape[1]) - return im_scale_y, im_scale_x - - -class NormalizeImage(object): - """normalize image - Args: - mean (list): im - mean - std (list): im / std - is_scale (bool): whether need im / 255 - norm_type (str): type in ['mean_std', 'none'] - """ - - def __init__(self, mean, std, is_scale=True, norm_type='mean_std'): - self.mean = mean - self.std = std - self.is_scale = is_scale - self.norm_type = norm_type - - def __call__(self, im, im_info): - """ - Args: - im (np.ndarray): image (np.ndarray) - im_info (dict): info of image - Returns: - im (np.ndarray): processed image (np.ndarray) - im_info (dict): info of processed image - """ - im = im.astype(np.float32, copy=False) - if self.is_scale: - scale = 1.0 / 255.0 - im *= scale - - if self.norm_type == 'mean_std': - mean = np.array(self.mean)[np.newaxis, np.newaxis, :] - std = np.array(self.std)[np.newaxis, np.newaxis, :] - im -= mean - im /= std - return im, im_info - - -class Permute(object): - """permute image - Args: - to_bgr (bool): whether convert RGB to BGR - channel_first (bool): whether convert HWC to CHW - """ - - def __init__( - self, - ): - super(Permute, self).__init__() - - def __call__(self, im, im_info): - """ - Args: - im (np.ndarray): image (np.ndarray) - im_info (dict): info of image - Returns: - im (np.ndarray): processed image (np.ndarray) - im_info (dict): info of processed image - """ - im = im.transpose((2, 0, 1)).copy() - return im, im_info - - -class PadStride(object): - """padding image for model with FPN, instead PadBatch(pad_to_stride) in original config - Args: - stride (bool): model with FPN need image shape % stride == 0 - """ - - def __init__(self, stride=0): - self.coarsest_stride = stride - - def __call__(self, im, im_info): - """ - Args: - im (np.ndarray): image (np.ndarray) - im_info (dict): info of image - Returns: - im (np.ndarray): processed image (np.ndarray) - im_info (dict): info of processed image - """ - coarsest_stride = self.coarsest_stride - if coarsest_stride <= 0: - return im, im_info - im_c, im_h, im_w = im.shape - pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride) - pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride) - padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32) - padding_im[:, :im_h, :im_w] = im - return padding_im, im_info - - -class LetterBoxResize(object): - def __init__(self, target_size): - """ - Resize image to target size, convert normalized xywh to pixel xyxy - format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]). - Args: - target_size (int|list): image target size. - """ - super(LetterBoxResize, self).__init__() - if isinstance(target_size, int): - target_size = [target_size, target_size] - self.target_size = target_size - - def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)): - # letterbox: resize a rectangular image to a padded rectangular - shape = img.shape[:2] # [height, width] - ratio_h = float(height) / shape[0] - ratio_w = float(width) / shape[1] - ratio = min(ratio_h, ratio_w) - new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) # [width, height] - padw = (width - new_shape[0]) / 2 - padh = (height - new_shape[1]) / 2 - top, bottom = round(padh - 0.1), round(padh + 0.1) - left, right = round(padw - 0.1), round(padw + 0.1) - - img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border - img = cv2.copyMakeBorder( - img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color - ) # padded rectangular - return img, ratio, padw, padh - - def __call__(self, im, im_info): - """ - Args: - im (np.ndarray): image (np.ndarray) - im_info (dict): info of image - Returns: - im (np.ndarray): processed image (np.ndarray) - im_info (dict): info of processed image - """ - assert len(self.target_size) == 2 - assert self.target_size[0] > 0 and self.target_size[1] > 0 - height, width = self.target_size - h, w = im.shape[:2] - im, ratio, padw, padh = self.letterbox(im, height=height, width=width) - - new_shape = [round(h * ratio), round(w * ratio)] - im_info['im_shape'] = np.array(new_shape, dtype=np.float32) - im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32) - return im, im_info - - -class Pad(object): - def __init__(self, size, fill_value=[114.0, 114.0, 114.0]): - """ - Pad image to a specified size. - Args: - size (list[int]): image target size - fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0) - """ - super(Pad, self).__init__() - if isinstance(size, int): - size = [size, size] - self.size = size - self.fill_value = fill_value - - def __call__(self, im, im_info): - im_h, im_w = im.shape[:2] - h, w = self.size - if h == im_h and w == im_w: - im = im.astype(np.float32) - return im, im_info - - canvas = np.ones((h, w, 3), dtype=np.float32) - canvas *= np.array(self.fill_value, dtype=np.float32) - canvas[0:im_h, 0:im_w, :] = im.astype(np.float32) - im = canvas - return im, im_info - - -def rotate_point(pt, angle_rad): - """Rotate a point by an angle. - - Args: - pt (list[float]): 2 dimensional point to be rotated - angle_rad (float): rotation angle by radian - - Returns: - list[float]: Rotated point. - """ - assert len(pt) == 2 - sn, cs = np.sin(angle_rad), np.cos(angle_rad) - new_x = pt[0] * cs - pt[1] * sn - new_y = pt[0] * sn + pt[1] * cs - rotated_pt = [new_x, new_y] - - return rotated_pt - - -def _get_3rd_point(a, b): - """To calculate the affine matrix, three pairs of points are required. This - function is used to get the 3rd point, given 2D points a & b. - - The 3rd point is defined by rotating vector `a - b` by 90 degrees - anticlockwise, using b as the rotation center. - - Args: - a (np.ndarray): point(x,y) - b (np.ndarray): point(x,y) - - Returns: - np.ndarray: The 3rd point. - """ - assert len(a) == 2 - assert len(b) == 2 - direction = a - b - third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32) - - return third_pt - - -def get_affine_transform(center, input_size, rot, output_size, shift=(0.0, 0.0), inv=False): - """Get the affine transform matrix, given the center/scale/rot/output_size. - - Args: - center (np.ndarray[2, ]): Center of the bounding box (x, y). - scale (np.ndarray[2, ]): Scale of the bounding box - wrt [width, height]. - rot (float): Rotation angle (degree). - output_size (np.ndarray[2, ]): Size of the destination heatmaps. - shift (0-100%): Shift translation ratio wrt the width/height. - Default (0., 0.). - inv (bool): Option to inverse the affine transform direction. - (inv=False: src->dst or inv=True: dst->src) - - Returns: - np.ndarray: The transform matrix. - """ - assert len(center) == 2 - assert len(output_size) == 2 - assert len(shift) == 2 - if not isinstance(input_size, (np.ndarray, list)): - input_size = np.array([input_size, input_size], dtype=np.float32) - scale_tmp = input_size - - shift = np.array(shift) - src_w = scale_tmp[0] - dst_w = output_size[0] - dst_h = output_size[1] - - rot_rad = np.pi * rot / 180 - src_dir = rotate_point([0.0, src_w * -0.5], rot_rad) - dst_dir = np.array([0.0, dst_w * -0.5]) - - src = np.zeros((3, 2), dtype=np.float32) - src[0, :] = center + scale_tmp * shift - src[1, :] = center + src_dir + scale_tmp * shift - src[2, :] = _get_3rd_point(src[0, :], src[1, :]) - - dst = np.zeros((3, 2), dtype=np.float32) - dst[0, :] = [dst_w * 0.5, dst_h * 0.5] - dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir - dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) - - if inv: - trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) - else: - trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) - - return trans - - -class WarpAffine(object): - """Warp affine the image""" - - def __init__(self, keep_res=False, pad=31, input_h=512, input_w=512, scale=0.4, shift=0.1): - self.keep_res = keep_res - self.pad = pad - self.input_h = input_h - self.input_w = input_w - self.scale = scale - self.shift = shift - - def __call__(self, im, im_info): - """ - Args: - im (np.ndarray): image (np.ndarray) - im_info (dict): info of image - Returns: - im (np.ndarray): processed image (np.ndarray) - im_info (dict): info of processed image - """ - img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) - - h, w = img.shape[:2] - - if self.keep_res: - input_h = (h | self.pad) + 1 - input_w = (w | self.pad) + 1 - s = np.array([input_w, input_h], dtype=np.float32) - c = np.array([w // 2, h // 2], dtype=np.float32) - - else: - s = max(h, w) * 1.0 - input_h, input_w = self.input_h, self.input_w - c = np.array([w / 2.0, h / 2.0], dtype=np.float32) - - trans_input = get_affine_transform(c, s, 0, [input_w, input_h]) - img = cv2.resize(img, (w, h)) - inp = cv2.warpAffine(img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR) - return inp, im_info - - -# keypoint preprocess -def get_warp_matrix(theta, size_input, size_dst, size_target): - """This code is based on - https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py - - Calculate the transformation matrix under the constraint of unbiased. - Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased - Data Processing for Human Pose Estimation (CVPR 2020). - - Args: - theta (float): Rotation angle in degrees. - size_input (np.ndarray): Size of input image [w, h]. - size_dst (np.ndarray): Size of output image [w, h]. - size_target (np.ndarray): Size of ROI in input plane [w, h]. - - Returns: - matrix (np.ndarray): A matrix for transformation. - """ - theta = np.deg2rad(theta) - matrix = np.zeros((2, 3), dtype=np.float32) - scale_x = size_dst[0] / size_target[0] - scale_y = size_dst[1] / size_target[1] - matrix[0, 0] = np.cos(theta) * scale_x - matrix[0, 1] = -np.sin(theta) * scale_x - matrix[0, 2] = scale_x * ( - -0.5 * size_input[0] * np.cos(theta) - + 0.5 * size_input[1] * np.sin(theta) - + 0.5 * size_target[0] - ) - matrix[1, 0] = np.sin(theta) * scale_y - matrix[1, 1] = np.cos(theta) * scale_y - matrix[1, 2] = scale_y * ( - -0.5 * size_input[0] * np.sin(theta) - - 0.5 * size_input[1] * np.cos(theta) - + 0.5 * size_target[1] - ) - return matrix - - -class TopDownEvalAffine(object): - """apply affine transform to image and coords - - Args: - trainsize (list): [w, h], the standard size used to train - use_udp (bool): whether to use Unbiased Data Processing. - records(dict): the dict contained the image and coords - - Returns: - records (dict): contain the image and coords after tranformed - - """ - - def __init__(self, trainsize, use_udp=False): - self.trainsize = trainsize - self.use_udp = use_udp - - def __call__(self, image, im_info): - rot = 0 - imshape = im_info['im_shape'][::-1] - center = im_info['center'] if 'center' in im_info else imshape / 2.0 - scale = im_info['scale'] if 'scale' in im_info else imshape - if self.use_udp: - trans = get_warp_matrix( - rot, center * 2.0, [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale - ) - image = cv2.warpAffine( - image, - trans, - (int(self.trainsize[0]), int(self.trainsize[1])), - flags=cv2.INTER_LINEAR, - ) - else: - trans = get_affine_transform(center, scale, rot, self.trainsize) - image = cv2.warpAffine( - image, - trans, - (int(self.trainsize[0]), int(self.trainsize[1])), - flags=cv2.INTER_LINEAR, - ) - - return image, im_info - - -class Compose: - def __init__(self, transforms): - self.transforms = [] - for op_info in transforms: - new_op_info = op_info.copy() - op_type = new_op_info.pop('type') - self.transforms.append(eval(op_type)(**new_op_info)) - - def __call__(self, img_path): - img, im_info = decode_image(img_path) - for t in self.transforms: - img, im_info = t(img, im_info) - inputs = copy.deepcopy(im_info) - inputs['image'] = img - return inputs diff --git a/texteller/models/ocr_model/model/TexTeller.py b/texteller/models/ocr_model/model/TexTeller.py deleted file mode 100644 index 4f916cd..0000000 --- a/texteller/models/ocr_model/model/TexTeller.py +++ /dev/null @@ -1,43 +0,0 @@ -from pathlib import Path - -from ...globals import VOCAB_SIZE, FIXED_IMG_SIZE, IMG_CHANNELS, MAX_TOKEN_SIZE - -from transformers import RobertaTokenizerFast, VisionEncoderDecoderModel, VisionEncoderDecoderConfig - - -class TexTeller(VisionEncoderDecoderModel): - REPO_NAME = 'OleehyO/TexTeller' - - def __init__(self): - config = VisionEncoderDecoderConfig.from_pretrained( - Path(__file__).resolve().parent / "config.json" - ) - config.encoder.image_size = FIXED_IMG_SIZE - config.encoder.num_channels = IMG_CHANNELS - config.decoder.vocab_size = VOCAB_SIZE - config.decoder.max_position_embeddings = MAX_TOKEN_SIZE - - super().__init__(config=config) - - @classmethod - def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None): - if model_path is None or model_path == 'default': - if not use_onnx: - return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME) - else: - from optimum.onnxruntime import ORTModelForVision2Seq - - use_gpu = True if onnx_provider == 'cuda' else False - return ORTModelForVision2Seq.from_pretrained( - cls.REPO_NAME, - provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider", - ) - model_path = Path(model_path).resolve() - return VisionEncoderDecoderModel.from_pretrained(str(model_path)) - - @classmethod - def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast: - if tokenizer_path is None or tokenizer_path == 'default': - return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME) - tokenizer_path = Path(tokenizer_path).resolve() - return RobertaTokenizerFast.from_pretrained(str(tokenizer_path)) diff --git a/texteller/models/ocr_model/model/config.json b/texteller/models/ocr_model/model/config.json deleted file mode 100644 index 45365ba..0000000 --- a/texteller/models/ocr_model/model/config.json +++ /dev/null @@ -1,168 +0,0 @@ -{ - "_name_or_path": "OleehyO/TexTeller", - "architectures": [ - "VisionEncoderDecoderModel" - ], - "decoder": { - "_name_or_path": "", - "activation_dropout": 0.0, - "activation_function": "gelu", - "add_cross_attention": true, - "architectures": null, - "attention_dropout": 0.0, - "bad_words_ids": null, - "begin_suppress_tokens": null, - "bos_token_id": 0, - "chunk_size_feed_forward": 0, - "classifier_dropout": 0.0, - "cross_attention_hidden_size": 768, - "d_model": 1024, - "decoder_attention_heads": 16, - "decoder_ffn_dim": 4096, - "decoder_layerdrop": 0.0, - "decoder_layers": 12, - "decoder_start_token_id": 2, - "diversity_penalty": 0.0, - "do_sample": false, - "dropout": 0.1, - "early_stopping": false, - "encoder_no_repeat_ngram_size": 0, - "eos_token_id": 2, - "exponential_decay_length_penalty": null, - "finetuning_task": null, - "forced_bos_token_id": null, - "forced_eos_token_id": null, - "id2label": { - "0": "LABEL_0", - "1": "LABEL_1" - }, - "init_std": 0.02, - "is_decoder": true, - "is_encoder_decoder": false, - "label2id": { - "LABEL_0": 0, - "LABEL_1": 1 - }, - "layernorm_embedding": true, - "length_penalty": 1.0, - "max_length": 20, - "max_position_embeddings": 1024, - "min_length": 0, - "model_type": "trocr", - "no_repeat_ngram_size": 0, - "num_beam_groups": 1, - "num_beams": 1, - "num_return_sequences": 1, - "output_attentions": false, - "output_hidden_states": false, - "output_scores": false, - "pad_token_id": 1, - "prefix": null, - "problem_type": null, - "pruned_heads": {}, - "remove_invalid_values": false, - "repetition_penalty": 1.0, - "return_dict": true, - "return_dict_in_generate": false, - "scale_embedding": false, - "sep_token_id": null, - "suppress_tokens": null, - "task_specific_params": null, - "temperature": 1.0, - "tf_legacy_loss": false, - "tie_encoder_decoder": false, - "tie_word_embeddings": true, - "tokenizer_class": null, - "top_k": 50, - "top_p": 1.0, - "torch_dtype": null, - "torchscript": false, - "typical_p": 1.0, - "use_bfloat16": false, - "use_cache": false, - "use_learned_position_embeddings": true, - "vocab_size": 15000 - }, - "encoder": { - "_name_or_path": "", - "add_cross_attention": false, - "architectures": null, - "attention_probs_dropout_prob": 0.0, - "bad_words_ids": null, - "begin_suppress_tokens": null, - "bos_token_id": null, - "chunk_size_feed_forward": 0, - "cross_attention_hidden_size": null, - "decoder_start_token_id": null, - "diversity_penalty": 0.0, - "do_sample": false, - "early_stopping": false, - "encoder_no_repeat_ngram_size": 0, - "encoder_stride": 16, - "eos_token_id": null, - "exponential_decay_length_penalty": null, - "finetuning_task": null, - "forced_bos_token_id": null, - "forced_eos_token_id": null, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.0, - "hidden_size": 768, - "id2label": { - "0": "LABEL_0", - "1": "LABEL_1" - }, - "image_size": 448, - "initializer_range": 0.02, - "intermediate_size": 3072, - "is_decoder": false, - "is_encoder_decoder": false, - "label2id": { - "LABEL_0": 0, - "LABEL_1": 1 - }, - "layer_norm_eps": 1e-12, - "length_penalty": 1.0, - "max_length": 20, - "min_length": 0, - "model_type": "vit", - "no_repeat_ngram_size": 0, - "num_attention_heads": 12, - "num_beam_groups": 1, - "num_beams": 1, - "num_channels": 1, - "num_hidden_layers": 12, - "num_return_sequences": 1, - "output_attentions": false, - "output_hidden_states": false, - "output_scores": false, - "pad_token_id": null, - "patch_size": 16, - "prefix": null, - "problem_type": null, - "pruned_heads": {}, - "qkv_bias": false, - "remove_invalid_values": false, - "repetition_penalty": 1.0, - "return_dict": true, - "return_dict_in_generate": false, - "sep_token_id": null, - "suppress_tokens": null, - "task_specific_params": null, - "temperature": 1.0, - "tf_legacy_loss": false, - "tie_encoder_decoder": false, - "tie_word_embeddings": true, - "tokenizer_class": null, - "top_k": 50, - "top_p": 1.0, - "torch_dtype": null, - "torchscript": false, - "typical_p": 1.0, - "use_bfloat16": false - }, - "is_encoder_decoder": true, - "model_type": "vision-encoder-decoder", - "tie_word_embeddings": false, - "transformers_version": "4.41.2", - "use_cache": true -} diff --git a/texteller/models/ocr_model/train/dataset/train/0.png b/texteller/models/ocr_model/train/dataset/train/0.png deleted file mode 100644 index 9f27321..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/0.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/1.png b/texteller/models/ocr_model/train/dataset/train/1.png deleted file mode 100644 index bc65c5f..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/1.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/10.png b/texteller/models/ocr_model/train/dataset/train/10.png deleted file mode 100644 index b2306ab..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/10.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/11.png b/texteller/models/ocr_model/train/dataset/train/11.png deleted file mode 100644 index f8b20a1..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/11.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/12.png b/texteller/models/ocr_model/train/dataset/train/12.png deleted file mode 100644 index 5b3b285..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/12.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/13.png b/texteller/models/ocr_model/train/dataset/train/13.png deleted file mode 100644 index 692fcc2..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/13.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/14.png b/texteller/models/ocr_model/train/dataset/train/14.png deleted file mode 100644 index e7fe2fd..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/14.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/15.png b/texteller/models/ocr_model/train/dataset/train/15.png deleted file mode 100644 index fbbeb82..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/15.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/16.png b/texteller/models/ocr_model/train/dataset/train/16.png deleted file mode 100644 index be56e99..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/16.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/17.png b/texteller/models/ocr_model/train/dataset/train/17.png deleted file mode 100644 index 4f30cf1..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/17.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/18.png b/texteller/models/ocr_model/train/dataset/train/18.png deleted file mode 100644 index 8774d25..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/18.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/19.png b/texteller/models/ocr_model/train/dataset/train/19.png deleted file mode 100644 index 4d3daa5..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/19.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/2.png b/texteller/models/ocr_model/train/dataset/train/2.png deleted file mode 100644 index 8fe5dd9..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/2.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/20.png b/texteller/models/ocr_model/train/dataset/train/20.png deleted file mode 100644 index 45c400d..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/20.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/21.png b/texteller/models/ocr_model/train/dataset/train/21.png deleted file mode 100644 index 311c1fd..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/21.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/22.png b/texteller/models/ocr_model/train/dataset/train/22.png deleted file mode 100644 index 6273383..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/22.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/23.png b/texteller/models/ocr_model/train/dataset/train/23.png deleted file mode 100644 index 06dfcdb..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/23.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/24.png b/texteller/models/ocr_model/train/dataset/train/24.png deleted file mode 100644 index c718fd5..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/24.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/25.png b/texteller/models/ocr_model/train/dataset/train/25.png deleted file mode 100644 index b90ab45..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/25.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/26.png b/texteller/models/ocr_model/train/dataset/train/26.png deleted file mode 100644 index 087e6de..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/26.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/27.png b/texteller/models/ocr_model/train/dataset/train/27.png deleted file mode 100644 index 67f552c..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/27.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/28.png b/texteller/models/ocr_model/train/dataset/train/28.png deleted file mode 100644 index 3b29359..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/28.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/29.png b/texteller/models/ocr_model/train/dataset/train/29.png deleted file mode 100644 index 917e0ed..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/29.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/3.png b/texteller/models/ocr_model/train/dataset/train/3.png deleted file mode 100644 index 0354b68..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/3.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/30.png b/texteller/models/ocr_model/train/dataset/train/30.png deleted file mode 100644 index cb38168..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/30.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/31.png b/texteller/models/ocr_model/train/dataset/train/31.png deleted file mode 100644 index 973f951..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/31.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/32.png b/texteller/models/ocr_model/train/dataset/train/32.png deleted file mode 100644 index 7c019a5..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/32.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/33.png b/texteller/models/ocr_model/train/dataset/train/33.png deleted file mode 100644 index 172ff55..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/33.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/34.png b/texteller/models/ocr_model/train/dataset/train/34.png deleted file mode 100644 index 013c1cc..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/34.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/4.png b/texteller/models/ocr_model/train/dataset/train/4.png deleted file mode 100644 index b8b0e39..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/4.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/5.png b/texteller/models/ocr_model/train/dataset/train/5.png deleted file mode 100644 index db3af1f..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/5.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/6.png b/texteller/models/ocr_model/train/dataset/train/6.png deleted file mode 100644 index c171137..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/6.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/7.png b/texteller/models/ocr_model/train/dataset/train/7.png deleted file mode 100644 index 9c2f9a6..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/7.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/8.png b/texteller/models/ocr_model/train/dataset/train/8.png deleted file mode 100644 index 54e300a..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/8.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/9.png b/texteller/models/ocr_model/train/dataset/train/9.png deleted file mode 100644 index 9bf24fb..0000000 Binary files a/texteller/models/ocr_model/train/dataset/train/9.png and /dev/null differ diff --git a/texteller/models/ocr_model/train/dataset/train/metadata.jsonl b/texteller/models/ocr_model/train/dataset/train/metadata.jsonl deleted file mode 100644 index 23279de..0000000 --- a/texteller/models/ocr_model/train/dataset/train/metadata.jsonl +++ /dev/null @@ -1,35 +0,0 @@ -{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"} -{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"} -{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"} -{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"} -{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"} -{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"} -{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"} -{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"} -{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"} -{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"} -{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"} -{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"} -{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"} -{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"} -{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"} -{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"} -{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"} -{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"} -{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"} -{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"} -{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"} -{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"} -{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"} -{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"} -{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"} -{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"} -{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"} -{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"} -{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"} -{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"} -{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"} -{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"} -{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"} -{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"} -{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"} diff --git a/texteller/models/ocr_model/train/train.py b/texteller/models/ocr_model/train/train.py deleted file mode 100644 index 80b58af..0000000 --- a/texteller/models/ocr_model/train/train.py +++ /dev/null @@ -1,114 +0,0 @@ -import os - -from functools import partial -from pathlib import Path - -from datasets import load_dataset -from transformers import ( - Trainer, - TrainingArguments, - Seq2SeqTrainer, - Seq2SeqTrainingArguments, - GenerationConfig, -) - -from .training_args import CONFIG -from ..model.TexTeller import TexTeller -from ..utils.functional import ( - tokenize_fn, - collate_fn, - img_train_transform, - img_inf_transform, - filter_fn, -) -from ..utils.metrics import bleu_metric -from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT - - -def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer): - training_args = TrainingArguments(**CONFIG) - trainer = Trainer( - model, - training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - data_collator=collate_fn_with_tokenizer, - ) - - trainer.train(resume_from_checkpoint=None) - - -def evaluate(model, tokenizer, eval_dataset, collate_fn): - eval_config = CONFIG.copy() - eval_config['predict_with_generate'] = True - generate_config = GenerationConfig( - max_new_tokens=MAX_TOKEN_SIZE, - num_beams=1, - do_sample=False, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - bos_token_id=tokenizer.bos_token_id, - ) - eval_config['generation_config'] = generate_config - seq2seq_config = Seq2SeqTrainingArguments(**eval_config) - - trainer = Seq2SeqTrainer( - model, - seq2seq_config, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - data_collator=collate_fn, - compute_metrics=partial(bleu_metric, tokenizer=tokenizer), - ) - - eval_res = trainer.evaluate() - print(eval_res) - - -if __name__ == '__main__': - script_dirpath = Path(__file__).resolve().parent - os.chdir(script_dirpath) - - # dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train'] - dataset = load_dataset("imagefolder", data_dir=str(script_dirpath / 'dataset'))['train'] - dataset = dataset.filter( - lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH - ) - dataset = dataset.shuffle(seed=42) - dataset = dataset.flatten_indices() - - tokenizer = TexTeller.get_tokenizer() - # If you want use your own tokenizer, please modify the path to your tokenizer - # +tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer') - filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer) - dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8) - - map_fn = partial(tokenize_fn, tokenizer=tokenizer) - tokenized_dataset = dataset.map( - map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8 - ) - - # Split dataset into train and eval, ratio 9:1 - split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42) - train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] - train_dataset = train_dataset.with_transform(img_train_transform) - eval_dataset = eval_dataset.with_transform(img_inf_transform) - collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) - - # Train from scratch - model = TexTeller() - # or train from TexTeller pre-trained model: model = TexTeller.from_pretrained() - - # If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint - # +e.g. - # +model = TexTeller.from_pretrained( - # + '/path/to/your/model_checkpoint' - # +) - - enable_train = True - enable_evaluate = False - if enable_train: - train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer) - if enable_evaluate and len(eval_dataset) > 0: - evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer) diff --git a/texteller/models/ocr_model/train/training_args.py b/texteller/models/ocr_model/train/training_args.py deleted file mode 100644 index b377cab..0000000 --- a/texteller/models/ocr_model/train/training_args.py +++ /dev/null @@ -1,31 +0,0 @@ -CONFIG = { - "seed": 42, # Random seed for reproducibility - "use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code) - "learning_rate": 5e-5, # Learning rate - "num_train_epochs": 10, # Total number of training epochs - "per_device_train_batch_size": 4, # Batch size per GPU for training - "per_device_eval_batch_size": 8, # Batch size per GPU for evaluation - "output_dir": "train_result", # Output directory - "overwrite_output_dir": False, # If the output directory exists, do not delete its content - "report_to": ["tensorboard"], # Report logs to TensorBoard - "save_strategy": "steps", # Strategy to save checkpoints - "save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000) - "save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded - "logging_strategy": "steps", # Log every certain number of steps - "logging_steps": 500, # Number of steps between each log - "logging_nan_inf_filter": False, # Record logs for loss=nan or inf - "optim": "adamw_torch", # Optimizer - "lr_scheduler_type": "cosine", # Learning rate scheduler - "warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr) - "max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0) - "fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode) - "bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it) - "gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large - "jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors) - "torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance) - "dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU - "dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used - "evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch" - "eval_steps": 500, # If evaluation_strategy="step" - "remove_unused_columns": False, # Don't change this unless you really know what you are doing. -} diff --git a/texteller/models/ocr_model/utils/functional.py b/texteller/models/ocr_model/utils/functional.py deleted file mode 100644 index aa3199e..0000000 --- a/texteller/models/ocr_model/utils/functional.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch - -from transformers import DataCollatorForLanguageModeling -from typing import List, Dict, Any -from .transforms import train_transform, inference_transform -from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE - - -def left_move(x: torch.Tensor, pad_val): - assert len(x.shape) == 2, 'x should be 2-dimensional' - lefted_x = torch.ones_like(x) - lefted_x[:, :-1] = x[:, 1:] - lefted_x[:, -1] = pad_val - return lefted_x - - -def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]: - assert tokenizer is not None, 'tokenizer should not be None' - tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True) - tokenized_formula['pixel_values'] = samples['image'] - return tokenized_formula - - -def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]: - assert tokenizer is not None, 'tokenizer should not be None' - pixel_values = [dic.pop('pixel_values') for dic in samples] - - clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) - - batch = clm_collator(samples) - batch['pixel_values'] = pixel_values - batch['decoder_input_ids'] = batch.pop('input_ids') - batch['decoder_attention_mask'] = batch.pop('attention_mask') - - # 左移labels和decoder_attention_mask - batch['labels'] = left_move(batch['labels'], -100) - - # 把list of Image转成一个tensor with (B, C, H, W) - batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0) - return batch - - -def img_train_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - processed_img = train_transform(samples['pixel_values']) - samples['pixel_values'] = processed_img - return samples - - -def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: - processed_img = inference_transform(samples['pixel_values']) - samples['pixel_values'] = processed_img - return samples - - -def filter_fn(sample, tokenizer=None) -> bool: - return ( - sample['image'].height > MIN_HEIGHT - and sample['image'].width > MIN_WIDTH - and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10 - ) diff --git a/texteller/models/ocr_model/utils/helpers.py b/texteller/models/ocr_model/utils/helpers.py deleted file mode 100644 index 50e8bd0..0000000 --- a/texteller/models/ocr_model/utils/helpers.py +++ /dev/null @@ -1,26 +0,0 @@ -import cv2 -import numpy as np -from typing import List - - -def convert2rgb(image_paths: List[str]) -> List[np.ndarray]: - processed_images = [] - for path in image_paths: - image = cv2.imread(path, cv2.IMREAD_UNCHANGED) - if image is None: - print(f"Image at {path} could not be read.") - continue - if image.dtype == np.uint16: - print(f'Converting {path} to 8-bit, image may be lossy.') - image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0)) - - channels = 1 if len(image.shape) == 2 else image.shape[2] - if channels == 4: - image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) - elif channels == 1: - image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) - elif channels == 3: - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - processed_images.append(image) - - return processed_images diff --git a/texteller/models/ocr_model/utils/metrics.py b/texteller/models/ocr_model/utils/metrics.py deleted file mode 100644 index 13dc972..0000000 --- a/texteller/models/ocr_model/utils/metrics.py +++ /dev/null @@ -1,25 +0,0 @@ -import evaluate -import numpy as np -import os - -from pathlib import Path -from typing import Dict -from transformers import EvalPrediction, RobertaTokenizer - - -def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict: - cur_dir = Path(os.getcwd()) - os.chdir(Path(__file__).resolve().parent) - metric = evaluate.load( - 'google_bleu' - ) # Will download the metric from huggingface if not already downloaded - os.chdir(cur_dir) - - logits, labels = eval_preds.predictions, eval_preds.label_ids - preds = logits - - labels = np.where(labels == -100, 1, labels) - - preds = tokenizer.batch_decode(preds, skip_special_tokens=True) - labels = tokenizer.batch_decode(labels, skip_special_tokens=True) - return metric.compute(predictions=preds, references=labels) diff --git a/texteller/models/ocr_model/utils/ocr_aug.py b/texteller/models/ocr_model/utils/ocr_aug.py deleted file mode 100644 index a232735..0000000 --- a/texteller/models/ocr_model/utils/ocr_aug.py +++ /dev/null @@ -1,152 +0,0 @@ -from augraphy import * -import random - - -def ocr_augmentation_pipeline(): - pre_phase = [] - - ink_phase = [ - InkColorSwap( - ink_swap_color="random", - ink_swap_sequence_number_range=(5, 10), - ink_swap_min_width_range=(2, 3), - ink_swap_max_width_range=(100, 120), - ink_swap_min_height_range=(2, 3), - ink_swap_max_height_range=(100, 120), - ink_swap_min_area_range=(10, 20), - ink_swap_max_area_range=(400, 500), - # p=0.2 - p=0.4, - ), - LinesDegradation( - line_roi=(0.0, 0.0, 1.0, 1.0), - line_gradient_range=(32, 255), - line_gradient_direction=(0, 2), - line_split_probability=(0.2, 0.4), - line_replacement_value=(250, 255), - line_min_length=(30, 40), - line_long_to_short_ratio=(5, 7), - line_replacement_probability=(0.4, 0.5), - line_replacement_thickness=(1, 3), - # p=0.2 - p=0.4, - ), - # ============================ - OneOf( - [ - Dithering( - dither="floyd-steinberg", - order=(3, 5), - ), - InkBleed( - intensity_range=(0.1, 0.2), - kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]), - severity=(0.4, 0.6), - ), - ], - # p=0.2 - p=0.4, - ), - # ============================ - # ============================ - InkShifter( - text_shift_scale_range=(18, 27), - text_shift_factor_range=(1, 4), - text_fade_range=(0, 2), - blur_kernel_size=(5, 5), - blur_sigma=0, - noise_type="perlin", - # p=0.2 - p=0.4, - ), - # ============================ - ] - - paper_phase = [ - NoiseTexturize( # tested - sigma_range=(3, 10), - turbulence_range=(2, 5), - texture_width_range=(300, 500), - texture_height_range=(300, 500), - # p=0.2 - p=0.4, - ), - BrightnessTexturize( # tested - texturize_range=(0.9, 0.99), - deviation=0.03, - # p=0.2 - p=0.4, - ), - ] - - post_phase = [ - ColorShift( # tested - color_shift_offset_x_range=(3, 5), - color_shift_offset_y_range=(3, 5), - color_shift_iterations=(2, 3), - color_shift_brightness_range=(0.9, 1.1), - color_shift_gaussian_kernel_range=(3, 3), - # p=0.2 - p=0.4, - ), - DirtyDrum( # tested - line_width_range=(1, 6), - line_concentration=random.uniform(0.05, 0.15), - direction=random.randint(0, 2), - noise_intensity=random.uniform(0.6, 0.95), - noise_value=(64, 224), - ksize=random.choice([(3, 3), (5, 5), (7, 7)]), - sigmaX=0, - # p=0.2 - p=0.4, - ), - # ===================================== - OneOf( - [ - LightingGradient( - light_position=None, - direction=None, - max_brightness=255, - min_brightness=0, - mode="gaussian", - linear_decay_rate=None, - transparency=None, - ), - Brightness( - brightness_range=(0.9, 1.1), - min_brightness=0, - min_brightness_value=(120, 150), - ), - Gamma( - gamma_range=(0.9, 1.1), - ), - ], - # p=0.2 - p=0.4, - ), - # ===================================== - # ===================================== - OneOf( - [ - SubtleNoise( - subtle_range=random.randint(5, 10), - ), - Jpeg( - quality_range=(70, 95), - ), - ], - # p=0.2 - p=0.4, - ), - # ===================================== - ] - - pipeline = AugraphyPipeline( - ink_phase=ink_phase, - paper_phase=paper_phase, - post_phase=post_phase, - pre_phase=pre_phase, - log=False, - ) - - return pipeline diff --git a/texteller/models/ocr_model/utils/transforms.py b/texteller/models/ocr_model/utils/transforms.py deleted file mode 100644 index 7da2de0..0000000 --- a/texteller/models/ocr_model/utils/transforms.py +++ /dev/null @@ -1,177 +0,0 @@ -import torch -import random -import numpy as np -import cv2 - -from torchvision.transforms import v2 -from typing import List, Union -from PIL import Image -from collections import Counter - -from ...globals import ( - IMG_CHANNELS, - FIXED_IMG_SIZE, - IMAGE_MEAN, - IMAGE_STD, - MAX_RESIZE_RATIO, - MIN_RESIZE_RATIO, -) -from .ocr_aug import ocr_augmentation_pipeline - -# train_pipeline = default_augraphy_pipeline(scan_only=True) -train_pipeline = ocr_augmentation_pipeline() - -general_transform_pipeline = v2.Compose( - [ - v2.ToImage(), - v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point - v2.Grayscale(), - v2.Resize( - size=FIXED_IMG_SIZE - 1, - interpolation=v2.InterpolationMode.BICUBIC, - max_size=FIXED_IMG_SIZE, - antialias=True, - ), - v2.ToDtype(torch.float32, scale=True), # Normalize expects float input - v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]), - # v2.ToPILImage() - ] -) - - -def trim_white_border(image: np.ndarray): - if len(image.shape) != 3 or image.shape[2] != 3: - raise ValueError("Image is not in RGB format or channel is not in third dimension") - - if image.dtype != np.uint8: - raise ValueError(f"Image should stored in uint8") - - corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])] - bg_color = Counter(corners).most_common(1)[0][0] - bg_color_np = np.array(bg_color, dtype=np.uint8) - - h, w = image.shape[:2] - bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8) - - diff = cv2.absdiff(image, bg) - mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY) - - threshold = 15 - _, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY) - - x, y, w, h = cv2.boundingRect(diff) - - trimmed_image = image[y : y + h, x : x + w] - - return trimmed_image - - -def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray: - randi = [random.randint(0, max_size) for _ in range(4)] - pad_height_size = randi[1] + randi[3] - pad_width_size = randi[0] + randi[2] - if pad_height_size + image.shape[0] < 30: - compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1 - randi[1] += compensate_height - randi[3] += compensate_height - if pad_width_size + image.shape[1] < 30: - compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1 - randi[0] += compensate_width - randi[2] += compensate_width - return v2.functional.pad( - torch.from_numpy(image).permute(2, 0, 1), - padding=randi, - padding_mode='constant', - fill=(255, 255, 255), - ) - - -def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]: - images = [ - v2.functional.pad( - img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]] - ) - for img in images - ] - return images - - -def random_resize(images: List[np.ndarray], minr: float, maxr: float) -> List[np.ndarray]: - if len(images[0].shape) != 3 or images[0].shape[2] != 3: - raise ValueError("Image is not in RGB format or channel is not in third dimension") - - ratios = [random.uniform(minr, maxr) for _ in range(len(images))] - return [ - cv2.resize( - img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4 - ) # 抗锯齿 - for img, r in zip(images, ratios) - ] - - -def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray: - # Get the center of the image to define the point of rotation - image_center = tuple(np.array(image.shape[1::-1]) / 2) - - # Generate a random angle within the specified range - angle = random.randint(min_angle, max_angle) - - # Get the rotation matrix for rotating the image around its center - rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) - - # Determine the size of the rotated image - cos = np.abs(rotation_mat[0, 0]) - sin = np.abs(rotation_mat[0, 1]) - new_width = int((image.shape[0] * sin) + (image.shape[1] * cos)) - new_height = int((image.shape[0] * cos) + (image.shape[1] * sin)) - - # Adjust the rotation matrix to take into account translation - rotation_mat[0, 2] += (new_width / 2) - image_center[0] - rotation_mat[1, 2] += (new_height / 2) - image_center[1] - - # Rotate the image with the specified border color (white in this case) - rotated_image = cv2.warpAffine( - image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255) - ) - - return rotated_image - - -def ocr_aug(image: np.ndarray) -> np.ndarray: - if random.random() < 0.2: - image = rotate(image, -5, 5) - image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy() - image = train_pipeline(image) - return image - - -def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: - assert IMG_CHANNELS == 1, "Only support grayscale images for now" - - images = [np.array(img.convert('RGB')) for img in images] - # random resize first - images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO) - images = [trim_white_border(image) for image in images] - - # OCR augmentation - images = [ocr_aug(image) for image in images] - - # general transform pipeline - images = [general_transform_pipeline(image) for image in images] - # padding to fixed size - images = padding(images, FIXED_IMG_SIZE) - return images - - -def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]: - assert IMG_CHANNELS == 1, "Only support grayscale images for now" - images = [ - np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images - ] - images = [trim_white_border(image) for image in images] - # general transform pipeline - images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image] - # padding to fixed size - images = padding(images, FIXED_IMG_SIZE) - - return images diff --git a/texteller/models/texteller.py b/texteller/models/texteller.py new file mode 100644 index 0000000..e054212 --- /dev/null +++ b/texteller/models/texteller.py @@ -0,0 +1,48 @@ +from pathlib import Path + +from transformers import RobertaTokenizerFast, VisionEncoderDecoderConfig, VisionEncoderDecoderModel + +from texteller.constants import ( + FIXED_IMG_SIZE, + IMG_CHANNELS, + MAX_TOKEN_SIZE, + VOCAB_SIZE, +) +from texteller.globals import Globals +from texteller.types import TexTellerModel +from texteller.utils import cuda_available + + +class TexTeller(VisionEncoderDecoderModel): + def __init__(self): + config = VisionEncoderDecoderConfig.from_pretrained(Globals().repo_name) + config.encoder.image_size = FIXED_IMG_SIZE + config.encoder.num_channels = IMG_CHANNELS + config.decoder.vocab_size = VOCAB_SIZE + config.decoder.max_position_embeddings = MAX_TOKEN_SIZE + + super().__init__(config=config) + + @classmethod + def from_pretrained(cls, model_dir: str | None = None, use_onnx=False) -> TexTellerModel: + if model_dir is None or model_dir == Globals().repo_name: + if not use_onnx: + return VisionEncoderDecoderModel.from_pretrained(Globals().repo_name) + else: + from optimum.onnxruntime import ORTModelForVision2Seq + + return ORTModelForVision2Seq.from_pretrained( + Globals().repo_name, + provider="CUDAExecutionProvider" + if cuda_available() + else "CPUExecutionProvider", + ) + model_dir = Path(model_dir).resolve() + return VisionEncoderDecoderModel.from_pretrained(str(model_dir)) + + @classmethod + def get_tokenizer(cls, tokenizer_dir: str = None) -> RobertaTokenizerFast: + if tokenizer_dir is None or tokenizer_dir == Globals().repo_name: + return RobertaTokenizerFast.from_pretrained(Globals().repo_name) + tokenizer_dir = Path(tokenizer_dir).resolve() + return RobertaTokenizerFast.from_pretrained(str(tokenizer_dir)) diff --git a/texteller/models/thrid_party/paddleocr/checkpoints/det/default_model.onnx b/texteller/models/thrid_party/paddleocr/checkpoints/det/default_model.onnx deleted file mode 100644 index 3239921..0000000 Binary files a/texteller/models/thrid_party/paddleocr/checkpoints/det/default_model.onnx and /dev/null differ diff --git a/texteller/models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx b/texteller/models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx deleted file mode 100644 index 273117b..0000000 Binary files a/texteller/models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx and /dev/null differ diff --git a/texteller/models/tokenizer/train.py b/texteller/models/tokenizer/train.py deleted file mode 100644 index 80e5e0e..0000000 --- a/texteller/models/tokenizer/train.py +++ /dev/null @@ -1,24 +0,0 @@ -import os -from pathlib import Path -from datasets import load_dataset -from ..ocr_model.model.TexTeller import TexTeller -from ..globals import VOCAB_SIZE - - -if __name__ == '__main__': - script_dirpath = Path(__file__).resolve().parent - os.chdir(script_dirpath) - - tokenizer = TexTeller.get_tokenizer() - - # Don't forget to config your dataset path in loader.py - dataset = load_dataset('../ocr_model/train/dataset/loader.py')['train'] - - new_tokenizer = tokenizer.train_new_from_iterator( - text_iterator=dataset['latex_formula'], - # If you want to use a different vocab size, **change VOCAB_SIZE from globals.py** - vocab_size=VOCAB_SIZE, - ) - - # Save the new tokenizer for later training and inference - new_tokenizer.save_pretrained('./your_dir_name') diff --git a/texteller/models/utils/__init__.py b/texteller/models/utils/__init__.py deleted file mode 100644 index 3597062..0000000 --- a/texteller/models/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .mix_inference import mix_inference diff --git a/texteller/models/utils/mix_inference.py b/texteller/models/utils/mix_inference.py deleted file mode 100644 index 0f8aa4f..0000000 --- a/texteller/models/utils/mix_inference.py +++ /dev/null @@ -1,261 +0,0 @@ -import re -import heapq -import cv2 -import time -import numpy as np - -from collections import Counter -from typing import List -from PIL import Image - -from ..det_model.inference import predict as latex_det_predict -from ..det_model.Bbox import Bbox, draw_bboxes - -from ..ocr_model.utils.inference import inference as latex_rec_predict -from ..ocr_model.utils.to_katex import to_katex, change_all - -MAXV = 999999999 - - -def mask_img(img, bboxes: List[Bbox], bg_color: np.ndarray) -> np.ndarray: - mask_img = img.copy() - for bbox in bboxes: - mask_img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w] = bg_color - return mask_img - - -def bbox_merge(sorted_bboxes: List[Bbox]) -> List[Bbox]: - if len(sorted_bboxes) == 0: - return [] - bboxes = sorted_bboxes.copy() - guard = Bbox(MAXV, bboxes[-1].p.y, -1, -1, label="guard") - bboxes.append(guard) - res = [] - prev = bboxes[0] - for curr in bboxes: - if prev.ur_point.x <= curr.p.x or not prev.same_row(curr): - res.append(prev) - prev = curr - else: - prev.w = max(prev.w, curr.ur_point.x - prev.p.x) - return res - - -def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbox]: - if latex_bboxes == []: - return ocr_bboxes - if ocr_bboxes == [] or len(ocr_bboxes) == 1: - return ocr_bboxes - - bboxes = sorted(ocr_bboxes + latex_bboxes) - - # log results - for idx, bbox in enumerate(bboxes): - bbox.content = str(idx) - draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png") - - assert len(bboxes) > 1 - - heapq.heapify(bboxes) - res = [] - candidate = heapq.heappop(bboxes) - curr = heapq.heappop(bboxes) - idx = 0 - while len(bboxes) > 0: - idx += 1 - assert candidate.p.x <= curr.p.x or not candidate.same_row(curr) - - if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr): - res.append(candidate) - candidate = curr - curr = heapq.heappop(bboxes) - elif candidate.ur_point.x < curr.ur_point.x: - assert not (candidate.label != "text" and curr.label != "text") - if candidate.label == "text" and curr.label == "text": - candidate.w = curr.ur_point.x - candidate.p.x - curr = heapq.heappop(bboxes) - elif candidate.label != curr.label: - if candidate.label == "text": - candidate.w = curr.p.x - candidate.p.x - res.append(candidate) - candidate = curr - curr = heapq.heappop(bboxes) - else: - curr.w = curr.ur_point.x - candidate.ur_point.x - curr.p.x = candidate.ur_point.x - heapq.heappush(bboxes, curr) - curr = heapq.heappop(bboxes) - - elif candidate.ur_point.x >= curr.ur_point.x: - assert not (candidate.label != "text" and curr.label != "text") - - if candidate.label == "text": - assert curr.label != "text" - heapq.heappush( - bboxes, - Bbox( - curr.ur_point.x, - candidate.p.y, - candidate.h, - candidate.ur_point.x - curr.ur_point.x, - label="text", - confidence=candidate.confidence, - content=None, - ), - ) - candidate.w = curr.p.x - candidate.p.x - res.append(candidate) - candidate = curr - curr = heapq.heappop(bboxes) - else: - assert curr.label == "text" - curr = heapq.heappop(bboxes) - else: - assert False - res.append(candidate) - res.append(curr) - - # log results - for idx, bbox in enumerate(res): - bbox.content = str(idx) - draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png") - - return res - - -def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray]: - sliced_imgs = [] - for bbox in ocr_bboxes: - x, y = int(bbox.p.x), int(bbox.p.y) - w, h = int(bbox.w), int(bbox.h) - sliced_img = img[y : y + h, x : x + w] - sliced_imgs.append(sliced_img) - return sliced_imgs - - -def mix_inference( - img_path: str, - infer_config, - latex_det_model, - lang_ocr_models, - latex_rec_models, - accelerator="cpu", - num_beams=1, -) -> str: - ''' - Input a mixed image of formula text and output str (in markdown syntax) - ''' - global img - img = cv2.imread(img_path) - corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])] - bg_color = np.array(Counter(corners).most_common(1)[0][0]) - - start_time = time.time() - latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config) - end_time = time.time() - print(f"latex_det_model time: {end_time - start_time:.2f}s") - latex_bboxes = sorted(latex_bboxes) - # log results - draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png") - latex_bboxes = bbox_merge(latex_bboxes) - # log results - draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png") - masked_img = mask_img(img, latex_bboxes, bg_color) - - det_model, rec_model = lang_ocr_models - start_time = time.time() - det_prediction, _ = det_model(masked_img) - end_time = time.time() - print(f"ocr_det_model time: {end_time - start_time:.2f}s") - ocr_bboxes = [ - Bbox( - p[0][0], - p[0][1], - p[3][1] - p[0][1], - p[1][0] - p[0][0], - label="text", - confidence=None, - content=None, - ) - for p in det_prediction - ] - # log results - draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(unmerged).png") - - ocr_bboxes = sorted(ocr_bboxes) - ocr_bboxes = bbox_merge(ocr_bboxes) - # log results - draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png") - ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes) - ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes)) - - sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes) - start_time = time.time() - rec_predictions, _ = rec_model(sliced_imgs) - end_time = time.time() - print(f"ocr_rec_model time: {end_time - start_time:.2f}s") - - assert len(rec_predictions) == len(ocr_bboxes) - for content, bbox in zip(rec_predictions, ocr_bboxes): - bbox.content = content[0] - - latex_imgs = [] - for bbox in latex_bboxes: - latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w]) - start_time = time.time() - latex_rec_res = latex_rec_predict( - *latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=800 - ) - end_time = time.time() - print(f"latex_rec_model time: {end_time - start_time:.2f}s") - - for bbox, content in zip(latex_bboxes, latex_rec_res): - bbox.content = to_katex(content) - if bbox.label == "embedding": - bbox.content = " $" + bbox.content + "$ " - elif bbox.label == "isolated": - bbox.content = '\n\n' + r"$$" + bbox.content + r"$$" + '\n\n' - - bboxes = sorted(ocr_bboxes + latex_bboxes) - if bboxes == []: - return "" - - md = "" - prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard") - for curr in bboxes: - # Add the formula number back to the isolated formula - if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr): - curr.content = curr.content.strip() - if curr.content.startswith('(') and curr.content.endswith(')'): - curr.content = curr.content[1:-1] - - if re.search(r'\\tag\{.*\}$', md[:-4]) is not None: - # in case of multiple tag - md = md[:-5] + f', {curr.content}' + '}' + md[-4:] - else: - md = md[:-4] + f'\\tag{{{curr.content}}}' + md[-4:] - continue - - if not prev.same_row(curr): - md += " " - - if curr.label == "embedding": - # remove the bold effect from inline formulas - curr.content = change_all(curr.content, r'\bm', r' ', r'{', r'}', r'', r' ') - curr.content = change_all(curr.content, r'\boldsymbol', r' ', r'{', r'}', r'', r' ') - curr.content = change_all(curr.content, r'\textit', r' ', r'{', r'}', r'', r' ') - curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ') - curr.content = change_all(curr.content, r'\textbf', r' ', r'{', r'}', r'', r' ') - curr.content = change_all(curr.content, r'\mathbf', r' ', r'{', r'}', r'', r' ') - - # change split environment into aligned - curr.content = curr.content.replace(r'\begin{split}', r'\begin{aligned}') - curr.content = curr.content.replace(r'\end{split}', r'\end{aligned}') - - # remove extra spaces (keeping only one) - curr.content = re.sub(r' +', ' ', curr.content) - assert curr.content.startswith(' $') and curr.content.endswith('$ ') - curr.content = ' $' + curr.content[2:-2].strip() + '$ ' - md += curr.content - prev = curr - return md.strip() diff --git a/texteller/models/thrid_party/paddleocr/infer/CTCLabelDecode.py b/texteller/paddleocr/CTCLabelDecode.py similarity index 99% rename from texteller/models/thrid_party/paddleocr/infer/CTCLabelDecode.py rename to texteller/paddleocr/CTCLabelDecode.py index de9a275..8e2b18b 100644 --- a/texteller/models/thrid_party/paddleocr/infer/CTCLabelDecode.py +++ b/texteller/paddleocr/CTCLabelDecode.py @@ -81,7 +81,7 @@ class BaseRecLabelDecode(object): word_list = [] word_col_list = [] state_list = [] - valid_col = np.where(selection == True)[0] + valid_col = np.where(selection)[0] for c_i, char in enumerate(text): if "\u4e00" <= char <= "\u9fff": diff --git a/texteller/models/thrid_party/paddleocr/infer/DBPostProcess.py b/texteller/paddleocr/DBPostProcess.py similarity index 100% rename from texteller/models/thrid_party/paddleocr/infer/DBPostProcess.py rename to texteller/paddleocr/DBPostProcess.py diff --git a/texteller/models/thrid_party/paddleocr/infer/operators.py b/texteller/paddleocr/operators.py similarity index 100% rename from texteller/models/thrid_party/paddleocr/infer/operators.py rename to texteller/paddleocr/operators.py diff --git a/texteller/models/thrid_party/paddleocr/infer/ppocr_keys_v1.txt b/texteller/paddleocr/ppocr_keys_v1.txt similarity index 100% rename from texteller/models/thrid_party/paddleocr/infer/ppocr_keys_v1.txt rename to texteller/paddleocr/ppocr_keys_v1.txt diff --git a/texteller/models/thrid_party/paddleocr/infer/predict_det.py b/texteller/paddleocr/predict_det.py similarity index 95% rename from texteller/models/thrid_party/paddleocr/infer/predict_det.py rename to texteller/paddleocr/predict_det.py index 284c673..07c5706 100755 --- a/texteller/models/thrid_party/paddleocr/infer/predict_det.py +++ b/texteller/paddleocr/predict_det.py @@ -12,25 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../.."))) os.environ["FLAGS_allocator_strategy"] = "auto_growth" -import sys import time -import cv2 import numpy as np -# import tools.infer.utility as utility -import utility -from DBPostProcess import DBPostProcess -from operators import DetResizeForTest, KeepKeys, NormalizeImage, ToCHWImage -from utility import get_logger +from .DBPostProcess import DBPostProcess +from .operators import DetResizeForTest, KeepKeys, NormalizeImage, ToCHWImage +from .utility import create_predictor, get_logger def transform(data, ops=None): @@ -82,7 +73,7 @@ class TextDetector(object): self.input_tensor, self.output_tensors, self.config, - ) = utility.create_predictor(args, "det", logger) + ) = create_predictor(args, "det", logger) assert self.use_onnx if self.use_onnx: diff --git a/texteller/models/thrid_party/paddleocr/infer/predict_rec.py b/texteller/paddleocr/predict_rec.py similarity index 100% rename from texteller/models/thrid_party/paddleocr/infer/predict_rec.py rename to texteller/paddleocr/predict_rec.py diff --git a/texteller/models/thrid_party/paddleocr/infer/utility.py b/texteller/paddleocr/utility.py similarity index 100% rename from texteller/models/thrid_party/paddleocr/infer/utility.py rename to texteller/paddleocr/utility.py diff --git a/texteller/server.py b/texteller/server.py deleted file mode 100644 index d22a706..0000000 --- a/texteller/server.py +++ /dev/null @@ -1,155 +0,0 @@ -import sys -import argparse -import tempfile -import time -import numpy as np -import cv2 - -from pathlib import Path -from starlette.requests import Request -from ray import serve -from ray.serve.handle import DeploymentHandle -from onnxruntime import InferenceSession - -from texteller.models.ocr_model.utils.inference import inference as rec_inference -from texteller.models.det_model.inference import predict as det_inference -from texteller.models.ocr_model.model.TexTeller import TexTeller -from texteller.models.det_model.inference import PredictConfig -from texteller.models.ocr_model.utils.to_katex import to_katex - - -PYTHON_VERSION = str(sys.version_info.major) + '.' + str(sys.version_info.minor) -LIBPATH = Path(sys.executable).parent.parent / 'lib' / ('python' + PYTHON_VERSION) / 'site-packages' -CUDNNPATH = LIBPATH / 'nvidia' / 'cudnn' / 'lib' - -parser = argparse.ArgumentParser() -parser.add_argument('-ckpt', '--checkpoint_dir', type=str) -parser.add_argument('-tknz', '--tokenizer_dir', type=str) -parser.add_argument('-port', '--server_port', type=int, default=8000) -parser.add_argument('--num_replicas', type=int, default=1) -parser.add_argument('--ncpu_per_replica', type=float, default=1.0) -parser.add_argument('--ngpu_per_replica', type=float, default=0.0) - -parser.add_argument('--inference-mode', type=str, default='cpu') -parser.add_argument('--num_beams', type=int, default=1) -parser.add_argument('-onnx', action='store_true', help='using onnx runtime') - -args = parser.parse_args() -if args.ngpu_per_replica > 0 and not args.inference_mode == 'cuda': - raise ValueError("--inference-mode must be cuda or mps if ngpu_per_replica > 0") - - -@serve.deployment( - num_replicas=args.num_replicas, - ray_actor_options={ - "num_cpus": args.ncpu_per_replica, - "num_gpus": args.ngpu_per_replica * 1.0 / 2, - }, -) -class TexTellerRecServer: - def __init__( - self, - checkpoint_path: str, - tokenizer_path: str, - inf_mode: str = 'cpu', - use_onnx: bool = False, - num_beams: int = 1, - ) -> None: - self.model = TexTeller.from_pretrained( - checkpoint_path, use_onnx=use_onnx, onnx_provider=inf_mode - ) - self.tokenizer = TexTeller.get_tokenizer(tokenizer_path) - self.inf_mode = inf_mode - self.num_beams = num_beams - - if not use_onnx: - self.model = self.model.to(inf_mode) if inf_mode != 'cpu' else self.model - - def predict(self, image_nparray) -> str: - return to_katex( - rec_inference( - self.model, - self.tokenizer, - [image_nparray], - accelerator=self.inf_mode, - num_beams=self.num_beams, - )[0] - ) - - -@serve.deployment( - num_replicas=args.num_replicas, - ray_actor_options={ - "num_cpus": args.ncpu_per_replica, - "num_gpus": args.ngpu_per_replica * 1.0 / 2, - "runtime_env": {"env_vars": {"LD_LIBRARY_PATH": f"{str(CUDNNPATH)}/:$LD_LIBRARY_PATH"}}, - }, -) -class TexTellerDetServer: - def __init__(self, inf_mode='cpu'): - self.infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml") - self.latex_det_model = InferenceSession( - "./models/det_model/model/rtdetr_r50vd_6x_coco.onnx", - providers=['CUDAExecutionProvider'] if inf_mode == 'cuda' else ['CPUExecutionProvider'], - ) - - async def predict(self, image_nparray) -> str: - with tempfile.TemporaryDirectory() as temp_dir: - img_path = f"{temp_dir}/temp_image.jpg" - cv2.imwrite(img_path, image_nparray) - - latex_bboxes = det_inference(img_path, self.latex_det_model, self.infer_config) - return latex_bboxes - - -@serve.deployment() -class Ingress: - def __init__(self, det_server: DeploymentHandle, rec_server: DeploymentHandle) -> None: - self.det_server = det_server - self.texteller_server = rec_server - - async def __call__(self, request: Request) -> str: - request_path = request.url.path - form = await request.form() - img_rb = await form['img'].read() - - img_nparray = np.frombuffer(img_rb, np.uint8) - img_nparray = cv2.imdecode(img_nparray, cv2.IMREAD_COLOR) - img_nparray = cv2.cvtColor(img_nparray, cv2.COLOR_BGR2RGB) - - if request_path.startswith("/fdet"): - if self.det_server is None: - return "[ERROR] rtdetr_r50vd_6x_coco.onnx not found." - pred = await self.det_server.predict.remote(img_nparray) - return pred - - elif request_path.startswith("/frec"): - pred = await self.texteller_server.predict.remote(img_nparray) - return pred - - else: - return "[ERROR] Invalid request path" - - -if __name__ == '__main__': - ckpt_dir = args.checkpoint_dir - tknz_dir = args.tokenizer_dir - - serve.start(http_options={"host": "0.0.0.0", "port": args.server_port}) - rec_server = TexTellerRecServer.bind( - ckpt_dir, - tknz_dir, - inf_mode=args.inference_mode, - use_onnx=args.onnx, - num_beams=args.num_beams, - ) - det_server = None - if Path('./models/det_model/model/rtdetr_r50vd_6x_coco.onnx').exists(): - det_server = TexTellerDetServer.bind(args.inference_mode) - ingress = Ingress.bind(det_server, rec_server) - - # ingress_handle = serve.run(ingress, route_prefix="/predict") - ingress_handle = serve.run(ingress, route_prefix="/") - - while True: - time.sleep(1) diff --git a/texteller/start_web.bat b/texteller/start_web.bat deleted file mode 100644 index e235cca..0000000 --- a/texteller/start_web.bat +++ /dev/null @@ -1,9 +0,0 @@ -@echo off -SETLOCAL ENABLEEXTENSIONS - -set CHECKPOINT_DIR=default -set TOKENIZER_DIR=default - -streamlit run web.py - -ENDLOCAL diff --git a/texteller/start_web.sh b/texteller/start_web.sh deleted file mode 100755 index 6ec8f7b..0000000 --- a/texteller/start_web.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env bash -set -exu - -export CHECKPOINT_DIR="default" -export TOKENIZER_DIR="default" - -streamlit run web.py diff --git a/texteller/train_config.yaml b/texteller/train_config.yaml deleted file mode 100644 index 3197607..0000000 --- a/texteller/train_config.yaml +++ /dev/null @@ -1,14 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: MULTI_GPU -gpu_ids: all -num_processes: 1 -machine_rank: 0 -main_training_function: main -num_machines: 1 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/texteller/types/__init__.py b/texteller/types/__init__.py new file mode 100644 index 0000000..a23110d --- /dev/null +++ b/texteller/types/__init__.py @@ -0,0 +1,12 @@ +from typing import TypeAlias + +from optimum.onnxruntime import ORTModelForVision2Seq +from transformers import VisionEncoderDecoderModel + +from .bbox import Bbox + + +TexTellerModel: TypeAlias = VisionEncoderDecoderModel | ORTModelForVision2Seq + + +__all__ = ["Bbox", "TexTellerModel"] diff --git a/texteller/models/det_model/Bbox.py b/texteller/types/bbox.py similarity index 62% rename from texteller/models/det_model/Bbox.py rename to texteller/types/bbox.py index 53d5735..e48981b 100644 --- a/texteller/models/det_model/Bbox.py +++ b/texteller/types/bbox.py @@ -1,10 +1,3 @@ -import os - -from PIL import Image, ImageDraw -from typing import List -from pathlib import Path - - class Point: def __init__(self, x: int, y: int): self.x = int(x) @@ -51,9 +44,9 @@ class Bbox: return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD def __lt__(self, other) -> bool: - ''' + """ from top to bottom, from left to right - ''' + """ if not self.same_row(other): return self.p.y < other.p.y else: @@ -61,29 +54,3 @@ class Bbox: def __repr__(self) -> str: return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})" - - -def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"): - curr_work_dir = Path(os.getcwd()) - log_dir = curr_work_dir / "logs" - log_dir.mkdir(exist_ok=True) - drawer = ImageDraw.Draw(img) - for bbox in bboxes: - # Calculate the coordinates for the rectangle to be drawn - left = bbox.p.x - top = bbox.p.y - right = bbox.p.x + bbox.w - bottom = bbox.p.y + bbox.h - - # Draw the rectangle on the image - drawer.rectangle([left, top, right, bottom], outline="green", width=1) - - # Optionally, add text label if it exists - if bbox.label: - drawer.text((left, top), bbox.label, fill="blue") - - if bbox.content: - drawer.text((left, bottom - 10), bbox.content[:10], fill="red") - - # Save the image with drawn rectangles - img.save(log_dir / name) diff --git a/texteller/utils/__init__.py b/texteller/utils/__init__.py new file mode 100644 index 0000000..be851b0 --- /dev/null +++ b/texteller/utils/__init__.py @@ -0,0 +1,26 @@ +from .device import get_device, cuda_available, mps_available, str2device +from .image import readimgs, transform +from .latex import change_all, remove_style, add_newlines +from .path import mkdir, resolve_path +from .misc import lines_dedent +from .bbox import mask_img, bbox_merge, split_conflict, slice_from_image, draw_bboxes + +__all__ = [ + "get_device", + "cuda_available", + "mps_available", + "str2device", + "readimgs", + "transform", + "change_all", + "remove_style", + "add_newlines", + "mkdir", + "resolve_path", + "lines_dedent", + "mask_img", + "bbox_merge", + "split_conflict", + "slice_from_image", + "draw_bboxes", +] diff --git a/texteller/utils/bbox.py b/texteller/utils/bbox.py new file mode 100644 index 0000000..951a024 --- /dev/null +++ b/texteller/utils/bbox.py @@ -0,0 +1,142 @@ +import heapq +import os +from pathlib import Path + +import numpy as np +from PIL import Image, ImageDraw + +from texteller.types import Bbox + +_MAXV = 999999999 + + +def mask_img(img, bboxes: list[Bbox], bg_color: np.ndarray) -> np.ndarray: + mask_img = img.copy() + for bbox in bboxes: + mask_img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w] = bg_color + return mask_img + + +def bbox_merge(sorted_bboxes: list[Bbox]) -> list[Bbox]: + if len(sorted_bboxes) == 0: + return [] + bboxes = sorted_bboxes.copy() + guard = Bbox(_MAXV, bboxes[-1].p.y, -1, -1, label="guard") + bboxes.append(guard) + res = [] + prev = bboxes[0] + for curr in bboxes: + if prev.ur_point.x <= curr.p.x or not prev.same_row(curr): + res.append(prev) + prev = curr + else: + prev.w = max(prev.w, curr.ur_point.x - prev.p.x) + return res + + +def split_conflict(ocr_bboxes: list[Bbox], latex_bboxes: list[Bbox]) -> list[Bbox]: + if latex_bboxes == []: + return ocr_bboxes + if ocr_bboxes == [] or len(ocr_bboxes) == 1: + return ocr_bboxes + + bboxes = sorted(ocr_bboxes + latex_bboxes) + + assert len(bboxes) > 1 + + heapq.heapify(bboxes) + res = [] + candidate = heapq.heappop(bboxes) + curr = heapq.heappop(bboxes) + idx = 0 + while len(bboxes) > 0: + idx += 1 + assert candidate.p.x <= curr.p.x or not candidate.same_row(curr) + + if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr): + res.append(candidate) + candidate = curr + curr = heapq.heappop(bboxes) + elif candidate.ur_point.x < curr.ur_point.x: + assert not (candidate.label != "text" and curr.label != "text") + if candidate.label == "text" and curr.label == "text": + candidate.w = curr.ur_point.x - candidate.p.x + curr = heapq.heappop(bboxes) + elif candidate.label != curr.label: + if candidate.label == "text": + candidate.w = curr.p.x - candidate.p.x + res.append(candidate) + candidate = curr + curr = heapq.heappop(bboxes) + else: + curr.w = curr.ur_point.x - candidate.ur_point.x + curr.p.x = candidate.ur_point.x + heapq.heappush(bboxes, curr) + curr = heapq.heappop(bboxes) + + elif candidate.ur_point.x >= curr.ur_point.x: + assert not (candidate.label != "text" and curr.label != "text") + + if candidate.label == "text": + assert curr.label != "text" + heapq.heappush( + bboxes, + Bbox( + curr.ur_point.x, + candidate.p.y, + candidate.h, + candidate.ur_point.x - curr.ur_point.x, + label="text", + confidence=candidate.confidence, + content=None, + ), + ) + candidate.w = curr.p.x - candidate.p.x + res.append(candidate) + candidate = curr + curr = heapq.heappop(bboxes) + else: + assert curr.label == "text" + curr = heapq.heappop(bboxes) + else: + assert False + res.append(candidate) + res.append(curr) + + return res + + +def slice_from_image(img: np.ndarray, ocr_bboxes: list[Bbox]) -> list[np.ndarray]: + sliced_imgs = [] + for bbox in ocr_bboxes: + x, y = int(bbox.p.x), int(bbox.p.y) + w, h = int(bbox.w), int(bbox.h) + sliced_img = img[y : y + h, x : x + w] + sliced_imgs.append(sliced_img) + return sliced_imgs + + +def draw_bboxes(img: Image.Image, bboxes: list[Bbox], name="annotated_image.png"): + curr_work_dir = Path(os.getcwd()) + log_dir = curr_work_dir / "logs" + log_dir.mkdir(exist_ok=True) + drawer = ImageDraw.Draw(img) + for bbox in bboxes: + # Calculate the coordinates for the rectangle to be drawn + left = bbox.p.x + top = bbox.p.y + right = bbox.p.x + bbox.w + bottom = bbox.p.y + bbox.h + + # Draw the rectangle on the image + drawer.rectangle([left, top, right, bottom], outline="green", width=1) + + # Optionally, add text label if it exists + if bbox.label: + drawer.text((left, top), bbox.label, fill="blue") + + if bbox.content: + drawer.text((left, bottom - 10), bbox.content[:10], fill="red") + + # Save the image with drawn rectangles + img.save(log_dir / name) diff --git a/texteller/utils/device.py b/texteller/utils/device.py new file mode 100644 index 0000000..e92434a --- /dev/null +++ b/texteller/utils/device.py @@ -0,0 +1,41 @@ +from typing import Literal + +import torch + + +def str2device(device_str: Literal["cpu", "cuda", "mps"]) -> torch.device: + if device_str == "cpu": + return torch.device("cpu") + elif device_str == "cuda": + return torch.device("cuda") + elif device_str == "mps": + return torch.device("mps") + else: + raise ValueError(f"Invalid device: {device_str}") + + +def get_device(device_index: int = None) -> torch.device: + """ + Automatically detect the best available device for inference. + + Args: + device_index: The index of GPU device to use if multiple are available. + Defaults to None, which uses the first available GPU. + + Returns: + torch.device: Selected device for model inference. + """ + if cuda_available(): + return str2device("cuda") + elif mps_available(): + return str2device("mps") + else: + return str2device("cpu") + + +def cuda_available() -> bool: + return torch.cuda.is_available() + + +def mps_available() -> bool: + return torch.backends.mps.is_available() diff --git a/texteller/utils/image.py b/texteller/utils/image.py new file mode 100644 index 0000000..cc50a3c --- /dev/null +++ b/texteller/utils/image.py @@ -0,0 +1,121 @@ +from collections import Counter +from typing import List, Union + +import cv2 +import numpy as np +import torch +from PIL import Image +from torchvision.transforms import v2 + +from texteller.constants import ( + FIXED_IMG_SIZE, + IMG_CHANNELS, + IMAGE_MEAN, + IMAGE_STD, +) +from texteller.logger import get_logger + + +_logger = get_logger() + + +def readimgs(image_paths: list[str]) -> list[np.ndarray]: + """ + Read and preprocess a list of images from their file paths. + + This function reads each image from the provided paths, handles different + bit depths (converting 16-bit to 8-bit if necessary), and normalizes color + channels to RGB format regardless of the original color space (BGR, BGRA, + or grayscale). + + Args: + image_paths (list[str]): A list of file paths to the images to be read. + + Returns: + list[np.ndarray]: A list of NumPy arrays containing the preprocessed images + in RGB format. Images that could not be read are skipped. + """ + processed_images = [] + for path in image_paths: + image = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if image is None: + raise ValueError(f"Image at {path} could not be read.") + if image.dtype == np.uint16: + _logger.warning(f'Converting {path} to 8-bit, image may be lossy.') + image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0)) + + channels = 1 if len(image.shape) == 2 else image.shape[2] + if channels == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) + elif channels == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + elif channels == 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + processed_images.append(image) + + return processed_images + + +def trim_white_border(image: np.ndarray) -> np.ndarray: + if len(image.shape) != 3 or image.shape[2] != 3: + raise ValueError("Image is not in RGB format or channel is not in third dimension") + + if image.dtype != np.uint8: + raise ValueError(f"Image should stored in uint8") + + corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])] + bg_color = Counter(corners).most_common(1)[0][0] + bg_color_np = np.array(bg_color, dtype=np.uint8) + + h, w = image.shape[:2] + bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8) + + diff = cv2.absdiff(image, bg) + mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY) + + threshold = 15 + _, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY) + + x, y, w, h = cv2.boundingRect(diff) + + trimmed_image = image[y : y + h, x : x + w] + + return trimmed_image + + +def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]: + images = [ + v2.functional.pad( + img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]] + ) + for img in images + ] + return images + + +def transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]: + general_transform_pipeline = v2.Compose( + [ + v2.ToImage(), + v2.ToDtype(torch.uint8, scale=True), + v2.Grayscale(), + v2.Resize( + size=FIXED_IMG_SIZE - 1, + interpolation=v2.InterpolationMode.BICUBIC, + max_size=FIXED_IMG_SIZE, + antialias=True, + ), + v2.ToDtype(torch.float32, scale=True), # Normalize expects float input + v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]), + ] + ) + + assert IMG_CHANNELS == 1, "Only support grayscale images for now" + images = [ + np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images + ] + images = [trim_white_border(image) for image in images] + images = [general_transform_pipeline(image) for image in images] + images = padding(images, FIXED_IMG_SIZE) + + return images diff --git a/texteller/utils/latex.py b/texteller/utils/latex.py new file mode 100644 index 0000000..d778924 --- /dev/null +++ b/texteller/utils/latex.py @@ -0,0 +1,128 @@ +import re + + +def _change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r): + result = "" + i = 0 + n = len(input_str) + + while i < n: + if input_str[i : i + len(old_inst)] == old_inst: + # check if the old_inst is followed by old_surr_l + start = i + len(old_inst) + else: + result += input_str[i] + i += 1 + continue + + if start < n and input_str[start] == old_surr_l: + # found an old_inst followed by old_surr_l, now look for the matching old_surr_r + count = 1 + j = start + 1 + escaped = False + while j < n and count > 0: + if input_str[j] == '\\' and not escaped: + escaped = True + j += 1 + continue + if input_str[j] == old_surr_r and not escaped: + count -= 1 + if count == 0: + break + elif input_str[j] == old_surr_l and not escaped: + count += 1 + escaped = False + j += 1 + + if count == 0: + assert j < n + assert input_str[start] == old_surr_l + assert input_str[j] == old_surr_r + inner_content = input_str[start + 1 : j] + # Replace the content with new pattern + result += new_inst + new_surr_l + inner_content + new_surr_r + i = j + 1 + continue + else: + assert count >= 1 + assert j == n + print("Warning: unbalanced surrogate pair in input string") + result += new_inst + new_surr_l + i = start + 1 + continue + else: + result += input_str[i:start] + i = start + + if old_inst != new_inst and (old_inst + old_surr_l) in result: + return _change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r) + else: + return result + + +def _find_substring_positions(string, substring): + positions = [match.start() for match in re.finditer(re.escape(substring), string)] + return positions + + +def change_all(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r): + pos = _find_substring_positions(input_str, old_inst + old_surr_l) + res = list(input_str) + for p in pos[::-1]: + res[p:] = list( + _change( + ''.join(res[p:]), old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r + ) + ) + res = ''.join(res) + return res + + +def remove_style(input_str: str) -> str: + input_str = change_all(input_str, r"\bm", r" ", r"{", r"}", r"", r" ") + input_str = change_all(input_str, r"\boldsymbol", r" ", r"{", r"}", r"", r" ") + input_str = change_all(input_str, r"\textit", r" ", r"{", r"}", r"", r" ") + input_str = change_all(input_str, r"\textbf", r" ", r"{", r"}", r"", r" ") + input_str = change_all(input_str, r"\textbf", r" ", r"{", r"}", r"", r" ") + input_str = change_all(input_str, r"\mathbf", r" ", r"{", r"}", r"", r" ") + output_str = input_str.strip() + return output_str + + +def add_newlines(latex_str: str) -> str: + """ + Adds newlines to a LaTeX string based on specific patterns, ensuring no + duplicate newlines are added around begin/end environments. + - After \\ (if not already followed by newline) + - Before \\begin{...} (if not already preceded by newline) + - After \\begin{...} (if not already followed by newline) + - Before \\end{...} (if not already preceded by newline) + - After \\end{...} (if not already followed by newline) + + Args: + latex_str: The input LaTeX string. + + Returns: + The LaTeX string with added newlines, avoiding duplicates. + """ + processed_str = latex_str + + # 1. Replace whitespace around \begin{...} with \n...\n + # \s* matches zero or more whitespace characters (space, tab, newline) + # Captures the \begin{...} part in group 1 (\g<1>) + processed_str = re.sub(r"\s*(\\begin\{[^}]*\})\s*", r"\n\g<1>\n", processed_str) + + # 2. Replace whitespace around \end{...} with \n...\n + # Same logic as for \begin + processed_str = re.sub(r"\s*(\\end\{[^}]*\})\s*", r"\n\g<1>\n", processed_str) + + # 3. Add newline after \\ (if not already followed by newline) + processed_str = re.sub(r"\\\\(?!\n| )|\\\\ ", r"\\\\\n", processed_str) + + # 4. Cleanup: Collapse multiple consecutive newlines into a single newline. + # This handles cases where the replacements above might have created \n\n. + processed_str = re.sub(r'\n{2,}', '\n', processed_str) + + # Remove leading/trailing whitespace (including potential single newlines + # at the very start/end resulting from the replacements) from the entire result. + return processed_str.strip() diff --git a/texteller/utils/misc.py b/texteller/utils/misc.py new file mode 100644 index 0000000..a534242 --- /dev/null +++ b/texteller/utils/misc.py @@ -0,0 +1,5 @@ +from textwrap import dedent + + +def lines_dedent(s: str) -> str: + return dedent(s).strip() diff --git a/texteller/utils/path.py b/texteller/utils/path.py new file mode 100644 index 0000000..0416478 --- /dev/null +++ b/texteller/utils/path.py @@ -0,0 +1,52 @@ +from pathlib import Path +from typing import Literal + +from texteller.logger import get_logger + +_logger = get_logger(__name__) + + +def resolve_path(path: str | Path) -> str: + if isinstance(path, str): + path = Path(path) + return str(path.expanduser().resolve()) + + +def touch(path: str | Path) -> None: + if isinstance(path, str): + path = Path(path) + path.touch(exist_ok=True) + + +def mkdir(path: str | Path) -> None: + if isinstance(path, str): + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + +def rmfile(path: str | Path) -> None: + if isinstance(path, str): + path = Path(path) + path.unlink(missing_ok=False) + + +def rmdir(path: str | Path, mode: Literal["empty", "recursive"] = "empty") -> None: + """Remove a directory. + + Args: + path: Path to directory to remove + mode: "empty" to only remove empty directories, "all" to recursively remove all contents + """ + if isinstance(path, str): + path = Path(path) + + if mode == "empty": + path.rmdir() + _logger.info(f"Removed empty directory: {path}") + elif mode == "recursive": + import shutil + + shutil.rmtree(path) + _logger.info(f"Recursively removed directory and all contents: {path}") + else: + raise ValueError(f"Invalid mode: {mode}. Must be 'empty' or 'all'") diff --git a/texteller/web.py b/texteller/web.py deleted file mode 100644 index ff13a74..0000000 --- a/texteller/web.py +++ /dev/null @@ -1,275 +0,0 @@ -import os -import io -import re -import base64 -import tempfile -import shutil -import streamlit as st - -from PIL import Image -from streamlit_paste_button import paste_image_button as pbutton -from onnxruntime import InferenceSession -from texteller.models.thrid_party.paddleocr.infer import predict_det, predict_rec -from texteller.models.thrid_party.paddleocr.infer import utility - -from texteller.models.utils import mix_inference -from texteller.models.det_model.inference import PredictConfig - -from texteller.models.ocr_model.model.TexTeller import TexTeller -from texteller.models.ocr_model.utils.inference import inference as latex_recognition -from texteller.models.ocr_model.utils.to_katex import to_katex - - -st.set_page_config(page_title="TexTeller", page_icon="🧮") - -html_string = ''' -

- - 𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛 - -

-''' - -suc_gif_html = ''' -

- - - -

-''' - -fail_gif_html = ''' -

- - - -

-''' - - -@st.cache_resource -def get_texteller(use_onnx, accelerator): - return TexTeller.from_pretrained( - os.environ['CHECKPOINT_DIR'], use_onnx=use_onnx, onnx_provider=accelerator - ) - - -@st.cache_resource -def get_tokenizer(): - return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR']) - - -@st.cache_resource -def get_det_models(accelerator): - infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml") - latex_det_model = InferenceSession( - "./models/det_model/model/rtdetr_r50vd_6x_coco.onnx", - providers=['CUDAExecutionProvider'] if accelerator == 'cuda' else ['CPUExecutionProvider'], - ) - return infer_config, latex_det_model - - -@st.cache_resource() -def get_ocr_models(accelerator): - use_gpu = accelerator == 'cuda' - - SIZE_LIMIT = 20 * 1024 * 1024 - det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx" - rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx" - # The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime) - det_use_gpu = False - rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT) - - paddleocr_args = utility.parse_args() - paddleocr_args.use_onnx = True - paddleocr_args.det_model_dir = det_model_dir - paddleocr_args.rec_model_dir = rec_model_dir - - paddleocr_args.use_gpu = det_use_gpu - detector = predict_det.TextDetector(paddleocr_args) - paddleocr_args.use_gpu = rec_use_gpu - recognizer = predict_rec.TextRecognizer(paddleocr_args) - return [detector, recognizer] - - -def get_image_base64(img_file): - buffered = io.BytesIO() - img_file.seek(0) - img = Image.open(img_file) - img.save(buffered, format="PNG") - return base64.b64encode(buffered.getvalue()).decode() - - -def on_file_upload(): - st.session_state["UPLOADED_FILE_CHANGED"] = True - - -def change_side_bar(): - st.session_state["CHANGE_SIDEBAR_FLAG"] = True - - -if "start" not in st.session_state: - st.session_state["start"] = 1 - st.toast('Hooray!', icon='🎉') - -if "UPLOADED_FILE_CHANGED" not in st.session_state: - st.session_state["UPLOADED_FILE_CHANGED"] = False - -if "CHANGE_SIDEBAR_FLAG" not in st.session_state: - st.session_state["CHANGE_SIDEBAR_FLAG"] = False - -if "INF_MODE" not in st.session_state: - st.session_state["INF_MODE"] = "Formula recognition" - - -############################## ############################## - -with st.sidebar: - num_beams = 1 - - st.markdown("# 🔨️ Config") - st.markdown("") - - inf_mode = st.selectbox( - "Inference mode", - ("Formula recognition", "Paragraph recognition"), - on_change=change_side_bar, - ) - - num_beams = st.number_input( - 'Number of beams', min_value=1, max_value=20, step=1, on_change=change_side_bar - ) - - device = st.radio("device", ("cpu", "cuda", "mps"), on_change=change_side_bar) - - st.markdown("## Seedup") - use_onnx = st.toggle("ONNX Runtime ") - - -############################## ############################## - - -################################ ################################ - -texteller = get_texteller(use_onnx, device) -tokenizer = get_tokenizer() -latex_rec_models = [texteller, tokenizer] - -if inf_mode == "Paragraph recognition": - infer_config, latex_det_model = get_det_models(device) - lang_ocr_models = get_ocr_models(device) - -st.markdown(html_string, unsafe_allow_html=True) - -uploaded_file = st.file_uploader(" ", type=['jpg', 'png'], on_change=on_file_upload) - -paste_result = pbutton( - label="📋 Paste an image", - background_color="#5BBCFF", - hover_background_color="#3498db", -) -st.write("") - -if st.session_state["CHANGE_SIDEBAR_FLAG"] is True: - st.session_state["CHANGE_SIDEBAR_FLAG"] = False -elif uploaded_file or paste_result.image_data is not None: - if st.session_state["UPLOADED_FILE_CHANGED"] is False and paste_result.image_data is not None: - uploaded_file = io.BytesIO() - paste_result.image_data.save(uploaded_file, format='PNG') - uploaded_file.seek(0) - - if st.session_state["UPLOADED_FILE_CHANGED"] is True: - st.session_state["UPLOADED_FILE_CHANGED"] = False - - img = Image.open(uploaded_file) - - temp_dir = tempfile.mkdtemp() - png_file_path = os.path.join(temp_dir, 'image.png') - img.save(png_file_path, 'PNG') - - with st.container(height=300): - img_base64 = get_image_base64(uploaded_file) - - st.markdown( - f""" - -
- Input image -
- """, - unsafe_allow_html=True, - ) - st.markdown( - f""" - -
-

Input image ({img.height}✖️{img.width})

-
- """, - unsafe_allow_html=True, - ) - - st.write("") - - with st.spinner("Predicting..."): - if inf_mode == "Formula recognition": - TexTeller_result = latex_recognition( - texteller, tokenizer, [png_file_path], accelerator=device, num_beams=num_beams - )[0] - katex_res = to_katex(TexTeller_result) - else: - katex_res = mix_inference( - png_file_path, - infer_config, - latex_det_model, - lang_ocr_models, - latex_rec_models, - device, - num_beams, - ) - - st.success('Completed!', icon="✅") - st.markdown(suc_gif_html, unsafe_allow_html=True) - st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150) - - if inf_mode == "Formula recognition": - st.latex(katex_res) - elif inf_mode == "Paragraph recognition": - mixed_res = re.split(r'(\$\$.*?\$\$)', katex_res) - for text in mixed_res: - if text.startswith('$$') and text.endswith('$$'): - st.latex(text[2:-2]) - else: - st.markdown(text) - - st.write("") - st.write("") - - with st.expander(":star2: :gray[Tips for better results]"): - st.markdown(''' - * :mag_right: Use a clear and high-resolution image. - * :scissors: Crop images as accurately as possible. - * :jigsaw: Split large multi line formulas into smaller ones. - * :page_facing_up: Use images with **white background and black text** as much as possible. - * :book: Use a font with good readability. - ''') - shutil.rmtree(temp_dir) - - paste_result.image_data = None - -################################
################################