diff --git a/src/inference.py b/src/inference.py index c1a65d2..07a0cae 100644 --- a/src/inference.py +++ b/src/inference.py @@ -1,11 +1,11 @@ import os -import sys import argparse import cv2 as cv from pathlib import Path from onnxruntime import InferenceSession -from paddleocr import PaddleOCR +from models.thrid_party.paddleocr.infer import predict_det, predict_rec +from models.thrid_party.paddleocr.infer import utility from models.utils import mix_inference from models.ocr_model.utils.to_katex import to_katex @@ -41,19 +41,8 @@ if __name__ == '__main__': action='store_true', help='use mix mode' ) - parser.add_argument( - '-lang', - type=str, - default='None' - ) args = parser.parse_args() - if args.mix and args.lang == "None": - print("When -mix is set, -lang must be set (support: ['zh', 'en'])") - sys.exit(-1) - elif args.mix and args.lang not in ['zh', 'en']: - print(f"language support: ['zh', 'en'] (invalid: {args.lang})") - sys.exit(-1) # You can use your own checkpoint and tokenizer path. print('Loading model and tokenizer...') @@ -73,20 +62,24 @@ if __name__ == '__main__': latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx") use_gpu = args.inference_mode == 'cuda' - text_ocr_model = PaddleOCR( - use_angle_cls=False, lang='ch', use_gpu=use_gpu, - det_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_det_server_infer", - rec_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_rec_server_infer", - det_limit_type='max', - det_limit_side_len=1280, - use_dilation=True, - det_db_score_mode="slow", - ) # need to run only once to load model into memory + SIZE_LIMIT = 20 * 1024 * 1024 + det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx" + rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx" + # The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime) + det_use_gpu = False + rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT) - detector = text_ocr_model.text_detector - recognizer = text_ocr_model.text_recognizer + paddleocr_args = utility.parse_args() + paddleocr_args.use_onnx = True + paddleocr_args.det_model_dir = det_model_dir + paddleocr_args.rec_model_dir = rec_model_dir + + paddleocr_args.use_gpu = det_use_gpu + detector = predict_det.TextDetector(paddleocr_args) + paddleocr_args.use_gpu = rec_use_gpu + recognizer = predict_rec.TextRecognizer(paddleocr_args) lang_ocr_models = [detector, recognizer] latex_rec_models = [latex_rec_model, tokenizer] - res = mix_inference(img_path, args.lang , infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam) + res = mix_inference(img_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam) print(res) diff --git a/src/models/utils/mix_inference.py b/src/models/utils/mix_inference.py index 398bf6c..f063197 100644 --- a/src/models/utils/mix_inference.py +++ b/src/models/utils/mix_inference.py @@ -1,14 +1,13 @@ import re import heapq import cv2 +import time import numpy as np from collections import Counter from typing import List from PIL import Image -from paddleocr.ppocr.utils.utility import alpha_to_color - from ..det_model.inference import predict as latex_det_predict from ..det_model.Bbox import Bbox, draw_bboxes @@ -64,7 +63,7 @@ def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbo idx = 0 while (len(bboxes) > 0): idx += 1 - assert candidate.p.x < curr.p.x or not candidate.same_row(curr) + assert candidate.p.x <= curr.p.x or not candidate.same_row(curr) if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr): res.append(candidate) @@ -134,14 +133,8 @@ def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray return sliced_imgs -def preprocess_image(_image): - _image = alpha_to_color(_image, (255, 255, 255)) - return _image - - def mix_inference( img_path: str, - language: str, infer_config, latex_det_model, @@ -156,7 +149,6 @@ def mix_inference( ''' global img img = cv2.imread(img_path) - img = alpha_to_color(img, (255, 255, 255)) corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])] bg_color = np.array(Counter(corners).most_common(1)[0][0]) @@ -172,9 +164,6 @@ def mix_inference( det_model, rec_model = lang_ocr_models det_prediction, _ = det_model(masked_img) - # log results - draw_bboxes(Image.fromarray(img), latex_bboxes, name="ocr_bboxes(unmerged).png") - ocr_bboxes = [ Bbox( p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0], @@ -184,8 +173,12 @@ def mix_inference( ) for p in det_prediction ] + # log results + draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(unmerged).png") + ocr_bboxes = sorted(ocr_bboxes) ocr_bboxes = bbox_merge(ocr_bboxes) + # log results draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png") ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes) ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes)) @@ -193,7 +186,6 @@ def mix_inference( sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes) rec_predictions, _ = rec_model(sliced_imgs) - assert len(rec_predictions) == len(ocr_bboxes) for content, bbox in zip(rec_predictions, ocr_bboxes): bbox.content = content[0] @@ -202,6 +194,7 @@ def mix_inference( for bbox in latex_bboxes: latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w]) latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=200) + for bbox, content in zip(latex_bboxes, latex_rec_res): bbox.content = to_katex(content) if bbox.label == "embedding":