From 185b2e3db69d6e4ab715cf1558eaa80bc710ab1f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com>
Date: Sun, 21 Apr 2024 00:05:14 +0800
Subject: [PATCH] =?UTF-8?q?1)=20=E5=AE=9E=E7=8E=B0=E4=BA=86=E6=96=87?=
=?UTF-8?q?=E6=9C=AC-=E5=85=AC=E5=BC=8F=E6=B7=B7=E6=8E=92=E8=AF=86?=
=?UTF-8?q?=E5=88=AB;=202)=20=E9=87=8D=E6=9E=84=E4=BA=86=E9=A1=B9=E7=9B=AE?=
=?UTF-8?q?=E7=BB=93=E6=9E=84?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.gitignore | 8 +-
requirements.txt | 2 +-
src/gradio_web.py | 75 -------
src/infer_det.py | 156 ++------------
src/inference.py | 37 ++--
src/models/det_model/Bbox.py | 85 ++++++++
src/models/det_model/inference.py | 161 ++++++++++++++
src/models/det_model/model/infer_cfg.yml | 4 +-
src/models/det_model/preprocess.py | 11 +-
src/models/ocr_model/utils/inference.py | 19 +-
src/models/ocr_model/utils/to_katex.py | 110 ++++++++++
src/models/utils/__init__.py | 1 +
src/models/utils/mix_inference.py | 257 +++++++++++++++++++++++
src/rec_infer_from_crop_imgs.py | 4 +-
src/server.py | 2 +-
src/start_web.sh | 2 +-
src/utils/__init__.py | 1 -
src/utils/to_katex.py | 15 --
src/web.py | 99 ++++++---
19 files changed, 753 insertions(+), 296 deletions(-)
delete mode 100644 src/gradio_web.py
create mode 100644 src/models/det_model/Bbox.py
create mode 100644 src/models/det_model/inference.py
create mode 100644 src/models/ocr_model/utils/to_katex.py
create mode 100644 src/models/utils/__init__.py
create mode 100644 src/models/utils/mix_inference.py
delete mode 100644 src/utils/__init__.py
delete mode 100644 src/utils/to_katex.py
diff --git a/.gitignore b/.gitignore
index b35b345..f5ac228 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,10 +6,16 @@
**/ckpt
**/*cache
**/.cache
+**/tmp
+**/log
**/data
**/logs
**/tmp*
**/data
**/*cache
-**/ckpt
\ No newline at end of file
+**/ckpt
+
+**/*.bin
+**/*.safetensor
+**/*.onnx
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 9bd8876..f6bd23d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -14,4 +14,4 @@ onnxruntime
streamlit==1.30
streamlit-paste-button
-easyocr
+surya-ocr
diff --git a/src/gradio_web.py b/src/gradio_web.py
deleted file mode 100644
index eef613f..0000000
--- a/src/gradio_web.py
+++ /dev/null
@@ -1,75 +0,0 @@
-import os
-
-import gradio as gr
-from models.ocr_model.utils.inference import inference
-from models.ocr_model.model.TexTeller import TexTeller
-from utils import to_katex
-from pathlib import Path
-
-
-# model = TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
-# tokenizer = TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
-
-
-css = """
-
-"""
-
-theme=gr.themes.Default(),
-
-def fn(img):
- return img
-
-with gr.Blocks(
- theme=theme,
- css=css
-) as demo:
- gr.HTML(f'''
- {css}
-
-

-
𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛
-

-
- ''')
-
- with gr.Row(equal_height=True):
- input_img = gr.Image(type="pil", label="Input Image")
- latex_img = gr.Image(label="Predicted Latex", show_label=False)
- input_img.upload(fn, input_img, latex_img)
-
- gr.Markdown(r'$$\fcxrac{7}{10349}$$')
- gr.Markdown('fooooooooooooooooooooooooooooo')
-
-
-demo.launch()
diff --git a/src/infer_det.py b/src/infer_det.py
index 2410b81..a5047fc 100644
--- a/src/infer_det.py
+++ b/src/infer_det.py
@@ -1,35 +1,21 @@
import os
-import yaml
import argparse
-import numpy as np
import glob
+
from onnxruntime import InferenceSession
-from tqdm import tqdm
+from pathlib import Path
+from models.det_model.inference import PredictConfig, predict_image
-from models.det_model.preprocess import Compose
-import cv2
-
-
-# Global dictionary
-SUPPORT_MODELS = {
- 'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
- 'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
- 'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet',
- 'DETR'
-}
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml",
- default="./models/det_model/model/infer_cfg.yml"
- )
+ default="./models/det_model/model/infer_cfg.yml")
parser.add_argument('--onnx_file', type=str, help="onnx model file path",
- default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx"
- )
+ default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
parser.add_argument("--image_dir", type=str)
-parser.add_argument("--image_file", type=str, default='/data/ljm/TexTeller/src/Tr00_0001015-page02.jpg')
-parser.add_argument("--imgsave_dir", type=str,
- default="."
- )
+parser.add_argument("--image_file", type=str, required=True)
+parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
+
def get_test_images(infer_dir, infer_img):
"""
@@ -62,125 +48,11 @@ def get_test_images(infer_dir, infer_img):
return images
-class PredictConfig(object):
- """set config of preprocess, postprocess and visualize
- Args:
- infer_config (str): path of infer_cfg.yml
- """
-
- def __init__(self, infer_config):
- # parsing Yaml config for Preprocess
- with open(infer_config) as f:
- yml_conf = yaml.safe_load(f)
- self.check_model(yml_conf)
- self.arch = yml_conf['arch']
- self.preprocess_infos = yml_conf['Preprocess']
- self.min_subgraph_size = yml_conf['min_subgraph_size']
- self.label_list = yml_conf['label_list']
- self.use_dynamic_shape = yml_conf['use_dynamic_shape']
- self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
- self.mask = yml_conf.get("mask", False)
- self.tracker = yml_conf.get("tracker", None)
- self.nms = yml_conf.get("NMS", None)
- self.fpn_stride = yml_conf.get("fpn_stride", None)
-
- color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
- self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)}
-
- if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
- print(
- 'The RCNN export model is used for ONNX and it only supports batch_size = 1'
- )
- self.print_config()
-
- def check_model(self, yml_conf):
- """
- Raises:
- ValueError: loaded model not in supported model type
- """
- for support_model in SUPPORT_MODELS:
- if support_model in yml_conf['arch']:
- return True
- raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
- 'arch'], SUPPORT_MODELS))
-
- def print_config(self):
- print('----------- Model Configuration -----------')
- print('%s: %s' % ('Model Arch', self.arch))
- print('%s: ' % ('Transform Order'))
- for op_info in self.preprocess_infos:
- print('--%s: %s' % ('transform op', op_info['type']))
- print('--------------------------------------------')
-
-
-def draw_bbox(image, outputs, infer_config):
- for output in outputs:
- cls_id, score, xmin, ymin, xmax, ymax = output
- if score > infer_config.draw_threshold:
- label = infer_config.label_list[int(cls_id)]
- color = infer_config.colors[label]
- cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
- cv2.putText(image, "{}: {:.2f}".format(label, score),
- (int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
- return image
-
-
-def predict_image(infer_config, predictor, img_list):
- # load preprocess transforms
- transforms = Compose(infer_config.preprocess_infos)
- errImgList = []
-
- # Check and create subimg_save_dir if not exist
- subimg_save_dir = os.path.join(FLAGS.imgsave_dir, 'subimages')
- os.makedirs(subimg_save_dir, exist_ok=True)
-
- # predict image
- for img_path in tqdm(img_list):
- img = cv2.imread(img_path)
- if img is None:
- print(f"Warning: Could not read image {img_path}. Skipping...")
- errImgList.append(img_path)
- continue
-
- inputs = transforms(img_path)
- inputs_name = [var.name for var in predictor.get_inputs()]
- inputs = {k: inputs[k][None, ] for k in inputs_name}
-
- outputs = predictor.run(output_names=None, input_feed=inputs)
-
- print("ONNXRuntime predict: ")
- if infer_config.arch in ["HRNet"]:
- print(np.array(outputs[0]))
- else:
- bboxes = np.array(outputs[0])
- for bbox in bboxes:
- if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
- print(f"{int(bbox[0])} {bbox[1]} "
- f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
-
- # Save the subimages (crop from the original image)
- subimg_counter = 1
- for output in np.array(outputs[0]):
- cls_id, score, xmin, ymin, xmax, ymax = output
- if score > infer_config.draw_threshold:
- label = infer_config.label_list[int(cls_id)]
- subimg = img[int(ymin):int(ymax), int(xmin):int(xmax)]
- subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg"
- subimg_path = os.path.join(subimg_save_dir, subimg_filename)
- cv2.imwrite(subimg_path, subimg)
- subimg_counter += 1
-
- # Draw bounding boxes and save the image with bounding boxes
- img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
- output_dir = FLAGS.imgsave_dir
- os.makedirs(output_dir, exist_ok=True)
- output_file = os.path.join(output_dir, "output_" + os.path.basename(img_path))
- cv2.imwrite(output_file, img_with_bbox)
-
- print("ErrorImgs:")
- print(errImgList)
-
if __name__ == '__main__':
+ cur_path = os.getcwd()
+ script_dirpath = Path(__file__).resolve().parent
+ os.chdir(script_dirpath)
+
FLAGS = parser.parse_args()
# load image list
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
@@ -189,4 +61,6 @@ if __name__ == '__main__':
# load infer config
infer_config = PredictConfig(FLAGS.infer_cfg)
- predict_image(infer_config, predictor, img_list)
+ predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list)
+
+ os.chdir(cur_path)
diff --git a/src/inference.py b/src/inference.py
index 127537e..b74a399 100644
--- a/src/inference.py
+++ b/src/inference.py
@@ -3,9 +3,18 @@ import argparse
import cv2 as cv
from pathlib import Path
-from utils import to_katex
+from onnxruntime import InferenceSession
+
+from models.utils import mix_inference
+from models.ocr_model.utils.to_katex import to_katex
from models.ocr_model.utils.inference import inference as latex_inference
+
from models.ocr_model.model.TexTeller import TexTeller
+from models.det_model.inference import PredictConfig
+
+from surya.model.detection import segformer
+from surya.model.recognition.model import load_model
+from surya.model.recognition.processor import load_processor
if __name__ == '__main__':
@@ -29,33 +38,35 @@ if __name__ == '__main__':
default=1,
help='number of beam search for decoding'
)
- # ================= new feature ==================
parser.add_argument(
'-mix',
- type=str,
- help='use mix mode, only Chinese and English are supported.'
+ action='store_true',
+ help='use mix mode'
)
- # ==================================================
args = parser.parse_args()
# You can use your own checkpoint and tokenizer path.
print('Loading model and tokenizer...')
latex_rec_model = TexTeller.from_pretrained()
- latex_rec_model = TexTeller.from_pretrained()
tokenizer = TexTeller.get_tokenizer()
print('Model and tokenizer loaded.')
- # img_path = [args.img]
- img = cv.imread(args.img)
+ img_path = args.img
+ img = cv.imread(img_path)
print('Inference...')
if not args.mix:
res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam)
res = to_katex(res[0])
print(res)
else:
- # latex_det_model = load_det_tex_model()
- # lang_model = load_lang_models()...
- ...
- # res: str = mix_inference(latex_det_model, latex_rec_model, lang_model, img, args.cuda)
- # print(res)
+ infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
+ latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco_IBEM_cnTextBook.onnx")
+
+ det_processor, det_model = segformer.load_processor(), segformer.load_model()
+ rec_model, rec_processor = load_model(), load_processor()
+ lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
+
+ latex_rec_models = [latex_rec_model, tokenizer]
+ res = mix_inference(img_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
+ print(res)
diff --git a/src/models/det_model/Bbox.py b/src/models/det_model/Bbox.py
new file mode 100644
index 0000000..93f4723
--- /dev/null
+++ b/src/models/det_model/Bbox.py
@@ -0,0 +1,85 @@
+from PIL import Image, ImageDraw
+from typing import List
+
+
+class Point:
+ def __init__(self, x: int, y: int):
+ self.x = int(x)
+ self.y = int(y)
+
+ def __repr__(self) -> str:
+ return f"Point(x={self.x}, y={self.y})"
+
+
+class Bbox:
+ THREADHOLD = 0.4
+
+ def __init__(self, x, y, h, w, label: str = None, confidence: float = 0, content: str = None):
+ self.p = Point(x, y)
+ self.h = int(h)
+ self.w = int(w)
+ self.label = label
+ self.confidence = confidence
+ self.content = content
+
+ @property
+ def ul_point(self) -> Point:
+ return self.p
+
+ @property
+ def ur_point(self) -> Point:
+ return Point(self.p.x + self.w, self.p.y)
+
+ @property
+ def ll_point(self) -> Point:
+ return Point(self.p.x, self.p.y + self.h)
+
+ @property
+ def lr_point(self) -> Point:
+ return Point(self.p.x + self.w, self.p.y + self.h)
+
+
+ def same_row(self, other) -> bool:
+ if (
+ (self.p.y >= other.p.y and self.ll_point.y <= other.ll_point.y)
+ or (self.p.y <= other.p.y and self.ll_point.y >= other.ll_point.y)
+ ):
+ return True
+ if self.ll_point.y <= other.p.y or self.p.y >= other.ll_point.y:
+ return False
+ return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD
+
+ def __lt__(self, other) -> bool:
+ '''
+ from top to bottom, from left to right
+ '''
+ if not self.same_row(other):
+ return self.p.y < other.p.y
+ else:
+ return self.p.x < other.p.x
+
+ def __repr__(self) -> str:
+ return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})"
+
+
+def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"):
+ drawer = ImageDraw.Draw(img)
+ for bbox in bboxes:
+ # Calculate the coordinates for the rectangle to be drawn
+ left = bbox.p.x
+ top = bbox.p.y
+ right = bbox.p.x + bbox.w
+ bottom = bbox.p.y + bbox.h
+
+ # Draw the rectangle on the image
+ drawer.rectangle([left, top, right, bottom], outline="green", width=1)
+
+ # Optionally, add text label if it exists
+ if bbox.label:
+ drawer.text((left, top), bbox.label, fill="blue")
+
+ if bbox.content:
+ drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
+
+ # Save the image with drawn rectangles
+ img.save(name)
\ No newline at end of file
diff --git a/src/models/det_model/inference.py b/src/models/det_model/inference.py
new file mode 100644
index 0000000..a142d16
--- /dev/null
+++ b/src/models/det_model/inference.py
@@ -0,0 +1,161 @@
+import os
+import yaml
+import numpy as np
+import cv2
+
+from tqdm import tqdm
+from typing import List
+from .preprocess import Compose
+from .Bbox import Bbox
+
+
+# Global dictionary
+SUPPORT_MODELS = {
+ 'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
+ 'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
+ 'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet',
+ 'DETR'
+}
+
+
+class PredictConfig(object):
+ """set config of preprocess, postprocess and visualize
+ Args:
+ infer_config (str): path of infer_cfg.yml
+ """
+
+ def __init__(self, infer_config):
+ # parsing Yaml config for Preprocess
+ with open(infer_config) as f:
+ yml_conf = yaml.safe_load(f)
+ self.check_model(yml_conf)
+ self.arch = yml_conf['arch']
+ self.preprocess_infos = yml_conf['Preprocess']
+ self.min_subgraph_size = yml_conf['min_subgraph_size']
+ self.label_list = yml_conf['label_list']
+ self.use_dynamic_shape = yml_conf['use_dynamic_shape']
+ self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
+ self.mask = yml_conf.get("mask", False)
+ self.tracker = yml_conf.get("tracker", None)
+ self.nms = yml_conf.get("NMS", None)
+ self.fpn_stride = yml_conf.get("fpn_stride", None)
+
+ color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
+ self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)}
+
+ if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
+ print(
+ 'The RCNN export model is used for ONNX and it only supports batch_size = 1'
+ )
+ self.print_config()
+
+ def check_model(self, yml_conf):
+ """
+ Raises:
+ ValueError: loaded model not in supported model type
+ """
+ for support_model in SUPPORT_MODELS:
+ if support_model in yml_conf['arch']:
+ return True
+ raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
+ 'arch'], SUPPORT_MODELS))
+
+ def print_config(self):
+ print('----------- Model Configuration -----------')
+ print('%s: %s' % ('Model Arch', self.arch))
+ print('%s: ' % ('Transform Order'))
+ for op_info in self.preprocess_infos:
+ print('--%s: %s' % ('transform op', op_info['type']))
+ print('--------------------------------------------')
+
+
+def draw_bbox(image, outputs, infer_config):
+ for output in outputs:
+ cls_id, score, xmin, ymin, xmax, ymax = output
+ if score > infer_config.draw_threshold:
+ label = infer_config.label_list[int(cls_id)]
+ color = infer_config.colors[label]
+ cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
+ cv2.putText(image, "{}: {:.2f}".format(label, score),
+ (int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
+ return image
+
+
+def predict_image(imgsave_dir, infer_config, predictor, img_list):
+ # load preprocess transforms
+ transforms = Compose(infer_config.preprocess_infos)
+ errImgList = []
+
+ # Check and create subimg_save_dir if not exist
+ subimg_save_dir = os.path.join(imgsave_dir, 'subimages')
+ os.makedirs(subimg_save_dir, exist_ok=True)
+
+ # predict image
+ for img_path in tqdm(img_list):
+ img = cv2.imread(img_path)
+ if img is None:
+ print(f"Warning: Could not read image {img_path}. Skipping...")
+ errImgList.append(img_path)
+ continue
+
+ inputs = transforms(img_path)
+ inputs_name = [var.name for var in predictor.get_inputs()]
+ inputs = {k: inputs[k][None, ] for k in inputs_name}
+
+ outputs = predictor.run(output_names=None, input_feed=inputs)
+
+ print("ONNXRuntime predict: ")
+ if infer_config.arch in ["HRNet"]:
+ print(np.array(outputs[0]))
+ else:
+ bboxes = np.array(outputs[0])
+ for bbox in bboxes:
+ if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
+ print(f"{int(bbox[0])} {bbox[1]} "
+ f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
+
+ # Save the subimages (crop from the original image)
+ subimg_counter = 1
+ for output in np.array(outputs[0]):
+ cls_id, score, xmin, ymin, xmax, ymax = output
+ if score > infer_config.draw_threshold:
+ label = infer_config.label_list[int(cls_id)]
+ subimg = img[int(max(ymin, 0)):int(ymax), int(max(xmin, 0)):int(xmax)]
+ if len(subimg) == 0:
+ continue
+
+ subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg"
+ subimg_path = os.path.join(subimg_save_dir, subimg_filename)
+ cv2.imwrite(subimg_path, subimg)
+ subimg_counter += 1
+
+ # Draw bounding boxes and save the image with bounding boxes
+ img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
+ output_dir = imgsave_dir
+ os.makedirs(output_dir, exist_ok=True)
+ output_file = os.path.join(output_dir, "output_" + os.path.basename(img_path))
+ cv2.imwrite(output_file, img_with_bbox)
+
+ print("ErrorImgs:")
+ print(errImgList)
+
+
+def predict(img_path: str, predictor, infer_config) -> List[Bbox]:
+ transforms = Compose(infer_config.preprocess_infos)
+ inputs = transforms(img_path)
+ inputs_name = [var.name for var in predictor.get_inputs()]
+ inputs = {k: inputs[k][None, ] for k in inputs_name}
+
+ outputs = predictor.run(output_names=None, input_feed=inputs)[0]
+ res = []
+ for output in outputs:
+ cls_name = infer_config.label_list[int(output[0])]
+ score = output[1]
+ xmin = int(max(output[2], 0))
+ ymin = int(max(output[3], 0))
+ xmax = int(output[4])
+ ymax = int(output[5])
+ if score > infer_config.draw_threshold:
+ res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
+
+ return res
diff --git a/src/models/det_model/model/infer_cfg.yml b/src/models/det_model/model/infer_cfg.yml
index 0c156fc..09e6603 100644
--- a/src/models/det_model/model/infer_cfg.yml
+++ b/src/models/det_model/model/infer_cfg.yml
@@ -8,8 +8,8 @@ Preprocess:
- interp: 2
keep_ratio: false
target_size:
- - 640
- - 640
+ - 1600
+ - 1600
type: Resize
- mean:
- 0.0
diff --git a/src/models/det_model/preprocess.py b/src/models/det_model/preprocess.py
index 3554b7f..6b72494 100644
--- a/src/models/det_model/preprocess.py
+++ b/src/models/det_model/preprocess.py
@@ -4,9 +4,14 @@ import copy
def decode_image(img_path):
- with open(img_path, 'rb') as f:
- im_read = f.read()
- data = np.frombuffer(im_read, dtype='uint8')
+ if isinstance(img_path, str):
+ with open(img_path, 'rb') as f:
+ im_read = f.read()
+ data = np.frombuffer(im_read, dtype='uint8')
+ else:
+ assert isinstance(img_path, np.ndarray)
+ data = img_path
+
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
img_info = {
diff --git a/src/models/ocr_model/utils/inference.py b/src/models/ocr_model/utils/inference.py
index 92bc08c..10f1b0d 100644
--- a/src/models/ocr_model/utils/inference.py
+++ b/src/models/ocr_model/utils/inference.py
@@ -4,19 +4,22 @@ import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List, Union
-from models.ocr_model.model.TexTeller import TexTeller
-from models.ocr_model.utils.transforms import inference_transform
-from models.ocr_model.utils.helpers import convert2rgb
-from models.globals import MAX_TOKEN_SIZE
+from .transforms import inference_transform
+from .helpers import convert2rgb
+from ..model.TexTeller import TexTeller
+from ...globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs: Union[List[str], List[np.ndarray]],
- inf_mode: str = 'cpu',
+ accelerator: str = 'cpu',
num_beams: int = 1,
+ max_tokens = None
) -> List[str]:
+ if imgs == []:
+ return []
model.eval()
if isinstance(imgs[0], str):
imgs = convert2rgb(imgs)
@@ -26,11 +29,11 @@ def inference(
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
- model = model.to(inf_mode)
- pixel_values = pixel_values.to(inf_mode)
+ model = model.to(accelerator)
+ pixel_values = pixel_values.to(accelerator)
generate_config = GenerationConfig(
- max_new_tokens=MAX_TOKEN_SIZE,
+ max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
diff --git a/src/models/ocr_model/utils/to_katex.py b/src/models/ocr_model/utils/to_katex.py
new file mode 100644
index 0000000..5a0a95b
--- /dev/null
+++ b/src/models/ocr_model/utils/to_katex.py
@@ -0,0 +1,110 @@
+import re
+
+
+def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
+ result = ""
+ i = 0
+ n = len(input_str)
+
+ while i < n:
+ if input_str[i:i+len(old_inst)] == old_inst:
+ # check if the old_inst is followed by old_surr_l
+ start = i + len(old_inst)
+ else:
+ result += input_str[i]
+ i += 1
+ continue
+
+ if start < n and input_str[start] == old_surr_l:
+ # found an old_inst followed by old_surr_l, now look for the matching old_surr_r
+ count = 1
+ j = start + 1
+ escaped = False
+ while j < n and count > 0:
+ if input_str[j] == '\\' and not escaped:
+ escaped = True
+ j += 1
+ continue
+ if input_str[j] == old_surr_r and not escaped:
+ count -= 1
+ if count == 0:
+ break
+ elif input_str[j] == old_surr_l and not escaped:
+ count += 1
+ escaped = False
+ j += 1
+
+ if count == 0:
+ assert j < n
+ assert input_str[start] == old_surr_l
+ assert input_str[j] == old_surr_r
+ inner_content = input_str[start + 1:j]
+ # Replace the content with new pattern
+ result += new_inst + new_surr_l + inner_content + new_surr_r
+ i = j + 1
+ continue
+ else:
+ assert count > 1
+ assert j == n
+ print("Warning: unbalanced surrogate pair in input string")
+ result += new_inst + new_surr_l
+ i = start + 1
+ continue
+ else:
+ i = start
+
+ if old_inst != new_inst and old_inst in result:
+ return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
+ else:
+ return result
+
+
+def to_katex(formula: str) -> str:
+ res = formula
+ res = change(res, r'\mbox', r'', r'{', r'}', r'', r'')
+
+ origin_instructions = [
+ r'\Huge',
+ r'\huge',
+ r'\LARGE',
+ r'\Large',
+ r'\large',
+ r'\normalsize',
+ r'\small',
+ r'\footnotesize',
+ r'\scriptsize',
+ r'\tiny'
+ ]
+ for (old_ins, new_ins) in zip(origin_instructions, origin_instructions):
+ res = change(res, old_ins, new_ins, r'$', r'$', '{', '}')
+ res = change(res, r'\boldmath', r'\bm', r'$', r'$', r'{', r'}')
+
+ origin_instructions = [
+ r'\left',
+ r'\middle',
+ r'\right',
+ r'\big',
+ r'\Big',
+ r'\bigg',
+ r'\Bigg',
+ r'\bigl',
+ r'\Bigl',
+ r'\biggl',
+ r'\Biggl',
+ r'\bigm',
+ r'\Bigm',
+ r'\biggm',
+ r'\Biggm',
+ r'\bigr',
+ r'\Bigr',
+ r'\biggr',
+ r'\Biggr'
+ ]
+ for origin_ins in origin_instructions:
+ res = change(res, origin_ins, origin_ins, r'{', r'}', r'', r'')
+
+ res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
+
+ if res.endswith(r'\newline'):
+ res = res[:-8]
+ return res
diff --git a/src/models/utils/__init__.py b/src/models/utils/__init__.py
new file mode 100644
index 0000000..775dc11
--- /dev/null
+++ b/src/models/utils/__init__.py
@@ -0,0 +1 @@
+from .mix_inference import mix_inference
\ No newline at end of file
diff --git a/src/models/utils/mix_inference.py b/src/models/utils/mix_inference.py
new file mode 100644
index 0000000..746237f
--- /dev/null
+++ b/src/models/utils/mix_inference.py
@@ -0,0 +1,257 @@
+import re
+import heapq
+import cv2
+import numpy as np
+
+from onnxruntime import InferenceSession
+from collections import Counter
+from typing import List
+
+from PIL import Image
+from surya.ocr import run_ocr
+from surya.detection import batch_text_detection
+from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image
+from surya.recognition import batch_recognition
+from surya.model.detection import segformer
+from surya.model.recognition.model import load_model
+from surya.model.recognition.processor import load_processor
+
+from ..det_model.inference import PredictConfig
+from ..det_model.inference import predict as latex_det_predict
+from ..det_model.Bbox import Bbox, draw_bboxes
+
+from ..ocr_model.model.TexTeller import TexTeller
+from ..ocr_model.utils.inference import inference as latex_rec_predict
+from ..ocr_model.utils.to_katex import to_katex
+
+MAXV = 999999999
+
+
+def mask_img(img, bboxes: List[Bbox], bg_color: np.ndarray) -> np.ndarray:
+ mask_img = img.copy()
+ for bbox in bboxes:
+ mask_img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w] = bg_color
+ return mask_img
+
+
+def bbox_merge(sorted_bboxes: List[Bbox]) -> List[Bbox]:
+ if (len(sorted_bboxes) == 0):
+ return []
+ bboxes = sorted_bboxes.copy()
+ guard = Bbox(MAXV, bboxes[-1].p.y, -1, -1, label="guard")
+ bboxes.append(guard)
+ res = []
+ prev = bboxes[0]
+ for curr in bboxes:
+ if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
+ res.append(prev)
+ prev = curr
+ else:
+ prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
+ return res
+
+
+def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbox]:
+ if latex_bboxes == []:
+ return ocr_bboxes
+ if ocr_bboxes == [] or len(ocr_bboxes) == 1:
+ return ocr_bboxes
+
+ bboxes = sorted(ocr_bboxes + latex_bboxes)
+
+ ######## debug #########
+ for idx, bbox in enumerate(bboxes):
+ bbox.content = str(idx)
+ draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png")
+ ######## debug ###########
+
+ assert len(bboxes) > 1
+
+ heapq.heapify(bboxes)
+ res = []
+ candidate = heapq.heappop(bboxes)
+ curr = heapq.heappop(bboxes)
+ idx = 0
+ while (len(bboxes) > 0):
+ idx += 1
+ assert candidate.p.x < curr.p.x or not candidate.same_row(curr)
+
+ if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
+ res.append(candidate)
+ candidate = curr
+ curr = heapq.heappop(bboxes)
+ elif candidate.ur_point.x < curr.ur_point.x:
+ assert not (candidate.label != "text" and curr.label != "text")
+ if candidate.label == "text" and curr.label == "text":
+ candidate.w = curr.ur_point.x - candidate.p.x
+ curr = heapq.heappop(bboxes)
+ elif candidate.label != curr.label:
+ if candidate.label == "text":
+ candidate.w = curr.p.x - candidate.p.x
+ res.append(candidate)
+ candidate = curr
+ curr = heapq.heappop(bboxes)
+ else:
+ curr.w = curr.ur_point.x - candidate.ur_point.x
+ curr.p.x = candidate.ur_point.x
+ heapq.heappush(bboxes, curr)
+ curr = heapq.heappop(bboxes)
+
+ elif candidate.ur_point.x >= curr.ur_point.x:
+ assert not (candidate.label != "text" and curr.label != "text")
+
+ if candidate.label == "text":
+ assert curr.label != "text"
+ heapq.heappush(
+ bboxes,
+ Bbox(
+ curr.ur_point.x,
+ candidate.p.y,
+ candidate.h,
+ candidate.ur_point.x - curr.ur_point.x,
+ label="text",
+ confidence=candidate.confidence,
+ content=None
+ )
+ )
+ candidate.w = curr.p.x - candidate.p.x
+ res.append(candidate)
+ candidate = curr
+ curr = heapq.heappop(bboxes)
+ else:
+ assert curr.label == "text"
+ curr = heapq.heappop(bboxes)
+ else:
+ assert False
+ res.append(candidate)
+ res.append(curr)
+ ######## debug #########
+ for idx, bbox in enumerate(res):
+ bbox.content = str(idx)
+ draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png")
+ ######## debug ###########
+ return res
+
+
+def mix_inference(
+ img_path: str,
+ language: str,
+ infer_config,
+ latex_det_model,
+
+ lang_ocr_models,
+
+ latex_rec_models,
+ accelerator="cpu",
+ num_beams=1
+) -> str:
+ '''
+ Input a mixed image of formula text and output str (in markdown syntax)
+ '''
+ global img
+ img = cv2.imread(img_path)
+ corners = [tuple(img[0, 0]), tuple(img[0, -1]),
+ tuple(img[-1, 0]), tuple(img[-1, -1])]
+ bg_color = np.array(Counter(corners).most_common(1)[0][0])
+
+ latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
+ latex_bboxes = sorted(latex_bboxes)
+ draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
+ latex_bboxes = bbox_merge(latex_bboxes)
+ draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
+ masked_img = mask_img(img, latex_bboxes, bg_color)
+
+ det_model, det_processor, rec_model, rec_processor = lang_ocr_models
+ images = [Image.fromarray(masked_img)]
+ det_prediction = batch_text_detection(images, det_model, det_processor)[0]
+ draw_bboxes(Image.fromarray(img), latex_bboxes, name="ocr_bboxes(unmerged).png")
+
+ lang = [language]
+ slice_map = []
+ all_slices = []
+ all_langs = []
+ ocr_bboxes = [
+ Bbox(
+ p.bbox[0], p.bbox[1], p.bbox[3] - p.bbox[1], p.bbox[2] - p.bbox[0],
+ label="text",
+ confidence=p.confidence,
+ content=None
+ )
+ for p in det_prediction.bboxes
+ ]
+ ocr_bboxes = sorted(ocr_bboxes)
+ ocr_bboxes = bbox_merge(ocr_bboxes)
+ draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
+ ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
+ ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
+ polygons = [
+ [
+ [bbox.ul_point.x, bbox.ul_point.y],
+ [bbox.ur_point.x, bbox.ur_point.y],
+ [bbox.lr_point.x, bbox.lr_point.y],
+ [bbox.ll_point.x, bbox.ll_point.y]
+ ]
+ for bbox in ocr_bboxes
+ ]
+
+ slices = slice_polys_from_image(images[0], polygons)
+ slice_map.append(len(slices))
+ all_slices.extend(slices)
+ all_langs.extend([lang] * len(slices))
+
+ rec_predictions, _ = batch_recognition(all_slices, all_langs, rec_model, rec_processor)
+
+ assert len(rec_predictions) == len(ocr_bboxes)
+ for content, bbox in zip(rec_predictions, ocr_bboxes):
+ bbox.content = content
+
+ latex_imgs =[]
+ for bbox in latex_bboxes:
+ latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w])
+ latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=200)
+ for bbox, content in zip(latex_bboxes, latex_rec_res):
+ bbox.content = to_katex(content)
+ if bbox.label == "embedding":
+ bbox.content = " $" + bbox.content + "$ "
+ elif bbox.label == "isolated":
+ bbox.content = '\n' + r"$$" + bbox.content + r"$$" + '\n'
+
+ bboxes = sorted(ocr_bboxes + latex_bboxes)
+ if bboxes == []:
+ return ""
+
+ md = ""
+ prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
+ # prev = bboxes[0]
+ for curr in bboxes:
+ if not prev.same_row(curr):
+ md += "\n"
+ md += curr.content
+ if (
+ prev.label == "isolated"
+ and curr.label == "text"
+ and bool(re.fullmatch(r"\([1-9]\d*?\)", curr.content))
+ ):
+ md += '\n'
+ prev = curr
+ return md
+
+
+if __name__ == '__main__':
+ img_path = "/Users/Leehy/Code/TexTeller/test3.png"
+
+ # latex_det_model = InferenceSession("/Users/Leehy/Code/TexTeller/src/models/det_model/model/rtdetr_r50vd_6x_coco_trained_on_IBEM_en_papers.onnx")
+ latex_det_model = InferenceSession("/Users/Leehy/Code/TexTeller/src/models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
+ infer_config = PredictConfig("/Users/Leehy/Code/TexTeller/src/models/det_model/model/infer_cfg.yml")
+
+ det_processor, det_model = segformer.load_processor(), segformer.load_model()
+ rec_model, rec_processor = load_model(), load_processor()
+ lang_ocr_models = (det_model, det_processor, rec_model, rec_processor)
+
+ texteller = TexTeller.from_pretrained()
+ tokenizer = TexTeller.get_tokenizer()
+ latex_rec_models = (texteller, tokenizer)
+
+ res = mix_inference(img_path, "zh", infer_config, latex_det_model, lang_ocr_models, latex_rec_models)
+ print(res)
+ pause = 1
diff --git a/src/rec_infer_from_crop_imgs.py b/src/rec_infer_from_crop_imgs.py
index 1cacbc6..73bfa73 100644
--- a/src/rec_infer_from_crop_imgs.py
+++ b/src/rec_infer_from_crop_imgs.py
@@ -2,7 +2,7 @@ import os
import argparse
import cv2 as cv
from pathlib import Path
-from utils import to_katex
+from models.ocr_model.utils.to_katex import to_katex
from models.ocr_model.utils.inference import inference as latex_inference
from models.ocr_model.model.TexTeller import TexTeller
@@ -46,7 +46,7 @@ if __name__ == '__main__':
if img is not None:
print(f'Inference for {filename}...')
- res = latex_inference(latex_rec_model, tokenizer, [img], inf_mode=args.inference_mode, num_beams=args.num_beam)
+ res = latex_inference(latex_rec_model, tokenizer, [img], accelerator=args.inference_mode, num_beams=args.num_beam)
res = to_katex(res[0])
# Save the recognition result to a text file
diff --git a/src/server.py b/src/server.py
index 520d908..52e0068 100644
--- a/src/server.py
+++ b/src/server.py
@@ -56,7 +56,7 @@ class TexTellerServer:
def predict(self, image_nparray) -> str:
return inference(
self.model, self.tokenizer, [image_nparray],
- inf_mode=self.inf_mode, num_beams=self.num_beams
+ accelerator=self.inf_mode, num_beams=self.num_beams
)[0]
diff --git a/src/start_web.sh b/src/start_web.sh
index 41e6311..6ec8f7b 100755
--- a/src/start_web.sh
+++ b/src/start_web.sh
@@ -1,7 +1,7 @@
#!/usr/bin/env bash
set -exu
-export CHECKPOINT_DIR="/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-648000"
+export CHECKPOINT_DIR="default"
export TOKENIZER_DIR="default"
streamlit run web.py
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
deleted file mode 100644
index ba17b40..0000000
--- a/src/utils/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .to_katex import to_katex
\ No newline at end of file
diff --git a/src/utils/to_katex.py b/src/utils/to_katex.py
deleted file mode 100644
index fe08a74..0000000
--- a/src/utils/to_katex.py
+++ /dev/null
@@ -1,15 +0,0 @@
-import numpy as np
-import re
-
-def to_katex(formula: str) -> str:
- res = formula
- res = re.sub(r'\\mbox\{([^}]*)\}', r'\1', res)
- res = re.sub(r'boldmath\$(.*?)\$', r'bm{\1}', res)
- res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
-
- pattern = r'(\\(?:left|middle|right|big|Big|bigg|Bigg|bigl|Bigl|biggl|Biggl|bigm|Bigm|biggm|Biggm|bigr|Bigr|biggr|Biggr))\{([^}]*)\}'
- replacement = r'\1\2'
- res = re.sub(pattern, replacement, res)
- if res.endswith(r'\newline'):
- res = res[:-8]
- return res
diff --git a/src/web.py b/src/web.py
index 379a609..d4a84f9 100644
--- a/src/web.py
+++ b/src/web.py
@@ -7,10 +7,18 @@ import streamlit as st
from PIL import Image
from streamlit_paste_button import paste_image_button as pbutton
-from models.ocr_model.utils.inference import inference
-from models.ocr_model.model.TexTeller import TexTeller
-from utils import to_katex
+from onnxruntime import InferenceSession
+from models.utils import mix_inference
+from models.det_model.inference import PredictConfig
+
+from models.ocr_model.model.TexTeller import TexTeller
+from models.ocr_model.utils.inference import inference as latex_recognition
+from models.ocr_model.utils.to_katex import to_katex
+
+from surya.model.detection import segformer
+from surya.model.recognition.model import load_model
+from surya.model.recognition.processor import load_processor
st.set_page_config(
page_title="TexTeller",
@@ -42,13 +50,26 @@ fail_gif_html = '''
'''
@st.cache_resource
-def get_model():
+def get_texteller():
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
@st.cache_resource
def get_tokenizer():
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
+@st.cache_resource
+def get_det_models():
+ infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
+ latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
+ return infer_config, latex_det_model
+
+@st.cache_resource()
+def get_ocr_models():
+ det_processor, det_model = segformer.load_processor(), segformer.load_model()
+ rec_model, rec_processor = load_model(), load_processor()
+ lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
+ return lang_ocr_models
+
def get_image_base64(img_file):
buffered = io.BytesIO()
img_file.seek(0)
@@ -62,9 +83,6 @@ def on_file_upload():
def change_side_bar():
st.session_state["CHANGE_SIDEBAR_FLAG"] = True
-model = get_model()
-tokenizer = get_tokenizer()
-
if "start" not in st.session_state:
st.session_state["start"] = 1
st.toast('Hooray!', icon='🎉')
@@ -75,31 +93,34 @@ if "UPLOADED_FILE_CHANGED" not in st.session_state:
if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
+if "INF_MODE" not in st.session_state:
+ st.session_state["INF_MODE"] = "Formula only"
+
+
# ============================ begin sidebar =============================== #
with st.sidebar:
num_beams = 1
- inf_mode = 'cpu'
st.markdown("# 🔨️ Config")
st.markdown("")
- model_type = st.selectbox(
- "Model type",
- ("TexTeller", "None"),
+ inf_mode = st.selectbox(
+ "Inference mode",
+ ("Formula only", "Text formula mixed"),
on_change=change_side_bar
)
- if model_type == "TexTeller":
- num_beams = st.number_input(
- 'Number of beams',
- min_value=1,
- max_value=20,
- step=1,
- on_change=change_side_bar
- )
- inf_mode = st.radio(
- "Inference mode",
+ num_beams = st.number_input(
+ 'Number of beams',
+ min_value=1,
+ max_value=20,
+ step=1,
+ on_change=change_side_bar
+ )
+
+ accelerator = st.radio(
+ "Accelerator",
("cpu", "cuda", "mps"),
on_change=change_side_bar
)
@@ -107,9 +128,16 @@ with st.sidebar:
# ============================ end sidebar =============================== #
-
# ============================ begin pages =============================== #
+texteller = get_texteller()
+tokenizer = get_tokenizer()
+latex_rec_models = [texteller, tokenizer]
+
+if inf_mode == "Text formula mixed":
+ infer_config, latex_det_model = get_det_models()
+ lang_ocr_models = get_ocr_models()
+
st.markdown(html_string, unsafe_allow_html=True)
uploaded_file = st.file_uploader(
@@ -176,19 +204,26 @@ elif uploaded_file or paste_result.image_data is not None:
st.write("")
with st.spinner("Predicting..."):
- uploaded_file.seek(0)
- TexTeller_result = inference(
- model,
- tokenizer,
- [png_file_path],
- inf_mode=inf_mode,
- num_beams=num_beams
- )[0]
+ if inf_mode == "Formula only":
+ TexTeller_result = latex_recognition(
+ texteller,
+ tokenizer,
+ [png_file_path],
+ accelerator=accelerator,
+ num_beams=num_beams
+ )[0]
+ katex_res = to_katex(TexTeller_result)
+ else:
+ katex_res = mix_inference(png_file_path, "en", infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams)
+
st.success('Completed!', icon="✅")
st.markdown(suc_gif_html, unsafe_allow_html=True)
- katex_res = to_katex(TexTeller_result)
st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150)
- st.latex(katex_res)
+
+ if inf_mode == "Formula only":
+ st.latex(katex_res)
+ elif inf_mode == "Text formula mixed":
+ st.markdown(katex_res)
st.write("")
st.write("")