From 185b2e3db69d6e4ab715cf1558eaa80bc710ab1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Sun, 21 Apr 2024 00:05:14 +0800 Subject: [PATCH] =?UTF-8?q?1)=20=E5=AE=9E=E7=8E=B0=E4=BA=86=E6=96=87?= =?UTF-8?q?=E6=9C=AC-=E5=85=AC=E5=BC=8F=E6=B7=B7=E6=8E=92=E8=AF=86?= =?UTF-8?q?=E5=88=AB;=202)=20=E9=87=8D=E6=9E=84=E4=BA=86=E9=A1=B9=E7=9B=AE?= =?UTF-8?q?=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 8 +- requirements.txt | 2 +- src/gradio_web.py | 75 ------- src/infer_det.py | 156 ++------------ src/inference.py | 37 ++-- src/models/det_model/Bbox.py | 85 ++++++++ src/models/det_model/inference.py | 161 ++++++++++++++ src/models/det_model/model/infer_cfg.yml | 4 +- src/models/det_model/preprocess.py | 11 +- src/models/ocr_model/utils/inference.py | 19 +- src/models/ocr_model/utils/to_katex.py | 110 ++++++++++ src/models/utils/__init__.py | 1 + src/models/utils/mix_inference.py | 257 +++++++++++++++++++++++ src/rec_infer_from_crop_imgs.py | 4 +- src/server.py | 2 +- src/start_web.sh | 2 +- src/utils/__init__.py | 1 - src/utils/to_katex.py | 15 -- src/web.py | 99 ++++++--- 19 files changed, 753 insertions(+), 296 deletions(-) delete mode 100644 src/gradio_web.py create mode 100644 src/models/det_model/Bbox.py create mode 100644 src/models/det_model/inference.py create mode 100644 src/models/ocr_model/utils/to_katex.py create mode 100644 src/models/utils/__init__.py create mode 100644 src/models/utils/mix_inference.py delete mode 100644 src/utils/__init__.py delete mode 100644 src/utils/to_katex.py diff --git a/.gitignore b/.gitignore index b35b345..f5ac228 100644 --- a/.gitignore +++ b/.gitignore @@ -6,10 +6,16 @@ **/ckpt **/*cache **/.cache +**/tmp +**/log **/data **/logs **/tmp* **/data **/*cache -**/ckpt \ No newline at end of file +**/ckpt + +**/*.bin +**/*.safetensor +**/*.onnx \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9bd8876..f6bd23d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,4 @@ onnxruntime streamlit==1.30 streamlit-paste-button -easyocr +surya-ocr diff --git a/src/gradio_web.py b/src/gradio_web.py deleted file mode 100644 index eef613f..0000000 --- a/src/gradio_web.py +++ /dev/null @@ -1,75 +0,0 @@ -import os - -import gradio as gr -from models.ocr_model.utils.inference import inference -from models.ocr_model.model.TexTeller import TexTeller -from utils import to_katex -from pathlib import Path - - -# model = TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR']) -# tokenizer = TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR']) - - -css = """ - -""" - -theme=gr.themes.Default(), - -def fn(img): - return img - -with gr.Blocks( - theme=theme, - css=css -) as demo: - gr.HTML(f''' - {css} -
- -

𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛

- -
- ''') - - with gr.Row(equal_height=True): - input_img = gr.Image(type="pil", label="Input Image") - latex_img = gr.Image(label="Predicted Latex", show_label=False) - input_img.upload(fn, input_img, latex_img) - - gr.Markdown(r'$$\fcxrac{7}{10349}$$') - gr.Markdown('fooooooooooooooooooooooooooooo') - - -demo.launch() diff --git a/src/infer_det.py b/src/infer_det.py index 2410b81..a5047fc 100644 --- a/src/infer_det.py +++ b/src/infer_det.py @@ -1,35 +1,21 @@ import os -import yaml import argparse -import numpy as np import glob + from onnxruntime import InferenceSession -from tqdm import tqdm +from pathlib import Path +from models.det_model.inference import PredictConfig, predict_image -from models.det_model.preprocess import Compose -import cv2 - - -# Global dictionary -SUPPORT_MODELS = { - 'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', - 'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', - 'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet', - 'DETR' -} parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml", - default="./models/det_model/model/infer_cfg.yml" - ) + default="./models/det_model/model/infer_cfg.yml") parser.add_argument('--onnx_file', type=str, help="onnx model file path", - default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx" - ) + default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx") parser.add_argument("--image_dir", type=str) -parser.add_argument("--image_file", type=str, default='/data/ljm/TexTeller/src/Tr00_0001015-page02.jpg') -parser.add_argument("--imgsave_dir", type=str, - default="." - ) +parser.add_argument("--image_file", type=str, required=True) +parser.add_argument("--imgsave_dir", type=str, default="./detect_results") + def get_test_images(infer_dir, infer_img): """ @@ -62,125 +48,11 @@ def get_test_images(infer_dir, infer_img): return images -class PredictConfig(object): - """set config of preprocess, postprocess and visualize - Args: - infer_config (str): path of infer_cfg.yml - """ - - def __init__(self, infer_config): - # parsing Yaml config for Preprocess - with open(infer_config) as f: - yml_conf = yaml.safe_load(f) - self.check_model(yml_conf) - self.arch = yml_conf['arch'] - self.preprocess_infos = yml_conf['Preprocess'] - self.min_subgraph_size = yml_conf['min_subgraph_size'] - self.label_list = yml_conf['label_list'] - self.use_dynamic_shape = yml_conf['use_dynamic_shape'] - self.draw_threshold = yml_conf.get("draw_threshold", 0.5) - self.mask = yml_conf.get("mask", False) - self.tracker = yml_conf.get("tracker", None) - self.nms = yml_conf.get("NMS", None) - self.fpn_stride = yml_conf.get("fpn_stride", None) - - color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)] - self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)} - - if self.arch == 'RCNN' and yml_conf.get('export_onnx', False): - print( - 'The RCNN export model is used for ONNX and it only supports batch_size = 1' - ) - self.print_config() - - def check_model(self, yml_conf): - """ - Raises: - ValueError: loaded model not in supported model type - """ - for support_model in SUPPORT_MODELS: - if support_model in yml_conf['arch']: - return True - raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[ - 'arch'], SUPPORT_MODELS)) - - def print_config(self): - print('----------- Model Configuration -----------') - print('%s: %s' % ('Model Arch', self.arch)) - print('%s: ' % ('Transform Order')) - for op_info in self.preprocess_infos: - print('--%s: %s' % ('transform op', op_info['type'])) - print('--------------------------------------------') - - -def draw_bbox(image, outputs, infer_config): - for output in outputs: - cls_id, score, xmin, ymin, xmax, ymax = output - if score > infer_config.draw_threshold: - label = infer_config.label_list[int(cls_id)] - color = infer_config.colors[label] - cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2) - cv2.putText(image, "{}: {:.2f}".format(label, score), - (int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) - return image - - -def predict_image(infer_config, predictor, img_list): - # load preprocess transforms - transforms = Compose(infer_config.preprocess_infos) - errImgList = [] - - # Check and create subimg_save_dir if not exist - subimg_save_dir = os.path.join(FLAGS.imgsave_dir, 'subimages') - os.makedirs(subimg_save_dir, exist_ok=True) - - # predict image - for img_path in tqdm(img_list): - img = cv2.imread(img_path) - if img is None: - print(f"Warning: Could not read image {img_path}. Skipping...") - errImgList.append(img_path) - continue - - inputs = transforms(img_path) - inputs_name = [var.name for var in predictor.get_inputs()] - inputs = {k: inputs[k][None, ] for k in inputs_name} - - outputs = predictor.run(output_names=None, input_feed=inputs) - - print("ONNXRuntime predict: ") - if infer_config.arch in ["HRNet"]: - print(np.array(outputs[0])) - else: - bboxes = np.array(outputs[0]) - for bbox in bboxes: - if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold: - print(f"{int(bbox[0])} {bbox[1]} " - f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}") - - # Save the subimages (crop from the original image) - subimg_counter = 1 - for output in np.array(outputs[0]): - cls_id, score, xmin, ymin, xmax, ymax = output - if score > infer_config.draw_threshold: - label = infer_config.label_list[int(cls_id)] - subimg = img[int(ymin):int(ymax), int(xmin):int(xmax)] - subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg" - subimg_path = os.path.join(subimg_save_dir, subimg_filename) - cv2.imwrite(subimg_path, subimg) - subimg_counter += 1 - - # Draw bounding boxes and save the image with bounding boxes - img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config) - output_dir = FLAGS.imgsave_dir - os.makedirs(output_dir, exist_ok=True) - output_file = os.path.join(output_dir, "output_" + os.path.basename(img_path)) - cv2.imwrite(output_file, img_with_bbox) - - print("ErrorImgs:") - print(errImgList) - if __name__ == '__main__': + cur_path = os.getcwd() + script_dirpath = Path(__file__).resolve().parent + os.chdir(script_dirpath) + FLAGS = parser.parse_args() # load image list img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) @@ -189,4 +61,6 @@ if __name__ == '__main__': # load infer config infer_config = PredictConfig(FLAGS.infer_cfg) - predict_image(infer_config, predictor, img_list) + predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list) + + os.chdir(cur_path) diff --git a/src/inference.py b/src/inference.py index 127537e..b74a399 100644 --- a/src/inference.py +++ b/src/inference.py @@ -3,9 +3,18 @@ import argparse import cv2 as cv from pathlib import Path -from utils import to_katex +from onnxruntime import InferenceSession + +from models.utils import mix_inference +from models.ocr_model.utils.to_katex import to_katex from models.ocr_model.utils.inference import inference as latex_inference + from models.ocr_model.model.TexTeller import TexTeller +from models.det_model.inference import PredictConfig + +from surya.model.detection import segformer +from surya.model.recognition.model import load_model +from surya.model.recognition.processor import load_processor if __name__ == '__main__': @@ -29,33 +38,35 @@ if __name__ == '__main__': default=1, help='number of beam search for decoding' ) - # ================= new feature ================== parser.add_argument( '-mix', - type=str, - help='use mix mode, only Chinese and English are supported.' + action='store_true', + help='use mix mode' ) - # ================================================== args = parser.parse_args() # You can use your own checkpoint and tokenizer path. print('Loading model and tokenizer...') latex_rec_model = TexTeller.from_pretrained() - latex_rec_model = TexTeller.from_pretrained() tokenizer = TexTeller.get_tokenizer() print('Model and tokenizer loaded.') - # img_path = [args.img] - img = cv.imread(args.img) + img_path = args.img + img = cv.imread(img_path) print('Inference...') if not args.mix: res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam) res = to_katex(res[0]) print(res) else: - # latex_det_model = load_det_tex_model() - # lang_model = load_lang_models()... - ... - # res: str = mix_inference(latex_det_model, latex_rec_model, lang_model, img, args.cuda) - # print(res) + infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml") + latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco_IBEM_cnTextBook.onnx") + + det_processor, det_model = segformer.load_processor(), segformer.load_model() + rec_model, rec_processor = load_model(), load_processor() + lang_ocr_models = [det_model, det_processor, rec_model, rec_processor] + + latex_rec_models = [latex_rec_model, tokenizer] + res = mix_inference(img_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam) + print(res) diff --git a/src/models/det_model/Bbox.py b/src/models/det_model/Bbox.py new file mode 100644 index 0000000..93f4723 --- /dev/null +++ b/src/models/det_model/Bbox.py @@ -0,0 +1,85 @@ +from PIL import Image, ImageDraw +from typing import List + + +class Point: + def __init__(self, x: int, y: int): + self.x = int(x) + self.y = int(y) + + def __repr__(self) -> str: + return f"Point(x={self.x}, y={self.y})" + + +class Bbox: + THREADHOLD = 0.4 + + def __init__(self, x, y, h, w, label: str = None, confidence: float = 0, content: str = None): + self.p = Point(x, y) + self.h = int(h) + self.w = int(w) + self.label = label + self.confidence = confidence + self.content = content + + @property + def ul_point(self) -> Point: + return self.p + + @property + def ur_point(self) -> Point: + return Point(self.p.x + self.w, self.p.y) + + @property + def ll_point(self) -> Point: + return Point(self.p.x, self.p.y + self.h) + + @property + def lr_point(self) -> Point: + return Point(self.p.x + self.w, self.p.y + self.h) + + + def same_row(self, other) -> bool: + if ( + (self.p.y >= other.p.y and self.ll_point.y <= other.ll_point.y) + or (self.p.y <= other.p.y and self.ll_point.y >= other.ll_point.y) + ): + return True + if self.ll_point.y <= other.p.y or self.p.y >= other.ll_point.y: + return False + return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD + + def __lt__(self, other) -> bool: + ''' + from top to bottom, from left to right + ''' + if not self.same_row(other): + return self.p.y < other.p.y + else: + return self.p.x < other.p.x + + def __repr__(self) -> str: + return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})" + + +def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"): + drawer = ImageDraw.Draw(img) + for bbox in bboxes: + # Calculate the coordinates for the rectangle to be drawn + left = bbox.p.x + top = bbox.p.y + right = bbox.p.x + bbox.w + bottom = bbox.p.y + bbox.h + + # Draw the rectangle on the image + drawer.rectangle([left, top, right, bottom], outline="green", width=1) + + # Optionally, add text label if it exists + if bbox.label: + drawer.text((left, top), bbox.label, fill="blue") + + if bbox.content: + drawer.text((left, bottom - 10), bbox.content[:10], fill="red") + + # Save the image with drawn rectangles + img.save(name) \ No newline at end of file diff --git a/src/models/det_model/inference.py b/src/models/det_model/inference.py new file mode 100644 index 0000000..a142d16 --- /dev/null +++ b/src/models/det_model/inference.py @@ -0,0 +1,161 @@ +import os +import yaml +import numpy as np +import cv2 + +from tqdm import tqdm +from typing import List +from .preprocess import Compose +from .Bbox import Bbox + + +# Global dictionary +SUPPORT_MODELS = { + 'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', + 'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', + 'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet', + 'DETR' +} + + +class PredictConfig(object): + """set config of preprocess, postprocess and visualize + Args: + infer_config (str): path of infer_cfg.yml + """ + + def __init__(self, infer_config): + # parsing Yaml config for Preprocess + with open(infer_config) as f: + yml_conf = yaml.safe_load(f) + self.check_model(yml_conf) + self.arch = yml_conf['arch'] + self.preprocess_infos = yml_conf['Preprocess'] + self.min_subgraph_size = yml_conf['min_subgraph_size'] + self.label_list = yml_conf['label_list'] + self.use_dynamic_shape = yml_conf['use_dynamic_shape'] + self.draw_threshold = yml_conf.get("draw_threshold", 0.5) + self.mask = yml_conf.get("mask", False) + self.tracker = yml_conf.get("tracker", None) + self.nms = yml_conf.get("NMS", None) + self.fpn_stride = yml_conf.get("fpn_stride", None) + + color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)] + self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)} + + if self.arch == 'RCNN' and yml_conf.get('export_onnx', False): + print( + 'The RCNN export model is used for ONNX and it only supports batch_size = 1' + ) + self.print_config() + + def check_model(self, yml_conf): + """ + Raises: + ValueError: loaded model not in supported model type + """ + for support_model in SUPPORT_MODELS: + if support_model in yml_conf['arch']: + return True + raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[ + 'arch'], SUPPORT_MODELS)) + + def print_config(self): + print('----------- Model Configuration -----------') + print('%s: %s' % ('Model Arch', self.arch)) + print('%s: ' % ('Transform Order')) + for op_info in self.preprocess_infos: + print('--%s: %s' % ('transform op', op_info['type'])) + print('--------------------------------------------') + + +def draw_bbox(image, outputs, infer_config): + for output in outputs: + cls_id, score, xmin, ymin, xmax, ymax = output + if score > infer_config.draw_threshold: + label = infer_config.label_list[int(cls_id)] + color = infer_config.colors[label] + cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2) + cv2.putText(image, "{}: {:.2f}".format(label, score), + (int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) + return image + + +def predict_image(imgsave_dir, infer_config, predictor, img_list): + # load preprocess transforms + transforms = Compose(infer_config.preprocess_infos) + errImgList = [] + + # Check and create subimg_save_dir if not exist + subimg_save_dir = os.path.join(imgsave_dir, 'subimages') + os.makedirs(subimg_save_dir, exist_ok=True) + + # predict image + for img_path in tqdm(img_list): + img = cv2.imread(img_path) + if img is None: + print(f"Warning: Could not read image {img_path}. Skipping...") + errImgList.append(img_path) + continue + + inputs = transforms(img_path) + inputs_name = [var.name for var in predictor.get_inputs()] + inputs = {k: inputs[k][None, ] for k in inputs_name} + + outputs = predictor.run(output_names=None, input_feed=inputs) + + print("ONNXRuntime predict: ") + if infer_config.arch in ["HRNet"]: + print(np.array(outputs[0])) + else: + bboxes = np.array(outputs[0]) + for bbox in bboxes: + if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold: + print(f"{int(bbox[0])} {bbox[1]} " + f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}") + + # Save the subimages (crop from the original image) + subimg_counter = 1 + for output in np.array(outputs[0]): + cls_id, score, xmin, ymin, xmax, ymax = output + if score > infer_config.draw_threshold: + label = infer_config.label_list[int(cls_id)] + subimg = img[int(max(ymin, 0)):int(ymax), int(max(xmin, 0)):int(xmax)] + if len(subimg) == 0: + continue + + subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg" + subimg_path = os.path.join(subimg_save_dir, subimg_filename) + cv2.imwrite(subimg_path, subimg) + subimg_counter += 1 + + # Draw bounding boxes and save the image with bounding boxes + img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config) + output_dir = imgsave_dir + os.makedirs(output_dir, exist_ok=True) + output_file = os.path.join(output_dir, "output_" + os.path.basename(img_path)) + cv2.imwrite(output_file, img_with_bbox) + + print("ErrorImgs:") + print(errImgList) + + +def predict(img_path: str, predictor, infer_config) -> List[Bbox]: + transforms = Compose(infer_config.preprocess_infos) + inputs = transforms(img_path) + inputs_name = [var.name for var in predictor.get_inputs()] + inputs = {k: inputs[k][None, ] for k in inputs_name} + + outputs = predictor.run(output_names=None, input_feed=inputs)[0] + res = [] + for output in outputs: + cls_name = infer_config.label_list[int(output[0])] + score = output[1] + xmin = int(max(output[2], 0)) + ymin = int(max(output[3], 0)) + xmax = int(output[4]) + ymax = int(output[5]) + if score > infer_config.draw_threshold: + res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score)) + + return res diff --git a/src/models/det_model/model/infer_cfg.yml b/src/models/det_model/model/infer_cfg.yml index 0c156fc..09e6603 100644 --- a/src/models/det_model/model/infer_cfg.yml +++ b/src/models/det_model/model/infer_cfg.yml @@ -8,8 +8,8 @@ Preprocess: - interp: 2 keep_ratio: false target_size: - - 640 - - 640 + - 1600 + - 1600 type: Resize - mean: - 0.0 diff --git a/src/models/det_model/preprocess.py b/src/models/det_model/preprocess.py index 3554b7f..6b72494 100644 --- a/src/models/det_model/preprocess.py +++ b/src/models/det_model/preprocess.py @@ -4,9 +4,14 @@ import copy def decode_image(img_path): - with open(img_path, 'rb') as f: - im_read = f.read() - data = np.frombuffer(im_read, dtype='uint8') + if isinstance(img_path, str): + with open(img_path, 'rb') as f: + im_read = f.read() + data = np.frombuffer(im_read, dtype='uint8') + else: + assert isinstance(img_path, np.ndarray) + data = img_path + im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) img_info = { diff --git a/src/models/ocr_model/utils/inference.py b/src/models/ocr_model/utils/inference.py index 92bc08c..10f1b0d 100644 --- a/src/models/ocr_model/utils/inference.py +++ b/src/models/ocr_model/utils/inference.py @@ -4,19 +4,22 @@ import numpy as np from transformers import RobertaTokenizerFast, GenerationConfig from typing import List, Union -from models.ocr_model.model.TexTeller import TexTeller -from models.ocr_model.utils.transforms import inference_transform -from models.ocr_model.utils.helpers import convert2rgb -from models.globals import MAX_TOKEN_SIZE +from .transforms import inference_transform +from .helpers import convert2rgb +from ..model.TexTeller import TexTeller +from ...globals import MAX_TOKEN_SIZE def inference( model: TexTeller, tokenizer: RobertaTokenizerFast, imgs: Union[List[str], List[np.ndarray]], - inf_mode: str = 'cpu', + accelerator: str = 'cpu', num_beams: int = 1, + max_tokens = None ) -> List[str]: + if imgs == []: + return [] model.eval() if isinstance(imgs[0], str): imgs = convert2rgb(imgs) @@ -26,11 +29,11 @@ def inference( imgs = inference_transform(imgs) pixel_values = torch.stack(imgs) - model = model.to(inf_mode) - pixel_values = pixel_values.to(inf_mode) + model = model.to(accelerator) + pixel_values = pixel_values.to(accelerator) generate_config = GenerationConfig( - max_new_tokens=MAX_TOKEN_SIZE, + max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens, num_beams=num_beams, do_sample=False, pad_token_id=tokenizer.pad_token_id, diff --git a/src/models/ocr_model/utils/to_katex.py b/src/models/ocr_model/utils/to_katex.py new file mode 100644 index 0000000..5a0a95b --- /dev/null +++ b/src/models/ocr_model/utils/to_katex.py @@ -0,0 +1,110 @@ +import re + + +def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r): + result = "" + i = 0 + n = len(input_str) + + while i < n: + if input_str[i:i+len(old_inst)] == old_inst: + # check if the old_inst is followed by old_surr_l + start = i + len(old_inst) + else: + result += input_str[i] + i += 1 + continue + + if start < n and input_str[start] == old_surr_l: + # found an old_inst followed by old_surr_l, now look for the matching old_surr_r + count = 1 + j = start + 1 + escaped = False + while j < n and count > 0: + if input_str[j] == '\\' and not escaped: + escaped = True + j += 1 + continue + if input_str[j] == old_surr_r and not escaped: + count -= 1 + if count == 0: + break + elif input_str[j] == old_surr_l and not escaped: + count += 1 + escaped = False + j += 1 + + if count == 0: + assert j < n + assert input_str[start] == old_surr_l + assert input_str[j] == old_surr_r + inner_content = input_str[start + 1:j] + # Replace the content with new pattern + result += new_inst + new_surr_l + inner_content + new_surr_r + i = j + 1 + continue + else: + assert count > 1 + assert j == n + print("Warning: unbalanced surrogate pair in input string") + result += new_inst + new_surr_l + i = start + 1 + continue + else: + i = start + + if old_inst != new_inst and old_inst in result: + return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r) + else: + return result + + +def to_katex(formula: str) -> str: + res = formula + res = change(res, r'\mbox', r'', r'{', r'}', r'', r'') + + origin_instructions = [ + r'\Huge', + r'\huge', + r'\LARGE', + r'\Large', + r'\large', + r'\normalsize', + r'\small', + r'\footnotesize', + r'\scriptsize', + r'\tiny' + ] + for (old_ins, new_ins) in zip(origin_instructions, origin_instructions): + res = change(res, old_ins, new_ins, r'$', r'$', '{', '}') + res = change(res, r'\boldmath', r'\bm', r'$', r'$', r'{', r'}') + + origin_instructions = [ + r'\left', + r'\middle', + r'\right', + r'\big', + r'\Big', + r'\bigg', + r'\Bigg', + r'\bigl', + r'\Bigl', + r'\biggl', + r'\Biggl', + r'\bigm', + r'\Bigm', + r'\biggm', + r'\Biggm', + r'\bigr', + r'\Bigr', + r'\biggr', + r'\Biggr' + ] + for origin_ins in origin_instructions: + res = change(res, origin_ins, origin_ins, r'{', r'}', r'', r'') + + res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res) + + if res.endswith(r'\newline'): + res = res[:-8] + return res diff --git a/src/models/utils/__init__.py b/src/models/utils/__init__.py new file mode 100644 index 0000000..775dc11 --- /dev/null +++ b/src/models/utils/__init__.py @@ -0,0 +1 @@ +from .mix_inference import mix_inference \ No newline at end of file diff --git a/src/models/utils/mix_inference.py b/src/models/utils/mix_inference.py new file mode 100644 index 0000000..746237f --- /dev/null +++ b/src/models/utils/mix_inference.py @@ -0,0 +1,257 @@ +import re +import heapq +import cv2 +import numpy as np + +from onnxruntime import InferenceSession +from collections import Counter +from typing import List + +from PIL import Image +from surya.ocr import run_ocr +from surya.detection import batch_text_detection +from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image +from surya.recognition import batch_recognition +from surya.model.detection import segformer +from surya.model.recognition.model import load_model +from surya.model.recognition.processor import load_processor + +from ..det_model.inference import PredictConfig +from ..det_model.inference import predict as latex_det_predict +from ..det_model.Bbox import Bbox, draw_bboxes + +from ..ocr_model.model.TexTeller import TexTeller +from ..ocr_model.utils.inference import inference as latex_rec_predict +from ..ocr_model.utils.to_katex import to_katex + +MAXV = 999999999 + + +def mask_img(img, bboxes: List[Bbox], bg_color: np.ndarray) -> np.ndarray: + mask_img = img.copy() + for bbox in bboxes: + mask_img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w] = bg_color + return mask_img + + +def bbox_merge(sorted_bboxes: List[Bbox]) -> List[Bbox]: + if (len(sorted_bboxes) == 0): + return [] + bboxes = sorted_bboxes.copy() + guard = Bbox(MAXV, bboxes[-1].p.y, -1, -1, label="guard") + bboxes.append(guard) + res = [] + prev = bboxes[0] + for curr in bboxes: + if prev.ur_point.x <= curr.p.x or not prev.same_row(curr): + res.append(prev) + prev = curr + else: + prev.w = max(prev.w, curr.ur_point.x - prev.p.x) + return res + + +def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbox]: + if latex_bboxes == []: + return ocr_bboxes + if ocr_bboxes == [] or len(ocr_bboxes) == 1: + return ocr_bboxes + + bboxes = sorted(ocr_bboxes + latex_bboxes) + + ######## debug ######### + for idx, bbox in enumerate(bboxes): + bbox.content = str(idx) + draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png") + ######## debug ########### + + assert len(bboxes) > 1 + + heapq.heapify(bboxes) + res = [] + candidate = heapq.heappop(bboxes) + curr = heapq.heappop(bboxes) + idx = 0 + while (len(bboxes) > 0): + idx += 1 + assert candidate.p.x < curr.p.x or not candidate.same_row(curr) + + if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr): + res.append(candidate) + candidate = curr + curr = heapq.heappop(bboxes) + elif candidate.ur_point.x < curr.ur_point.x: + assert not (candidate.label != "text" and curr.label != "text") + if candidate.label == "text" and curr.label == "text": + candidate.w = curr.ur_point.x - candidate.p.x + curr = heapq.heappop(bboxes) + elif candidate.label != curr.label: + if candidate.label == "text": + candidate.w = curr.p.x - candidate.p.x + res.append(candidate) + candidate = curr + curr = heapq.heappop(bboxes) + else: + curr.w = curr.ur_point.x - candidate.ur_point.x + curr.p.x = candidate.ur_point.x + heapq.heappush(bboxes, curr) + curr = heapq.heappop(bboxes) + + elif candidate.ur_point.x >= curr.ur_point.x: + assert not (candidate.label != "text" and curr.label != "text") + + if candidate.label == "text": + assert curr.label != "text" + heapq.heappush( + bboxes, + Bbox( + curr.ur_point.x, + candidate.p.y, + candidate.h, + candidate.ur_point.x - curr.ur_point.x, + label="text", + confidence=candidate.confidence, + content=None + ) + ) + candidate.w = curr.p.x - candidate.p.x + res.append(candidate) + candidate = curr + curr = heapq.heappop(bboxes) + else: + assert curr.label == "text" + curr = heapq.heappop(bboxes) + else: + assert False + res.append(candidate) + res.append(curr) + ######## debug ######### + for idx, bbox in enumerate(res): + bbox.content = str(idx) + draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png") + ######## debug ########### + return res + + +def mix_inference( + img_path: str, + language: str, + infer_config, + latex_det_model, + + lang_ocr_models, + + latex_rec_models, + accelerator="cpu", + num_beams=1 +) -> str: + ''' + Input a mixed image of formula text and output str (in markdown syntax) + ''' + global img + img = cv2.imread(img_path) + corners = [tuple(img[0, 0]), tuple(img[0, -1]), + tuple(img[-1, 0]), tuple(img[-1, -1])] + bg_color = np.array(Counter(corners).most_common(1)[0][0]) + + latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config) + latex_bboxes = sorted(latex_bboxes) + draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png") + latex_bboxes = bbox_merge(latex_bboxes) + draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png") + masked_img = mask_img(img, latex_bboxes, bg_color) + + det_model, det_processor, rec_model, rec_processor = lang_ocr_models + images = [Image.fromarray(masked_img)] + det_prediction = batch_text_detection(images, det_model, det_processor)[0] + draw_bboxes(Image.fromarray(img), latex_bboxes, name="ocr_bboxes(unmerged).png") + + lang = [language] + slice_map = [] + all_slices = [] + all_langs = [] + ocr_bboxes = [ + Bbox( + p.bbox[0], p.bbox[1], p.bbox[3] - p.bbox[1], p.bbox[2] - p.bbox[0], + label="text", + confidence=p.confidence, + content=None + ) + for p in det_prediction.bboxes + ] + ocr_bboxes = sorted(ocr_bboxes) + ocr_bboxes = bbox_merge(ocr_bboxes) + draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png") + ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes) + ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes)) + polygons = [ + [ + [bbox.ul_point.x, bbox.ul_point.y], + [bbox.ur_point.x, bbox.ur_point.y], + [bbox.lr_point.x, bbox.lr_point.y], + [bbox.ll_point.x, bbox.ll_point.y] + ] + for bbox in ocr_bboxes + ] + + slices = slice_polys_from_image(images[0], polygons) + slice_map.append(len(slices)) + all_slices.extend(slices) + all_langs.extend([lang] * len(slices)) + + rec_predictions, _ = batch_recognition(all_slices, all_langs, rec_model, rec_processor) + + assert len(rec_predictions) == len(ocr_bboxes) + for content, bbox in zip(rec_predictions, ocr_bboxes): + bbox.content = content + + latex_imgs =[] + for bbox in latex_bboxes: + latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w]) + latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=200) + for bbox, content in zip(latex_bboxes, latex_rec_res): + bbox.content = to_katex(content) + if bbox.label == "embedding": + bbox.content = " $" + bbox.content + "$ " + elif bbox.label == "isolated": + bbox.content = '\n' + r"$$" + bbox.content + r"$$" + '\n' + + bboxes = sorted(ocr_bboxes + latex_bboxes) + if bboxes == []: + return "" + + md = "" + prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard") + # prev = bboxes[0] + for curr in bboxes: + if not prev.same_row(curr): + md += "\n" + md += curr.content + if ( + prev.label == "isolated" + and curr.label == "text" + and bool(re.fullmatch(r"\([1-9]\d*?\)", curr.content)) + ): + md += '\n' + prev = curr + return md + + +if __name__ == '__main__': + img_path = "/Users/Leehy/Code/TexTeller/test3.png" + + # latex_det_model = InferenceSession("/Users/Leehy/Code/TexTeller/src/models/det_model/model/rtdetr_r50vd_6x_coco_trained_on_IBEM_en_papers.onnx") + latex_det_model = InferenceSession("/Users/Leehy/Code/TexTeller/src/models/det_model/model/rtdetr_r50vd_6x_coco.onnx") + infer_config = PredictConfig("/Users/Leehy/Code/TexTeller/src/models/det_model/model/infer_cfg.yml") + + det_processor, det_model = segformer.load_processor(), segformer.load_model() + rec_model, rec_processor = load_model(), load_processor() + lang_ocr_models = (det_model, det_processor, rec_model, rec_processor) + + texteller = TexTeller.from_pretrained() + tokenizer = TexTeller.get_tokenizer() + latex_rec_models = (texteller, tokenizer) + + res = mix_inference(img_path, "zh", infer_config, latex_det_model, lang_ocr_models, latex_rec_models) + print(res) + pause = 1 diff --git a/src/rec_infer_from_crop_imgs.py b/src/rec_infer_from_crop_imgs.py index 1cacbc6..73bfa73 100644 --- a/src/rec_infer_from_crop_imgs.py +++ b/src/rec_infer_from_crop_imgs.py @@ -2,7 +2,7 @@ import os import argparse import cv2 as cv from pathlib import Path -from utils import to_katex +from models.ocr_model.utils.to_katex import to_katex from models.ocr_model.utils.inference import inference as latex_inference from models.ocr_model.model.TexTeller import TexTeller @@ -46,7 +46,7 @@ if __name__ == '__main__': if img is not None: print(f'Inference for {filename}...') - res = latex_inference(latex_rec_model, tokenizer, [img], inf_mode=args.inference_mode, num_beams=args.num_beam) + res = latex_inference(latex_rec_model, tokenizer, [img], accelerator=args.inference_mode, num_beams=args.num_beam) res = to_katex(res[0]) # Save the recognition result to a text file diff --git a/src/server.py b/src/server.py index 520d908..52e0068 100644 --- a/src/server.py +++ b/src/server.py @@ -56,7 +56,7 @@ class TexTellerServer: def predict(self, image_nparray) -> str: return inference( self.model, self.tokenizer, [image_nparray], - inf_mode=self.inf_mode, num_beams=self.num_beams + accelerator=self.inf_mode, num_beams=self.num_beams )[0] diff --git a/src/start_web.sh b/src/start_web.sh index 41e6311..6ec8f7b 100755 --- a/src/start_web.sh +++ b/src/start_web.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -exu -export CHECKPOINT_DIR="/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-648000" +export CHECKPOINT_DIR="default" export TOKENIZER_DIR="default" streamlit run web.py diff --git a/src/utils/__init__.py b/src/utils/__init__.py deleted file mode 100644 index ba17b40..0000000 --- a/src/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .to_katex import to_katex \ No newline at end of file diff --git a/src/utils/to_katex.py b/src/utils/to_katex.py deleted file mode 100644 index fe08a74..0000000 --- a/src/utils/to_katex.py +++ /dev/null @@ -1,15 +0,0 @@ -import numpy as np -import re - -def to_katex(formula: str) -> str: - res = formula - res = re.sub(r'\\mbox\{([^}]*)\}', r'\1', res) - res = re.sub(r'boldmath\$(.*?)\$', r'bm{\1}', res) - res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res) - - pattern = r'(\\(?:left|middle|right|big|Big|bigg|Bigg|bigl|Bigl|biggl|Biggl|bigm|Bigm|biggm|Biggm|bigr|Bigr|biggr|Biggr))\{([^}]*)\}' - replacement = r'\1\2' - res = re.sub(pattern, replacement, res) - if res.endswith(r'\newline'): - res = res[:-8] - return res diff --git a/src/web.py b/src/web.py index 379a609..d4a84f9 100644 --- a/src/web.py +++ b/src/web.py @@ -7,10 +7,18 @@ import streamlit as st from PIL import Image from streamlit_paste_button import paste_image_button as pbutton -from models.ocr_model.utils.inference import inference -from models.ocr_model.model.TexTeller import TexTeller -from utils import to_katex +from onnxruntime import InferenceSession +from models.utils import mix_inference +from models.det_model.inference import PredictConfig + +from models.ocr_model.model.TexTeller import TexTeller +from models.ocr_model.utils.inference import inference as latex_recognition +from models.ocr_model.utils.to_katex import to_katex + +from surya.model.detection import segformer +from surya.model.recognition.model import load_model +from surya.model.recognition.processor import load_processor st.set_page_config( page_title="TexTeller", @@ -42,13 +50,26 @@ fail_gif_html = ''' ''' @st.cache_resource -def get_model(): +def get_texteller(): return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR']) @st.cache_resource def get_tokenizer(): return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR']) +@st.cache_resource +def get_det_models(): + infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml") + latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx") + return infer_config, latex_det_model + +@st.cache_resource() +def get_ocr_models(): + det_processor, det_model = segformer.load_processor(), segformer.load_model() + rec_model, rec_processor = load_model(), load_processor() + lang_ocr_models = [det_model, det_processor, rec_model, rec_processor] + return lang_ocr_models + def get_image_base64(img_file): buffered = io.BytesIO() img_file.seek(0) @@ -62,9 +83,6 @@ def on_file_upload(): def change_side_bar(): st.session_state["CHANGE_SIDEBAR_FLAG"] = True -model = get_model() -tokenizer = get_tokenizer() - if "start" not in st.session_state: st.session_state["start"] = 1 st.toast('Hooray!', icon='🎉') @@ -75,31 +93,34 @@ if "UPLOADED_FILE_CHANGED" not in st.session_state: if "CHANGE_SIDEBAR_FLAG" not in st.session_state: st.session_state["CHANGE_SIDEBAR_FLAG"] = False +if "INF_MODE" not in st.session_state: + st.session_state["INF_MODE"] = "Formula only" + + # ============================ begin sidebar =============================== # with st.sidebar: num_beams = 1 - inf_mode = 'cpu' st.markdown("# 🔨️ Config") st.markdown("") - model_type = st.selectbox( - "Model type", - ("TexTeller", "None"), + inf_mode = st.selectbox( + "Inference mode", + ("Formula only", "Text formula mixed"), on_change=change_side_bar ) - if model_type == "TexTeller": - num_beams = st.number_input( - 'Number of beams', - min_value=1, - max_value=20, - step=1, - on_change=change_side_bar - ) - inf_mode = st.radio( - "Inference mode", + num_beams = st.number_input( + 'Number of beams', + min_value=1, + max_value=20, + step=1, + on_change=change_side_bar + ) + + accelerator = st.radio( + "Accelerator", ("cpu", "cuda", "mps"), on_change=change_side_bar ) @@ -107,9 +128,16 @@ with st.sidebar: # ============================ end sidebar =============================== # - # ============================ begin pages =============================== # +texteller = get_texteller() +tokenizer = get_tokenizer() +latex_rec_models = [texteller, tokenizer] + +if inf_mode == "Text formula mixed": + infer_config, latex_det_model = get_det_models() + lang_ocr_models = get_ocr_models() + st.markdown(html_string, unsafe_allow_html=True) uploaded_file = st.file_uploader( @@ -176,19 +204,26 @@ elif uploaded_file or paste_result.image_data is not None: st.write("") with st.spinner("Predicting..."): - uploaded_file.seek(0) - TexTeller_result = inference( - model, - tokenizer, - [png_file_path], - inf_mode=inf_mode, - num_beams=num_beams - )[0] + if inf_mode == "Formula only": + TexTeller_result = latex_recognition( + texteller, + tokenizer, + [png_file_path], + accelerator=accelerator, + num_beams=num_beams + )[0] + katex_res = to_katex(TexTeller_result) + else: + katex_res = mix_inference(png_file_path, "en", infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams) + st.success('Completed!', icon="✅") st.markdown(suc_gif_html, unsafe_allow_html=True) - katex_res = to_katex(TexTeller_result) st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150) - st.latex(katex_res) + + if inf_mode == "Formula only": + st.latex(katex_res) + elif inf_mode == "Text formula mixed": + st.markdown(katex_res) st.write("") st.write("")