1) 实现了文本-公式混排识别; 2) 重构了项目结构

This commit is contained in:
三洋三洋
2024-04-21 00:05:14 +08:00
parent eab6e4c85d
commit 185b2e3db6
19 changed files with 753 additions and 296 deletions

6
.gitignore vendored
View File

@@ -6,6 +6,8 @@
**/ckpt **/ckpt
**/*cache **/*cache
**/.cache **/.cache
**/tmp
**/log
**/data **/data
**/logs **/logs
@@ -13,3 +15,7 @@
**/data **/data
**/*cache **/*cache
**/ckpt **/ckpt
**/*.bin
**/*.safetensor
**/*.onnx

View File

@@ -14,4 +14,4 @@ onnxruntime
streamlit==1.30 streamlit==1.30
streamlit-paste-button streamlit-paste-button
easyocr surya-ocr

View File

@@ -1,75 +0,0 @@
import os
import gradio as gr
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
from utils import to_katex
from pathlib import Path
# model = TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
# tokenizer = TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
css = """
<style>
.container {
display: flex;
align-items: center;
justify-content: center;
font-size: 20px;
font-family: 'Arial';
}
.container img {
height: auto;
}
.text {
margin: 0 15px;
}
h1 {
text-align: center;
font-size: 50px !important;
}
.markdown-style {
color: #333; /* 调整颜色 */
line-height: 1.6; /* 行间距 */
font-size: 50px;
}
.markdown-style h1, .markdown-style h2, .markdown-style h3 {
color: #007BFF; /* 为标题元素指定颜色 */
}
.markdown-style p {
margin-bottom: 1em; /* 段落间距 */
}
</style>
"""
theme=gr.themes.Default(),
def fn(img):
return img
with gr.Blocks(
theme=theme,
css=css
) as demo:
gr.HTML(f'''
{css}
<div class="container">
<img src="https://github.com/OleehyO/TexTeller/raw/main/assets/fire.svg" width="100">
<h1> 𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛 </h1>
<img src="https://github.com/OleehyO/TexTeller/raw/main/assets/fire.svg" width="100">
</div>
''')
with gr.Row(equal_height=True):
input_img = gr.Image(type="pil", label="Input Image")
latex_img = gr.Image(label="Predicted Latex", show_label=False)
input_img.upload(fn, input_img, latex_img)
gr.Markdown(r'$$\fcxrac{7}{10349}$$')
gr.Markdown('fooooooooooooooooooooooooooooo')
demo.launch()

View File

@@ -1,35 +1,21 @@
import os import os
import yaml
import argparse import argparse
import numpy as np
import glob import glob
from onnxruntime import InferenceSession from onnxruntime import InferenceSession
from tqdm import tqdm from pathlib import Path
from models.det_model.inference import PredictConfig, predict_image
from models.det_model.preprocess import Compose
import cv2
# Global dictionary
SUPPORT_MODELS = {
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet',
'DETR'
}
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml", parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml",
default="./models/det_model/model/infer_cfg.yml" default="./models/det_model/model/infer_cfg.yml")
)
parser.add_argument('--onnx_file', type=str, help="onnx model file path", parser.add_argument('--onnx_file', type=str, help="onnx model file path",
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx" default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
)
parser.add_argument("--image_dir", type=str) parser.add_argument("--image_dir", type=str)
parser.add_argument("--image_file", type=str, default='/data/ljm/TexTeller/src/Tr00_0001015-page02.jpg') parser.add_argument("--image_file", type=str, required=True)
parser.add_argument("--imgsave_dir", type=str, parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
default="."
)
def get_test_images(infer_dir, infer_img): def get_test_images(infer_dir, infer_img):
""" """
@@ -62,125 +48,11 @@ def get_test_images(infer_dir, infer_img):
return images return images
class PredictConfig(object):
"""set config of preprocess, postprocess and visualize
Args:
infer_config (str): path of infer_cfg.yml
"""
def __init__(self, infer_config):
# parsing Yaml config for Preprocess
with open(infer_config) as f:
yml_conf = yaml.safe_load(f)
self.check_model(yml_conf)
self.arch = yml_conf['arch']
self.preprocess_infos = yml_conf['Preprocess']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.label_list = yml_conf['label_list']
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
self.mask = yml_conf.get("mask", False)
self.tracker = yml_conf.get("tracker", None)
self.nms = yml_conf.get("NMS", None)
self.fpn_stride = yml_conf.get("fpn_stride", None)
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)}
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
print(
'The RCNN export model is used for ONNX and it only supports batch_size = 1'
)
self.print_config()
def check_model(self, yml_conf):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for support_model in SUPPORT_MODELS:
if support_model in yml_conf['arch']:
return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
'arch'], SUPPORT_MODELS))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def draw_bbox(image, outputs, infer_config):
for output in outputs:
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
color = infer_config.colors[label]
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
cv2.putText(image, "{}: {:.2f}".format(label, score),
(int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return image
def predict_image(infer_config, predictor, img_list):
# load preprocess transforms
transforms = Compose(infer_config.preprocess_infos)
errImgList = []
# Check and create subimg_save_dir if not exist
subimg_save_dir = os.path.join(FLAGS.imgsave_dir, 'subimages')
os.makedirs(subimg_save_dir, exist_ok=True)
# predict image
for img_path in tqdm(img_list):
img = cv2.imread(img_path)
if img is None:
print(f"Warning: Could not read image {img_path}. Skipping...")
errImgList.append(img_path)
continue
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None, ] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)
print("ONNXRuntime predict: ")
if infer_config.arch in ["HRNet"]:
print(np.array(outputs[0]))
else:
bboxes = np.array(outputs[0])
for bbox in bboxes:
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
print(f"{int(bbox[0])} {bbox[1]} "
f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
# Save the subimages (crop from the original image)
subimg_counter = 1
for output in np.array(outputs[0]):
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
subimg = img[int(ymin):int(ymax), int(xmin):int(xmax)]
subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg"
subimg_path = os.path.join(subimg_save_dir, subimg_filename)
cv2.imwrite(subimg_path, subimg)
subimg_counter += 1
# Draw bounding boxes and save the image with bounding boxes
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
output_dir = FLAGS.imgsave_dir
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, "output_" + os.path.basename(img_path))
cv2.imwrite(output_file, img_with_bbox)
print("ErrorImgs:")
print(errImgList)
if __name__ == '__main__': if __name__ == '__main__':
cur_path = os.getcwd()
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
# load image list # load image list
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
@@ -189,4 +61,6 @@ if __name__ == '__main__':
# load infer config # load infer config
infer_config = PredictConfig(FLAGS.infer_cfg) infer_config = PredictConfig(FLAGS.infer_cfg)
predict_image(infer_config, predictor, img_list) predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list)
os.chdir(cur_path)

View File

@@ -3,9 +3,18 @@ import argparse
import cv2 as cv import cv2 as cv
from pathlib import Path from pathlib import Path
from utils import to_katex from onnxruntime import InferenceSession
from models.utils import mix_inference
from models.ocr_model.utils.to_katex import to_katex
from models.ocr_model.utils.inference import inference as latex_inference from models.ocr_model.utils.inference import inference as latex_inference
from models.ocr_model.model.TexTeller import TexTeller from models.ocr_model.model.TexTeller import TexTeller
from models.det_model.inference import PredictConfig
from surya.model.detection import segformer
from surya.model.recognition.model import load_model
from surya.model.recognition.processor import load_processor
if __name__ == '__main__': if __name__ == '__main__':
@@ -29,33 +38,35 @@ if __name__ == '__main__':
default=1, default=1,
help='number of beam search for decoding' help='number of beam search for decoding'
) )
# ================= new feature ==================
parser.add_argument( parser.add_argument(
'-mix', '-mix',
type=str, action='store_true',
help='use mix mode, only Chinese and English are supported.' help='use mix mode'
) )
# ==================================================
args = parser.parse_args() args = parser.parse_args()
# You can use your own checkpoint and tokenizer path. # You can use your own checkpoint and tokenizer path.
print('Loading model and tokenizer...') print('Loading model and tokenizer...')
latex_rec_model = TexTeller.from_pretrained() latex_rec_model = TexTeller.from_pretrained()
latex_rec_model = TexTeller.from_pretrained()
tokenizer = TexTeller.get_tokenizer() tokenizer = TexTeller.get_tokenizer()
print('Model and tokenizer loaded.') print('Model and tokenizer loaded.')
# img_path = [args.img] img_path = args.img
img = cv.imread(args.img) img = cv.imread(img_path)
print('Inference...') print('Inference...')
if not args.mix: if not args.mix:
res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam) res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam)
res = to_katex(res[0]) res = to_katex(res[0])
print(res) print(res)
else: else:
# latex_det_model = load_det_tex_model() infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
# lang_model = load_lang_models()... latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco_IBEM_cnTextBook.onnx")
...
# res: str = mix_inference(latex_det_model, latex_rec_model, lang_model, img, args.cuda) det_processor, det_model = segformer.load_processor(), segformer.load_model()
# print(res) rec_model, rec_processor = load_model(), load_processor()
lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
latex_rec_models = [latex_rec_model, tokenizer]
res = mix_inference(img_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
print(res)

View File

@@ -0,0 +1,85 @@
from PIL import Image, ImageDraw
from typing import List
class Point:
def __init__(self, x: int, y: int):
self.x = int(x)
self.y = int(y)
def __repr__(self) -> str:
return f"Point(x={self.x}, y={self.y})"
class Bbox:
THREADHOLD = 0.4
def __init__(self, x, y, h, w, label: str = None, confidence: float = 0, content: str = None):
self.p = Point(x, y)
self.h = int(h)
self.w = int(w)
self.label = label
self.confidence = confidence
self.content = content
@property
def ul_point(self) -> Point:
return self.p
@property
def ur_point(self) -> Point:
return Point(self.p.x + self.w, self.p.y)
@property
def ll_point(self) -> Point:
return Point(self.p.x, self.p.y + self.h)
@property
def lr_point(self) -> Point:
return Point(self.p.x + self.w, self.p.y + self.h)
def same_row(self, other) -> bool:
if (
(self.p.y >= other.p.y and self.ll_point.y <= other.ll_point.y)
or (self.p.y <= other.p.y and self.ll_point.y >= other.ll_point.y)
):
return True
if self.ll_point.y <= other.p.y or self.p.y >= other.ll_point.y:
return False
return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD
def __lt__(self, other) -> bool:
'''
from top to bottom, from left to right
'''
if not self.same_row(other):
return self.p.y < other.p.y
else:
return self.p.x < other.p.x
def __repr__(self) -> str:
return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})"
def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"):
drawer = ImageDraw.Draw(img)
for bbox in bboxes:
# Calculate the coordinates for the rectangle to be drawn
left = bbox.p.x
top = bbox.p.y
right = bbox.p.x + bbox.w
bottom = bbox.p.y + bbox.h
# Draw the rectangle on the image
drawer.rectangle([left, top, right, bottom], outline="green", width=1)
# Optionally, add text label if it exists
if bbox.label:
drawer.text((left, top), bbox.label, fill="blue")
if bbox.content:
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
# Save the image with drawn rectangles
img.save(name)

View File

@@ -0,0 +1,161 @@
import os
import yaml
import numpy as np
import cv2
from tqdm import tqdm
from typing import List
from .preprocess import Compose
from .Bbox import Bbox
# Global dictionary
SUPPORT_MODELS = {
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet',
'DETR'
}
class PredictConfig(object):
"""set config of preprocess, postprocess and visualize
Args:
infer_config (str): path of infer_cfg.yml
"""
def __init__(self, infer_config):
# parsing Yaml config for Preprocess
with open(infer_config) as f:
yml_conf = yaml.safe_load(f)
self.check_model(yml_conf)
self.arch = yml_conf['arch']
self.preprocess_infos = yml_conf['Preprocess']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.label_list = yml_conf['label_list']
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
self.mask = yml_conf.get("mask", False)
self.tracker = yml_conf.get("tracker", None)
self.nms = yml_conf.get("NMS", None)
self.fpn_stride = yml_conf.get("fpn_stride", None)
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)}
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
print(
'The RCNN export model is used for ONNX and it only supports batch_size = 1'
)
self.print_config()
def check_model(self, yml_conf):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for support_model in SUPPORT_MODELS:
if support_model in yml_conf['arch']:
return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
'arch'], SUPPORT_MODELS))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def draw_bbox(image, outputs, infer_config):
for output in outputs:
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
color = infer_config.colors[label]
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
cv2.putText(image, "{}: {:.2f}".format(label, score),
(int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return image
def predict_image(imgsave_dir, infer_config, predictor, img_list):
# load preprocess transforms
transforms = Compose(infer_config.preprocess_infos)
errImgList = []
# Check and create subimg_save_dir if not exist
subimg_save_dir = os.path.join(imgsave_dir, 'subimages')
os.makedirs(subimg_save_dir, exist_ok=True)
# predict image
for img_path in tqdm(img_list):
img = cv2.imread(img_path)
if img is None:
print(f"Warning: Could not read image {img_path}. Skipping...")
errImgList.append(img_path)
continue
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None, ] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)
print("ONNXRuntime predict: ")
if infer_config.arch in ["HRNet"]:
print(np.array(outputs[0]))
else:
bboxes = np.array(outputs[0])
for bbox in bboxes:
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
print(f"{int(bbox[0])} {bbox[1]} "
f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
# Save the subimages (crop from the original image)
subimg_counter = 1
for output in np.array(outputs[0]):
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
subimg = img[int(max(ymin, 0)):int(ymax), int(max(xmin, 0)):int(xmax)]
if len(subimg) == 0:
continue
subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg"
subimg_path = os.path.join(subimg_save_dir, subimg_filename)
cv2.imwrite(subimg_path, subimg)
subimg_counter += 1
# Draw bounding boxes and save the image with bounding boxes
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
output_dir = imgsave_dir
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, "output_" + os.path.basename(img_path))
cv2.imwrite(output_file, img_with_bbox)
print("ErrorImgs:")
print(errImgList)
def predict(img_path: str, predictor, infer_config) -> List[Bbox]:
transforms = Compose(infer_config.preprocess_infos)
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None, ] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
res = []
for output in outputs:
cls_name = infer_config.label_list[int(output[0])]
score = output[1]
xmin = int(max(output[2], 0))
ymin = int(max(output[3], 0))
xmax = int(output[4])
ymax = int(output[5])
if score > infer_config.draw_threshold:
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
return res

View File

@@ -8,8 +8,8 @@ Preprocess:
- interp: 2 - interp: 2
keep_ratio: false keep_ratio: false
target_size: target_size:
- 640 - 1600
- 640 - 1600
type: Resize type: Resize
- mean: - mean:
- 0.0 - 0.0

View File

@@ -4,9 +4,14 @@ import copy
def decode_image(img_path): def decode_image(img_path):
with open(img_path, 'rb') as f: if isinstance(img_path, str):
im_read = f.read() with open(img_path, 'rb') as f:
data = np.frombuffer(im_read, dtype='uint8') im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
else:
assert isinstance(img_path, np.ndarray)
data = img_path
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
img_info = { img_info = {

View File

@@ -4,19 +4,22 @@ import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List, Union from typing import List, Union
from models.ocr_model.model.TexTeller import TexTeller from .transforms import inference_transform
from models.ocr_model.utils.transforms import inference_transform from .helpers import convert2rgb
from models.ocr_model.utils.helpers import convert2rgb from ..model.TexTeller import TexTeller
from models.globals import MAX_TOKEN_SIZE from ...globals import MAX_TOKEN_SIZE
def inference( def inference(
model: TexTeller, model: TexTeller,
tokenizer: RobertaTokenizerFast, tokenizer: RobertaTokenizerFast,
imgs: Union[List[str], List[np.ndarray]], imgs: Union[List[str], List[np.ndarray]],
inf_mode: str = 'cpu', accelerator: str = 'cpu',
num_beams: int = 1, num_beams: int = 1,
max_tokens = None
) -> List[str]: ) -> List[str]:
if imgs == []:
return []
model.eval() model.eval()
if isinstance(imgs[0], str): if isinstance(imgs[0], str):
imgs = convert2rgb(imgs) imgs = convert2rgb(imgs)
@@ -26,11 +29,11 @@ def inference(
imgs = inference_transform(imgs) imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs) pixel_values = torch.stack(imgs)
model = model.to(inf_mode) model = model.to(accelerator)
pixel_values = pixel_values.to(inf_mode) pixel_values = pixel_values.to(accelerator)
generate_config = GenerationConfig( generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE, max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
num_beams=num_beams, num_beams=num_beams,
do_sample=False, do_sample=False,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,

View File

@@ -0,0 +1,110 @@
import re
def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
result = ""
i = 0
n = len(input_str)
while i < n:
if input_str[i:i+len(old_inst)] == old_inst:
# check if the old_inst is followed by old_surr_l
start = i + len(old_inst)
else:
result += input_str[i]
i += 1
continue
if start < n and input_str[start] == old_surr_l:
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
count = 1
j = start + 1
escaped = False
while j < n and count > 0:
if input_str[j] == '\\' and not escaped:
escaped = True
j += 1
continue
if input_str[j] == old_surr_r and not escaped:
count -= 1
if count == 0:
break
elif input_str[j] == old_surr_l and not escaped:
count += 1
escaped = False
j += 1
if count == 0:
assert j < n
assert input_str[start] == old_surr_l
assert input_str[j] == old_surr_r
inner_content = input_str[start + 1:j]
# Replace the content with new pattern
result += new_inst + new_surr_l + inner_content + new_surr_r
i = j + 1
continue
else:
assert count > 1
assert j == n
print("Warning: unbalanced surrogate pair in input string")
result += new_inst + new_surr_l
i = start + 1
continue
else:
i = start
if old_inst != new_inst and old_inst in result:
return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
else:
return result
def to_katex(formula: str) -> str:
res = formula
res = change(res, r'\mbox', r'', r'{', r'}', r'', r'')
origin_instructions = [
r'\Huge',
r'\huge',
r'\LARGE',
r'\Large',
r'\large',
r'\normalsize',
r'\small',
r'\footnotesize',
r'\scriptsize',
r'\tiny'
]
for (old_ins, new_ins) in zip(origin_instructions, origin_instructions):
res = change(res, old_ins, new_ins, r'$', r'$', '{', '}')
res = change(res, r'\boldmath', r'\bm', r'$', r'$', r'{', r'}')
origin_instructions = [
r'\left',
r'\middle',
r'\right',
r'\big',
r'\Big',
r'\bigg',
r'\Bigg',
r'\bigl',
r'\Bigl',
r'\biggl',
r'\Biggl',
r'\bigm',
r'\Bigm',
r'\biggm',
r'\Biggm',
r'\bigr',
r'\Bigr',
r'\biggr',
r'\Biggr'
]
for origin_ins in origin_instructions:
res = change(res, origin_ins, origin_ins, r'{', r'}', r'', r'')
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
if res.endswith(r'\newline'):
res = res[:-8]
return res

View File

@@ -0,0 +1 @@
from .mix_inference import mix_inference

View File

@@ -0,0 +1,257 @@
import re
import heapq
import cv2
import numpy as np
from onnxruntime import InferenceSession
from collections import Counter
from typing import List
from PIL import Image
from surya.ocr import run_ocr
from surya.detection import batch_text_detection
from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image
from surya.recognition import batch_recognition
from surya.model.detection import segformer
from surya.model.recognition.model import load_model
from surya.model.recognition.processor import load_processor
from ..det_model.inference import PredictConfig
from ..det_model.inference import predict as latex_det_predict
from ..det_model.Bbox import Bbox, draw_bboxes
from ..ocr_model.model.TexTeller import TexTeller
from ..ocr_model.utils.inference import inference as latex_rec_predict
from ..ocr_model.utils.to_katex import to_katex
MAXV = 999999999
def mask_img(img, bboxes: List[Bbox], bg_color: np.ndarray) -> np.ndarray:
mask_img = img.copy()
for bbox in bboxes:
mask_img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w] = bg_color
return mask_img
def bbox_merge(sorted_bboxes: List[Bbox]) -> List[Bbox]:
if (len(sorted_bboxes) == 0):
return []
bboxes = sorted_bboxes.copy()
guard = Bbox(MAXV, bboxes[-1].p.y, -1, -1, label="guard")
bboxes.append(guard)
res = []
prev = bboxes[0]
for curr in bboxes:
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
res.append(prev)
prev = curr
else:
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
return res
def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbox]:
if latex_bboxes == []:
return ocr_bboxes
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
return ocr_bboxes
bboxes = sorted(ocr_bboxes + latex_bboxes)
######## debug #########
for idx, bbox in enumerate(bboxes):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png")
######## debug ###########
assert len(bboxes) > 1
heapq.heapify(bboxes)
res = []
candidate = heapq.heappop(bboxes)
curr = heapq.heappop(bboxes)
idx = 0
while (len(bboxes) > 0):
idx += 1
assert candidate.p.x < curr.p.x or not candidate.same_row(curr)
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x < curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text" and curr.label == "text":
candidate.w = curr.ur_point.x - candidate.p.x
curr = heapq.heappop(bboxes)
elif candidate.label != curr.label:
if candidate.label == "text":
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
curr.w = curr.ur_point.x - candidate.ur_point.x
curr.p.x = candidate.ur_point.x
heapq.heappush(bboxes, curr)
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x >= curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text":
assert curr.label != "text"
heapq.heappush(
bboxes,
Bbox(
curr.ur_point.x,
candidate.p.y,
candidate.h,
candidate.ur_point.x - curr.ur_point.x,
label="text",
confidence=candidate.confidence,
content=None
)
)
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
assert curr.label == "text"
curr = heapq.heappop(bboxes)
else:
assert False
res.append(candidate)
res.append(curr)
######## debug #########
for idx, bbox in enumerate(res):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png")
######## debug ###########
return res
def mix_inference(
img_path: str,
language: str,
infer_config,
latex_det_model,
lang_ocr_models,
latex_rec_models,
accelerator="cpu",
num_beams=1
) -> str:
'''
Input a mixed image of formula text and output str (in markdown syntax)
'''
global img
img = cv2.imread(img_path)
corners = [tuple(img[0, 0]), tuple(img[0, -1]),
tuple(img[-1, 0]), tuple(img[-1, -1])]
bg_color = np.array(Counter(corners).most_common(1)[0][0])
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
latex_bboxes = sorted(latex_bboxes)
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
latex_bboxes = bbox_merge(latex_bboxes)
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
masked_img = mask_img(img, latex_bboxes, bg_color)
det_model, det_processor, rec_model, rec_processor = lang_ocr_models
images = [Image.fromarray(masked_img)]
det_prediction = batch_text_detection(images, det_model, det_processor)[0]
draw_bboxes(Image.fromarray(img), latex_bboxes, name="ocr_bboxes(unmerged).png")
lang = [language]
slice_map = []
all_slices = []
all_langs = []
ocr_bboxes = [
Bbox(
p.bbox[0], p.bbox[1], p.bbox[3] - p.bbox[1], p.bbox[2] - p.bbox[0],
label="text",
confidence=p.confidence,
content=None
)
for p in det_prediction.bboxes
]
ocr_bboxes = sorted(ocr_bboxes)
ocr_bboxes = bbox_merge(ocr_bboxes)
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
polygons = [
[
[bbox.ul_point.x, bbox.ul_point.y],
[bbox.ur_point.x, bbox.ur_point.y],
[bbox.lr_point.x, bbox.lr_point.y],
[bbox.ll_point.x, bbox.ll_point.y]
]
for bbox in ocr_bboxes
]
slices = slice_polys_from_image(images[0], polygons)
slice_map.append(len(slices))
all_slices.extend(slices)
all_langs.extend([lang] * len(slices))
rec_predictions, _ = batch_recognition(all_slices, all_langs, rec_model, rec_processor)
assert len(rec_predictions) == len(ocr_bboxes)
for content, bbox in zip(rec_predictions, ocr_bboxes):
bbox.content = content
latex_imgs =[]
for bbox in latex_bboxes:
latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w])
latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=200)
for bbox, content in zip(latex_bboxes, latex_rec_res):
bbox.content = to_katex(content)
if bbox.label == "embedding":
bbox.content = " $" + bbox.content + "$ "
elif bbox.label == "isolated":
bbox.content = '\n' + r"$$" + bbox.content + r"$$" + '\n'
bboxes = sorted(ocr_bboxes + latex_bboxes)
if bboxes == []:
return ""
md = ""
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
# prev = bboxes[0]
for curr in bboxes:
if not prev.same_row(curr):
md += "\n"
md += curr.content
if (
prev.label == "isolated"
and curr.label == "text"
and bool(re.fullmatch(r"\([1-9]\d*?\)", curr.content))
):
md += '\n'
prev = curr
return md
if __name__ == '__main__':
img_path = "/Users/Leehy/Code/TexTeller/test3.png"
# latex_det_model = InferenceSession("/Users/Leehy/Code/TexTeller/src/models/det_model/model/rtdetr_r50vd_6x_coco_trained_on_IBEM_en_papers.onnx")
latex_det_model = InferenceSession("/Users/Leehy/Code/TexTeller/src/models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
infer_config = PredictConfig("/Users/Leehy/Code/TexTeller/src/models/det_model/model/infer_cfg.yml")
det_processor, det_model = segformer.load_processor(), segformer.load_model()
rec_model, rec_processor = load_model(), load_processor()
lang_ocr_models = (det_model, det_processor, rec_model, rec_processor)
texteller = TexTeller.from_pretrained()
tokenizer = TexTeller.get_tokenizer()
latex_rec_models = (texteller, tokenizer)
res = mix_inference(img_path, "zh", infer_config, latex_det_model, lang_ocr_models, latex_rec_models)
print(res)
pause = 1

View File

@@ -2,7 +2,7 @@ import os
import argparse import argparse
import cv2 as cv import cv2 as cv
from pathlib import Path from pathlib import Path
from utils import to_katex from models.ocr_model.utils.to_katex import to_katex
from models.ocr_model.utils.inference import inference as latex_inference from models.ocr_model.utils.inference import inference as latex_inference
from models.ocr_model.model.TexTeller import TexTeller from models.ocr_model.model.TexTeller import TexTeller
@@ -46,7 +46,7 @@ if __name__ == '__main__':
if img is not None: if img is not None:
print(f'Inference for {filename}...') print(f'Inference for {filename}...')
res = latex_inference(latex_rec_model, tokenizer, [img], inf_mode=args.inference_mode, num_beams=args.num_beam) res = latex_inference(latex_rec_model, tokenizer, [img], accelerator=args.inference_mode, num_beams=args.num_beam)
res = to_katex(res[0]) res = to_katex(res[0])
# Save the recognition result to a text file # Save the recognition result to a text file

View File

@@ -56,7 +56,7 @@ class TexTellerServer:
def predict(self, image_nparray) -> str: def predict(self, image_nparray) -> str:
return inference( return inference(
self.model, self.tokenizer, [image_nparray], self.model, self.tokenizer, [image_nparray],
inf_mode=self.inf_mode, num_beams=self.num_beams accelerator=self.inf_mode, num_beams=self.num_beams
)[0] )[0]

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash #!/usr/bin/env bash
set -exu set -exu
export CHECKPOINT_DIR="/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-648000" export CHECKPOINT_DIR="default"
export TOKENIZER_DIR="default" export TOKENIZER_DIR="default"
streamlit run web.py streamlit run web.py

View File

@@ -1 +0,0 @@
from .to_katex import to_katex

View File

@@ -1,15 +0,0 @@
import numpy as np
import re
def to_katex(formula: str) -> str:
res = formula
res = re.sub(r'\\mbox\{([^}]*)\}', r'\1', res)
res = re.sub(r'boldmath\$(.*?)\$', r'bm{\1}', res)
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
pattern = r'(\\(?:left|middle|right|big|Big|bigg|Bigg|bigl|Bigl|biggl|Biggl|bigm|Bigm|biggm|Biggm|bigr|Bigr|biggr|Biggr))\{([^}]*)\}'
replacement = r'\1\2'
res = re.sub(pattern, replacement, res)
if res.endswith(r'\newline'):
res = res[:-8]
return res

View File

@@ -7,10 +7,18 @@ import streamlit as st
from PIL import Image from PIL import Image
from streamlit_paste_button import paste_image_button as pbutton from streamlit_paste_button import paste_image_button as pbutton
from models.ocr_model.utils.inference import inference from onnxruntime import InferenceSession
from models.ocr_model.model.TexTeller import TexTeller
from utils import to_katex
from models.utils import mix_inference
from models.det_model.inference import PredictConfig
from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.inference import inference as latex_recognition
from models.ocr_model.utils.to_katex import to_katex
from surya.model.detection import segformer
from surya.model.recognition.model import load_model
from surya.model.recognition.processor import load_processor
st.set_page_config( st.set_page_config(
page_title="TexTeller", page_title="TexTeller",
@@ -42,13 +50,26 @@ fail_gif_html = '''
''' '''
@st.cache_resource @st.cache_resource
def get_model(): def get_texteller():
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR']) return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
@st.cache_resource @st.cache_resource
def get_tokenizer(): def get_tokenizer():
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR']) return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
@st.cache_resource
def get_det_models():
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
return infer_config, latex_det_model
@st.cache_resource()
def get_ocr_models():
det_processor, det_model = segformer.load_processor(), segformer.load_model()
rec_model, rec_processor = load_model(), load_processor()
lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
return lang_ocr_models
def get_image_base64(img_file): def get_image_base64(img_file):
buffered = io.BytesIO() buffered = io.BytesIO()
img_file.seek(0) img_file.seek(0)
@@ -62,9 +83,6 @@ def on_file_upload():
def change_side_bar(): def change_side_bar():
st.session_state["CHANGE_SIDEBAR_FLAG"] = True st.session_state["CHANGE_SIDEBAR_FLAG"] = True
model = get_model()
tokenizer = get_tokenizer()
if "start" not in st.session_state: if "start" not in st.session_state:
st.session_state["start"] = 1 st.session_state["start"] = 1
st.toast('Hooray!', icon='🎉') st.toast('Hooray!', icon='🎉')
@@ -75,31 +93,34 @@ if "UPLOADED_FILE_CHANGED" not in st.session_state:
if "CHANGE_SIDEBAR_FLAG" not in st.session_state: if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False st.session_state["CHANGE_SIDEBAR_FLAG"] = False
if "INF_MODE" not in st.session_state:
st.session_state["INF_MODE"] = "Formula only"
# ============================ begin sidebar =============================== # # ============================ begin sidebar =============================== #
with st.sidebar: with st.sidebar:
num_beams = 1 num_beams = 1
inf_mode = 'cpu'
st.markdown("# 🔨️ Config") st.markdown("# 🔨️ Config")
st.markdown("") st.markdown("")
model_type = st.selectbox( inf_mode = st.selectbox(
"Model type", "Inference mode",
("TexTeller", "None"), ("Formula only", "Text formula mixed"),
on_change=change_side_bar on_change=change_side_bar
) )
if model_type == "TexTeller":
num_beams = st.number_input(
'Number of beams',
min_value=1,
max_value=20,
step=1,
on_change=change_side_bar
)
inf_mode = st.radio( num_beams = st.number_input(
"Inference mode", 'Number of beams',
min_value=1,
max_value=20,
step=1,
on_change=change_side_bar
)
accelerator = st.radio(
"Accelerator",
("cpu", "cuda", "mps"), ("cpu", "cuda", "mps"),
on_change=change_side_bar on_change=change_side_bar
) )
@@ -107,9 +128,16 @@ with st.sidebar:
# ============================ end sidebar =============================== # # ============================ end sidebar =============================== #
# ============================ begin pages =============================== # # ============================ begin pages =============================== #
texteller = get_texteller()
tokenizer = get_tokenizer()
latex_rec_models = [texteller, tokenizer]
if inf_mode == "Text formula mixed":
infer_config, latex_det_model = get_det_models()
lang_ocr_models = get_ocr_models()
st.markdown(html_string, unsafe_allow_html=True) st.markdown(html_string, unsafe_allow_html=True)
uploaded_file = st.file_uploader( uploaded_file = st.file_uploader(
@@ -176,19 +204,26 @@ elif uploaded_file or paste_result.image_data is not None:
st.write("") st.write("")
with st.spinner("Predicting..."): with st.spinner("Predicting..."):
uploaded_file.seek(0) if inf_mode == "Formula only":
TexTeller_result = inference( TexTeller_result = latex_recognition(
model, texteller,
tokenizer, tokenizer,
[png_file_path], [png_file_path],
inf_mode=inf_mode, accelerator=accelerator,
num_beams=num_beams num_beams=num_beams
)[0] )[0]
katex_res = to_katex(TexTeller_result)
else:
katex_res = mix_inference(png_file_path, "en", infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams)
st.success('Completed!', icon="") st.success('Completed!', icon="")
st.markdown(suc_gif_html, unsafe_allow_html=True) st.markdown(suc_gif_html, unsafe_allow_html=True)
katex_res = to_katex(TexTeller_result)
st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150) st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150)
st.latex(katex_res)
if inf_mode == "Formula only":
st.latex(katex_res)
elif inf_mode == "Text formula mixed":
st.markdown(katex_res)
st.write("") st.write("")
st.write("") st.write("")