1) 实现了文本-公式混排识别; 2) 重构了项目结构

This commit is contained in:
三洋三洋
2024-04-21 00:05:14 +08:00
parent eab6e4c85d
commit 185b2e3db6
19 changed files with 753 additions and 296 deletions

6
.gitignore vendored
View File

@@ -6,6 +6,8 @@
**/ckpt
**/*cache
**/.cache
**/tmp
**/log
**/data
**/logs
@@ -13,3 +15,7 @@
**/data
**/*cache
**/ckpt
**/*.bin
**/*.safetensor
**/*.onnx

View File

@@ -14,4 +14,4 @@ onnxruntime
streamlit==1.30
streamlit-paste-button
easyocr
surya-ocr

View File

@@ -1,75 +0,0 @@
import os
import gradio as gr
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
from utils import to_katex
from pathlib import Path
# model = TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
# tokenizer = TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
css = """
<style>
.container {
display: flex;
align-items: center;
justify-content: center;
font-size: 20px;
font-family: 'Arial';
}
.container img {
height: auto;
}
.text {
margin: 0 15px;
}
h1 {
text-align: center;
font-size: 50px !important;
}
.markdown-style {
color: #333; /* 调整颜色 */
line-height: 1.6; /* 行间距 */
font-size: 50px;
}
.markdown-style h1, .markdown-style h2, .markdown-style h3 {
color: #007BFF; /* 为标题元素指定颜色 */
}
.markdown-style p {
margin-bottom: 1em; /* 段落间距 */
}
</style>
"""
theme=gr.themes.Default(),
def fn(img):
return img
with gr.Blocks(
theme=theme,
css=css
) as demo:
gr.HTML(f'''
{css}
<div class="container">
<img src="https://github.com/OleehyO/TexTeller/raw/main/assets/fire.svg" width="100">
<h1> 𝚃𝚎𝚡𝚃𝚎𝚕𝚕𝚎𝚛 </h1>
<img src="https://github.com/OleehyO/TexTeller/raw/main/assets/fire.svg" width="100">
</div>
''')
with gr.Row(equal_height=True):
input_img = gr.Image(type="pil", label="Input Image")
latex_img = gr.Image(label="Predicted Latex", show_label=False)
input_img.upload(fn, input_img, latex_img)
gr.Markdown(r'$$\fcxrac{7}{10349}$$')
gr.Markdown('fooooooooooooooooooooooooooooo')
demo.launch()

View File

@@ -1,35 +1,21 @@
import os
import yaml
import argparse
import numpy as np
import glob
from onnxruntime import InferenceSession
from tqdm import tqdm
from pathlib import Path
from models.det_model.inference import PredictConfig, predict_image
from models.det_model.preprocess import Compose
import cv2
# Global dictionary
SUPPORT_MODELS = {
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet',
'DETR'
}
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml",
default="./models/det_model/model/infer_cfg.yml"
)
default="./models/det_model/model/infer_cfg.yml")
parser.add_argument('--onnx_file', type=str, help="onnx model file path",
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx"
)
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
parser.add_argument("--image_dir", type=str)
parser.add_argument("--image_file", type=str, default='/data/ljm/TexTeller/src/Tr00_0001015-page02.jpg')
parser.add_argument("--imgsave_dir", type=str,
default="."
)
parser.add_argument("--image_file", type=str, required=True)
parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
def get_test_images(infer_dir, infer_img):
"""
@@ -62,125 +48,11 @@ def get_test_images(infer_dir, infer_img):
return images
class PredictConfig(object):
"""set config of preprocess, postprocess and visualize
Args:
infer_config (str): path of infer_cfg.yml
"""
def __init__(self, infer_config):
# parsing Yaml config for Preprocess
with open(infer_config) as f:
yml_conf = yaml.safe_load(f)
self.check_model(yml_conf)
self.arch = yml_conf['arch']
self.preprocess_infos = yml_conf['Preprocess']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.label_list = yml_conf['label_list']
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
self.mask = yml_conf.get("mask", False)
self.tracker = yml_conf.get("tracker", None)
self.nms = yml_conf.get("NMS", None)
self.fpn_stride = yml_conf.get("fpn_stride", None)
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)}
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
print(
'The RCNN export model is used for ONNX and it only supports batch_size = 1'
)
self.print_config()
def check_model(self, yml_conf):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for support_model in SUPPORT_MODELS:
if support_model in yml_conf['arch']:
return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
'arch'], SUPPORT_MODELS))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def draw_bbox(image, outputs, infer_config):
for output in outputs:
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
color = infer_config.colors[label]
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
cv2.putText(image, "{}: {:.2f}".format(label, score),
(int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return image
def predict_image(infer_config, predictor, img_list):
# load preprocess transforms
transforms = Compose(infer_config.preprocess_infos)
errImgList = []
# Check and create subimg_save_dir if not exist
subimg_save_dir = os.path.join(FLAGS.imgsave_dir, 'subimages')
os.makedirs(subimg_save_dir, exist_ok=True)
# predict image
for img_path in tqdm(img_list):
img = cv2.imread(img_path)
if img is None:
print(f"Warning: Could not read image {img_path}. Skipping...")
errImgList.append(img_path)
continue
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None, ] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)
print("ONNXRuntime predict: ")
if infer_config.arch in ["HRNet"]:
print(np.array(outputs[0]))
else:
bboxes = np.array(outputs[0])
for bbox in bboxes:
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
print(f"{int(bbox[0])} {bbox[1]} "
f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
# Save the subimages (crop from the original image)
subimg_counter = 1
for output in np.array(outputs[0]):
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
subimg = img[int(ymin):int(ymax), int(xmin):int(xmax)]
subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg"
subimg_path = os.path.join(subimg_save_dir, subimg_filename)
cv2.imwrite(subimg_path, subimg)
subimg_counter += 1
# Draw bounding boxes and save the image with bounding boxes
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
output_dir = FLAGS.imgsave_dir
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, "output_" + os.path.basename(img_path))
cv2.imwrite(output_file, img_with_bbox)
print("ErrorImgs:")
print(errImgList)
if __name__ == '__main__':
cur_path = os.getcwd()
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
FLAGS = parser.parse_args()
# load image list
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
@@ -189,4 +61,6 @@ if __name__ == '__main__':
# load infer config
infer_config = PredictConfig(FLAGS.infer_cfg)
predict_image(infer_config, predictor, img_list)
predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list)
os.chdir(cur_path)

View File

@@ -3,9 +3,18 @@ import argparse
import cv2 as cv
from pathlib import Path
from utils import to_katex
from onnxruntime import InferenceSession
from models.utils import mix_inference
from models.ocr_model.utils.to_katex import to_katex
from models.ocr_model.utils.inference import inference as latex_inference
from models.ocr_model.model.TexTeller import TexTeller
from models.det_model.inference import PredictConfig
from surya.model.detection import segformer
from surya.model.recognition.model import load_model
from surya.model.recognition.processor import load_processor
if __name__ == '__main__':
@@ -29,33 +38,35 @@ if __name__ == '__main__':
default=1,
help='number of beam search for decoding'
)
# ================= new feature ==================
parser.add_argument(
'-mix',
type=str,
help='use mix mode, only Chinese and English are supported.'
action='store_true',
help='use mix mode'
)
# ==================================================
args = parser.parse_args()
# You can use your own checkpoint and tokenizer path.
print('Loading model and tokenizer...')
latex_rec_model = TexTeller.from_pretrained()
latex_rec_model = TexTeller.from_pretrained()
tokenizer = TexTeller.get_tokenizer()
print('Model and tokenizer loaded.')
# img_path = [args.img]
img = cv.imread(args.img)
img_path = args.img
img = cv.imread(img_path)
print('Inference...')
if not args.mix:
res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam)
res = to_katex(res[0])
print(res)
else:
# latex_det_model = load_det_tex_model()
# lang_model = load_lang_models()...
...
# res: str = mix_inference(latex_det_model, latex_rec_model, lang_model, img, args.cuda)
# print(res)
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco_IBEM_cnTextBook.onnx")
det_processor, det_model = segformer.load_processor(), segformer.load_model()
rec_model, rec_processor = load_model(), load_processor()
lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
latex_rec_models = [latex_rec_model, tokenizer]
res = mix_inference(img_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
print(res)

View File

@@ -0,0 +1,85 @@
from PIL import Image, ImageDraw
from typing import List
class Point:
def __init__(self, x: int, y: int):
self.x = int(x)
self.y = int(y)
def __repr__(self) -> str:
return f"Point(x={self.x}, y={self.y})"
class Bbox:
THREADHOLD = 0.4
def __init__(self, x, y, h, w, label: str = None, confidence: float = 0, content: str = None):
self.p = Point(x, y)
self.h = int(h)
self.w = int(w)
self.label = label
self.confidence = confidence
self.content = content
@property
def ul_point(self) -> Point:
return self.p
@property
def ur_point(self) -> Point:
return Point(self.p.x + self.w, self.p.y)
@property
def ll_point(self) -> Point:
return Point(self.p.x, self.p.y + self.h)
@property
def lr_point(self) -> Point:
return Point(self.p.x + self.w, self.p.y + self.h)
def same_row(self, other) -> bool:
if (
(self.p.y >= other.p.y and self.ll_point.y <= other.ll_point.y)
or (self.p.y <= other.p.y and self.ll_point.y >= other.ll_point.y)
):
return True
if self.ll_point.y <= other.p.y or self.p.y >= other.ll_point.y:
return False
return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD
def __lt__(self, other) -> bool:
'''
from top to bottom, from left to right
'''
if not self.same_row(other):
return self.p.y < other.p.y
else:
return self.p.x < other.p.x
def __repr__(self) -> str:
return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})"
def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"):
drawer = ImageDraw.Draw(img)
for bbox in bboxes:
# Calculate the coordinates for the rectangle to be drawn
left = bbox.p.x
top = bbox.p.y
right = bbox.p.x + bbox.w
bottom = bbox.p.y + bbox.h
# Draw the rectangle on the image
drawer.rectangle([left, top, right, bottom], outline="green", width=1)
# Optionally, add text label if it exists
if bbox.label:
drawer.text((left, top), bbox.label, fill="blue")
if bbox.content:
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
# Save the image with drawn rectangles
img.save(name)

View File

@@ -0,0 +1,161 @@
import os
import yaml
import numpy as np
import cv2
from tqdm import tqdm
from typing import List
from .preprocess import Compose
from .Bbox import Bbox
# Global dictionary
SUPPORT_MODELS = {
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet',
'DETR'
}
class PredictConfig(object):
"""set config of preprocess, postprocess and visualize
Args:
infer_config (str): path of infer_cfg.yml
"""
def __init__(self, infer_config):
# parsing Yaml config for Preprocess
with open(infer_config) as f:
yml_conf = yaml.safe_load(f)
self.check_model(yml_conf)
self.arch = yml_conf['arch']
self.preprocess_infos = yml_conf['Preprocess']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.label_list = yml_conf['label_list']
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
self.mask = yml_conf.get("mask", False)
self.tracker = yml_conf.get("tracker", None)
self.nms = yml_conf.get("NMS", None)
self.fpn_stride = yml_conf.get("fpn_stride", None)
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)}
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
print(
'The RCNN export model is used for ONNX and it only supports batch_size = 1'
)
self.print_config()
def check_model(self, yml_conf):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for support_model in SUPPORT_MODELS:
if support_model in yml_conf['arch']:
return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
'arch'], SUPPORT_MODELS))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def draw_bbox(image, outputs, infer_config):
for output in outputs:
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
color = infer_config.colors[label]
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
cv2.putText(image, "{}: {:.2f}".format(label, score),
(int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return image
def predict_image(imgsave_dir, infer_config, predictor, img_list):
# load preprocess transforms
transforms = Compose(infer_config.preprocess_infos)
errImgList = []
# Check and create subimg_save_dir if not exist
subimg_save_dir = os.path.join(imgsave_dir, 'subimages')
os.makedirs(subimg_save_dir, exist_ok=True)
# predict image
for img_path in tqdm(img_list):
img = cv2.imread(img_path)
if img is None:
print(f"Warning: Could not read image {img_path}. Skipping...")
errImgList.append(img_path)
continue
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None, ] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)
print("ONNXRuntime predict: ")
if infer_config.arch in ["HRNet"]:
print(np.array(outputs[0]))
else:
bboxes = np.array(outputs[0])
for bbox in bboxes:
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
print(f"{int(bbox[0])} {bbox[1]} "
f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
# Save the subimages (crop from the original image)
subimg_counter = 1
for output in np.array(outputs[0]):
cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)]
subimg = img[int(max(ymin, 0)):int(ymax), int(max(xmin, 0)):int(xmax)]
if len(subimg) == 0:
continue
subimg_filename = f"{os.path.splitext(os.path.basename(img_path))[0]}_{label}_{xmin:.2f}_{ymin:.2f}_{xmax:.2f}_{ymax:.2f}.jpg"
subimg_path = os.path.join(subimg_save_dir, subimg_filename)
cv2.imwrite(subimg_path, subimg)
subimg_counter += 1
# Draw bounding boxes and save the image with bounding boxes
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
output_dir = imgsave_dir
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, "output_" + os.path.basename(img_path))
cv2.imwrite(output_file, img_with_bbox)
print("ErrorImgs:")
print(errImgList)
def predict(img_path: str, predictor, infer_config) -> List[Bbox]:
transforms = Compose(infer_config.preprocess_infos)
inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None, ] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)[0]
res = []
for output in outputs:
cls_name = infer_config.label_list[int(output[0])]
score = output[1]
xmin = int(max(output[2], 0))
ymin = int(max(output[3], 0))
xmax = int(output[4])
ymax = int(output[5])
if score > infer_config.draw_threshold:
res.append(Bbox(xmin, ymin, ymax - ymin, xmax - xmin, cls_name, score))
return res

View File

@@ -8,8 +8,8 @@ Preprocess:
- interp: 2
keep_ratio: false
target_size:
- 640
- 640
- 1600
- 1600
type: Resize
- mean:
- 0.0

View File

@@ -4,9 +4,14 @@ import copy
def decode_image(img_path):
with open(img_path, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
if isinstance(img_path, str):
with open(img_path, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
else:
assert isinstance(img_path, np.ndarray)
data = img_path
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
img_info = {

View File

@@ -4,19 +4,22 @@ import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List, Union
from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.transforms import inference_transform
from models.ocr_model.utils.helpers import convert2rgb
from models.globals import MAX_TOKEN_SIZE
from .transforms import inference_transform
from .helpers import convert2rgb
from ..model.TexTeller import TexTeller
from ...globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs: Union[List[str], List[np.ndarray]],
inf_mode: str = 'cpu',
accelerator: str = 'cpu',
num_beams: int = 1,
max_tokens = None
) -> List[str]:
if imgs == []:
return []
model.eval()
if isinstance(imgs[0], str):
imgs = convert2rgb(imgs)
@@ -26,11 +29,11 @@ def inference(
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
model = model.to(inf_mode)
pixel_values = pixel_values.to(inf_mode)
model = model.to(accelerator)
pixel_values = pixel_values.to(accelerator)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,

View File

@@ -0,0 +1,110 @@
import re
def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r):
result = ""
i = 0
n = len(input_str)
while i < n:
if input_str[i:i+len(old_inst)] == old_inst:
# check if the old_inst is followed by old_surr_l
start = i + len(old_inst)
else:
result += input_str[i]
i += 1
continue
if start < n and input_str[start] == old_surr_l:
# found an old_inst followed by old_surr_l, now look for the matching old_surr_r
count = 1
j = start + 1
escaped = False
while j < n and count > 0:
if input_str[j] == '\\' and not escaped:
escaped = True
j += 1
continue
if input_str[j] == old_surr_r and not escaped:
count -= 1
if count == 0:
break
elif input_str[j] == old_surr_l and not escaped:
count += 1
escaped = False
j += 1
if count == 0:
assert j < n
assert input_str[start] == old_surr_l
assert input_str[j] == old_surr_r
inner_content = input_str[start + 1:j]
# Replace the content with new pattern
result += new_inst + new_surr_l + inner_content + new_surr_r
i = j + 1
continue
else:
assert count > 1
assert j == n
print("Warning: unbalanced surrogate pair in input string")
result += new_inst + new_surr_l
i = start + 1
continue
else:
i = start
if old_inst != new_inst and old_inst in result:
return change(result, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, new_surr_r)
else:
return result
def to_katex(formula: str) -> str:
res = formula
res = change(res, r'\mbox', r'', r'{', r'}', r'', r'')
origin_instructions = [
r'\Huge',
r'\huge',
r'\LARGE',
r'\Large',
r'\large',
r'\normalsize',
r'\small',
r'\footnotesize',
r'\scriptsize',
r'\tiny'
]
for (old_ins, new_ins) in zip(origin_instructions, origin_instructions):
res = change(res, old_ins, new_ins, r'$', r'$', '{', '}')
res = change(res, r'\boldmath', r'\bm', r'$', r'$', r'{', r'}')
origin_instructions = [
r'\left',
r'\middle',
r'\right',
r'\big',
r'\Big',
r'\bigg',
r'\Bigg',
r'\bigl',
r'\Bigl',
r'\biggl',
r'\Biggl',
r'\bigm',
r'\Bigm',
r'\biggm',
r'\Biggm',
r'\bigr',
r'\Bigr',
r'\biggr',
r'\Biggr'
]
for origin_ins in origin_instructions:
res = change(res, origin_ins, origin_ins, r'{', r'}', r'', r'')
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
if res.endswith(r'\newline'):
res = res[:-8]
return res

View File

@@ -0,0 +1 @@
from .mix_inference import mix_inference

View File

@@ -0,0 +1,257 @@
import re
import heapq
import cv2
import numpy as np
from onnxruntime import InferenceSession
from collections import Counter
from typing import List
from PIL import Image
from surya.ocr import run_ocr
from surya.detection import batch_text_detection
from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image
from surya.recognition import batch_recognition
from surya.model.detection import segformer
from surya.model.recognition.model import load_model
from surya.model.recognition.processor import load_processor
from ..det_model.inference import PredictConfig
from ..det_model.inference import predict as latex_det_predict
from ..det_model.Bbox import Bbox, draw_bboxes
from ..ocr_model.model.TexTeller import TexTeller
from ..ocr_model.utils.inference import inference as latex_rec_predict
from ..ocr_model.utils.to_katex import to_katex
MAXV = 999999999
def mask_img(img, bboxes: List[Bbox], bg_color: np.ndarray) -> np.ndarray:
mask_img = img.copy()
for bbox in bboxes:
mask_img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w] = bg_color
return mask_img
def bbox_merge(sorted_bboxes: List[Bbox]) -> List[Bbox]:
if (len(sorted_bboxes) == 0):
return []
bboxes = sorted_bboxes.copy()
guard = Bbox(MAXV, bboxes[-1].p.y, -1, -1, label="guard")
bboxes.append(guard)
res = []
prev = bboxes[0]
for curr in bboxes:
if prev.ur_point.x <= curr.p.x or not prev.same_row(curr):
res.append(prev)
prev = curr
else:
prev.w = max(prev.w, curr.ur_point.x - prev.p.x)
return res
def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbox]:
if latex_bboxes == []:
return ocr_bboxes
if ocr_bboxes == [] or len(ocr_bboxes) == 1:
return ocr_bboxes
bboxes = sorted(ocr_bboxes + latex_bboxes)
######## debug #########
for idx, bbox in enumerate(bboxes):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), bboxes, name="before_split_confict.png")
######## debug ###########
assert len(bboxes) > 1
heapq.heapify(bboxes)
res = []
candidate = heapq.heappop(bboxes)
curr = heapq.heappop(bboxes)
idx = 0
while (len(bboxes) > 0):
idx += 1
assert candidate.p.x < curr.p.x or not candidate.same_row(curr)
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x < curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text" and curr.label == "text":
candidate.w = curr.ur_point.x - candidate.p.x
curr = heapq.heappop(bboxes)
elif candidate.label != curr.label:
if candidate.label == "text":
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
curr.w = curr.ur_point.x - candidate.ur_point.x
curr.p.x = candidate.ur_point.x
heapq.heappush(bboxes, curr)
curr = heapq.heappop(bboxes)
elif candidate.ur_point.x >= curr.ur_point.x:
assert not (candidate.label != "text" and curr.label != "text")
if candidate.label == "text":
assert curr.label != "text"
heapq.heappush(
bboxes,
Bbox(
curr.ur_point.x,
candidate.p.y,
candidate.h,
candidate.ur_point.x - curr.ur_point.x,
label="text",
confidence=candidate.confidence,
content=None
)
)
candidate.w = curr.p.x - candidate.p.x
res.append(candidate)
candidate = curr
curr = heapq.heappop(bboxes)
else:
assert curr.label == "text"
curr = heapq.heappop(bboxes)
else:
assert False
res.append(candidate)
res.append(curr)
######## debug #########
for idx, bbox in enumerate(res):
bbox.content = str(idx)
draw_bboxes(Image.fromarray(img), res, name="after_split_confict.png")
######## debug ###########
return res
def mix_inference(
img_path: str,
language: str,
infer_config,
latex_det_model,
lang_ocr_models,
latex_rec_models,
accelerator="cpu",
num_beams=1
) -> str:
'''
Input a mixed image of formula text and output str (in markdown syntax)
'''
global img
img = cv2.imread(img_path)
corners = [tuple(img[0, 0]), tuple(img[0, -1]),
tuple(img[-1, 0]), tuple(img[-1, -1])]
bg_color = np.array(Counter(corners).most_common(1)[0][0])
latex_bboxes = latex_det_predict(img_path, latex_det_model, infer_config)
latex_bboxes = sorted(latex_bboxes)
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(unmerged).png")
latex_bboxes = bbox_merge(latex_bboxes)
draw_bboxes(Image.fromarray(img), latex_bboxes, name="latex_bboxes(merged).png")
masked_img = mask_img(img, latex_bboxes, bg_color)
det_model, det_processor, rec_model, rec_processor = lang_ocr_models
images = [Image.fromarray(masked_img)]
det_prediction = batch_text_detection(images, det_model, det_processor)[0]
draw_bboxes(Image.fromarray(img), latex_bboxes, name="ocr_bboxes(unmerged).png")
lang = [language]
slice_map = []
all_slices = []
all_langs = []
ocr_bboxes = [
Bbox(
p.bbox[0], p.bbox[1], p.bbox[3] - p.bbox[1], p.bbox[2] - p.bbox[0],
label="text",
confidence=p.confidence,
content=None
)
for p in det_prediction.bboxes
]
ocr_bboxes = sorted(ocr_bboxes)
ocr_bboxes = bbox_merge(ocr_bboxes)
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
polygons = [
[
[bbox.ul_point.x, bbox.ul_point.y],
[bbox.ur_point.x, bbox.ur_point.y],
[bbox.lr_point.x, bbox.lr_point.y],
[bbox.ll_point.x, bbox.ll_point.y]
]
for bbox in ocr_bboxes
]
slices = slice_polys_from_image(images[0], polygons)
slice_map.append(len(slices))
all_slices.extend(slices)
all_langs.extend([lang] * len(slices))
rec_predictions, _ = batch_recognition(all_slices, all_langs, rec_model, rec_processor)
assert len(rec_predictions) == len(ocr_bboxes)
for content, bbox in zip(rec_predictions, ocr_bboxes):
bbox.content = content
latex_imgs =[]
for bbox in latex_bboxes:
latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w])
latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=200)
for bbox, content in zip(latex_bboxes, latex_rec_res):
bbox.content = to_katex(content)
if bbox.label == "embedding":
bbox.content = " $" + bbox.content + "$ "
elif bbox.label == "isolated":
bbox.content = '\n' + r"$$" + bbox.content + r"$$" + '\n'
bboxes = sorted(ocr_bboxes + latex_bboxes)
if bboxes == []:
return ""
md = ""
prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard")
# prev = bboxes[0]
for curr in bboxes:
if not prev.same_row(curr):
md += "\n"
md += curr.content
if (
prev.label == "isolated"
and curr.label == "text"
and bool(re.fullmatch(r"\([1-9]\d*?\)", curr.content))
):
md += '\n'
prev = curr
return md
if __name__ == '__main__':
img_path = "/Users/Leehy/Code/TexTeller/test3.png"
# latex_det_model = InferenceSession("/Users/Leehy/Code/TexTeller/src/models/det_model/model/rtdetr_r50vd_6x_coco_trained_on_IBEM_en_papers.onnx")
latex_det_model = InferenceSession("/Users/Leehy/Code/TexTeller/src/models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
infer_config = PredictConfig("/Users/Leehy/Code/TexTeller/src/models/det_model/model/infer_cfg.yml")
det_processor, det_model = segformer.load_processor(), segformer.load_model()
rec_model, rec_processor = load_model(), load_processor()
lang_ocr_models = (det_model, det_processor, rec_model, rec_processor)
texteller = TexTeller.from_pretrained()
tokenizer = TexTeller.get_tokenizer()
latex_rec_models = (texteller, tokenizer)
res = mix_inference(img_path, "zh", infer_config, latex_det_model, lang_ocr_models, latex_rec_models)
print(res)
pause = 1

View File

@@ -2,7 +2,7 @@ import os
import argparse
import cv2 as cv
from pathlib import Path
from utils import to_katex
from models.ocr_model.utils.to_katex import to_katex
from models.ocr_model.utils.inference import inference as latex_inference
from models.ocr_model.model.TexTeller import TexTeller
@@ -46,7 +46,7 @@ if __name__ == '__main__':
if img is not None:
print(f'Inference for {filename}...')
res = latex_inference(latex_rec_model, tokenizer, [img], inf_mode=args.inference_mode, num_beams=args.num_beam)
res = latex_inference(latex_rec_model, tokenizer, [img], accelerator=args.inference_mode, num_beams=args.num_beam)
res = to_katex(res[0])
# Save the recognition result to a text file

View File

@@ -56,7 +56,7 @@ class TexTellerServer:
def predict(self, image_nparray) -> str:
return inference(
self.model, self.tokenizer, [image_nparray],
inf_mode=self.inf_mode, num_beams=self.num_beams
accelerator=self.inf_mode, num_beams=self.num_beams
)[0]

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env bash
set -exu
export CHECKPOINT_DIR="/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-648000"
export CHECKPOINT_DIR="default"
export TOKENIZER_DIR="default"
streamlit run web.py

View File

@@ -1 +0,0 @@
from .to_katex import to_katex

View File

@@ -1,15 +0,0 @@
import numpy as np
import re
def to_katex(formula: str) -> str:
res = formula
res = re.sub(r'\\mbox\{([^}]*)\}', r'\1', res)
res = re.sub(r'boldmath\$(.*?)\$', r'bm{\1}', res)
res = re.sub(r'\\\[(.*?)\\\]', r'\1\\newline', res)
pattern = r'(\\(?:left|middle|right|big|Big|bigg|Bigg|bigl|Bigl|biggl|Biggl|bigm|Bigm|biggm|Biggm|bigr|Bigr|biggr|Biggr))\{([^}]*)\}'
replacement = r'\1\2'
res = re.sub(pattern, replacement, res)
if res.endswith(r'\newline'):
res = res[:-8]
return res

View File

@@ -7,10 +7,18 @@ import streamlit as st
from PIL import Image
from streamlit_paste_button import paste_image_button as pbutton
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
from utils import to_katex
from onnxruntime import InferenceSession
from models.utils import mix_inference
from models.det_model.inference import PredictConfig
from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.inference import inference as latex_recognition
from models.ocr_model.utils.to_katex import to_katex
from surya.model.detection import segformer
from surya.model.recognition.model import load_model
from surya.model.recognition.processor import load_processor
st.set_page_config(
page_title="TexTeller",
@@ -42,13 +50,26 @@ fail_gif_html = '''
'''
@st.cache_resource
def get_model():
def get_texteller():
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
@st.cache_resource
def get_tokenizer():
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
@st.cache_resource
def get_det_models():
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
return infer_config, latex_det_model
@st.cache_resource()
def get_ocr_models():
det_processor, det_model = segformer.load_processor(), segformer.load_model()
rec_model, rec_processor = load_model(), load_processor()
lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
return lang_ocr_models
def get_image_base64(img_file):
buffered = io.BytesIO()
img_file.seek(0)
@@ -62,9 +83,6 @@ def on_file_upload():
def change_side_bar():
st.session_state["CHANGE_SIDEBAR_FLAG"] = True
model = get_model()
tokenizer = get_tokenizer()
if "start" not in st.session_state:
st.session_state["start"] = 1
st.toast('Hooray!', icon='🎉')
@@ -75,31 +93,34 @@ if "UPLOADED_FILE_CHANGED" not in st.session_state:
if "CHANGE_SIDEBAR_FLAG" not in st.session_state:
st.session_state["CHANGE_SIDEBAR_FLAG"] = False
if "INF_MODE" not in st.session_state:
st.session_state["INF_MODE"] = "Formula only"
# ============================ begin sidebar =============================== #
with st.sidebar:
num_beams = 1
inf_mode = 'cpu'
st.markdown("# 🔨️ Config")
st.markdown("")
model_type = st.selectbox(
"Model type",
("TexTeller", "None"),
inf_mode = st.selectbox(
"Inference mode",
("Formula only", "Text formula mixed"),
on_change=change_side_bar
)
if model_type == "TexTeller":
num_beams = st.number_input(
'Number of beams',
min_value=1,
max_value=20,
step=1,
on_change=change_side_bar
)
inf_mode = st.radio(
"Inference mode",
num_beams = st.number_input(
'Number of beams',
min_value=1,
max_value=20,
step=1,
on_change=change_side_bar
)
accelerator = st.radio(
"Accelerator",
("cpu", "cuda", "mps"),
on_change=change_side_bar
)
@@ -107,9 +128,16 @@ with st.sidebar:
# ============================ end sidebar =============================== #
# ============================ begin pages =============================== #
texteller = get_texteller()
tokenizer = get_tokenizer()
latex_rec_models = [texteller, tokenizer]
if inf_mode == "Text formula mixed":
infer_config, latex_det_model = get_det_models()
lang_ocr_models = get_ocr_models()
st.markdown(html_string, unsafe_allow_html=True)
uploaded_file = st.file_uploader(
@@ -176,19 +204,26 @@ elif uploaded_file or paste_result.image_data is not None:
st.write("")
with st.spinner("Predicting..."):
uploaded_file.seek(0)
TexTeller_result = inference(
model,
tokenizer,
[png_file_path],
inf_mode=inf_mode,
num_beams=num_beams
)[0]
if inf_mode == "Formula only":
TexTeller_result = latex_recognition(
texteller,
tokenizer,
[png_file_path],
accelerator=accelerator,
num_beams=num_beams
)[0]
katex_res = to_katex(TexTeller_result)
else:
katex_res = mix_inference(png_file_path, "en", infer_config, latex_det_model, lang_ocr_models, latex_rec_models, accelerator, num_beams)
st.success('Completed!', icon="")
st.markdown(suc_gif_html, unsafe_allow_html=True)
katex_res = to_katex(TexTeller_result)
st.text_area(":blue[*** 𝑃r𝑒d𝑖c𝑡e𝑑 𝑓o𝑟m𝑢l𝑎 ***]", katex_res, height=150)
st.latex(katex_res)
if inf_mode == "Formula only":
st.latex(katex_res)
elif inf_mode == "Text formula mixed":
st.markdown(katex_res)
st.write("")
st.write("")