Using paddleocr with onnxruntime

Deleted the code for test time.
This commit is contained in:
三洋三洋
2024-05-27 17:05:24 +00:00
parent 85d558f772
commit 9b11689f22
2 changed files with 25 additions and 39 deletions

View File

@@ -1,11 +1,11 @@
import os import os
import sys
import argparse import argparse
import cv2 as cv import cv2 as cv
from pathlib import Path from pathlib import Path
from onnxruntime import InferenceSession from onnxruntime import InferenceSession
from paddleocr import PaddleOCR from models.thrid_party.paddleocr.infer import predict_det, predict_rec
from models.thrid_party.paddleocr.infer import utility
from models.utils import mix_inference from models.utils import mix_inference
from models.ocr_model.utils.to_katex import to_katex from models.ocr_model.utils.to_katex import to_katex
@@ -41,19 +41,8 @@ if __name__ == '__main__':
action='store_true', action='store_true',
help='use mix mode' help='use mix mode'
) )
parser.add_argument(
'-lang',
type=str,
default='None'
)
args = parser.parse_args() args = parser.parse_args()
if args.mix and args.lang == "None":
print("When -mix is set, -lang must be set (support: ['zh', 'en'])")
sys.exit(-1)
elif args.mix and args.lang not in ['zh', 'en']:
print(f"language support: ['zh', 'en'] (invalid: {args.lang})")
sys.exit(-1)
# You can use your own checkpoint and tokenizer path. # You can use your own checkpoint and tokenizer path.
print('Loading model and tokenizer...') print('Loading model and tokenizer...')
@@ -73,20 +62,24 @@ if __name__ == '__main__':
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx") latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
use_gpu = args.inference_mode == 'cuda' use_gpu = args.inference_mode == 'cuda'
text_ocr_model = PaddleOCR( SIZE_LIMIT = 20 * 1024 * 1024
use_angle_cls=False, lang='ch', use_gpu=use_gpu, det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
det_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_det_server_infer", rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
rec_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_rec_server_infer", # The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
det_limit_type='max', det_use_gpu = False
det_limit_side_len=1280, rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
use_dilation=True,
det_db_score_mode="slow",
) # need to run only once to load model into memory
detector = text_ocr_model.text_detector paddleocr_args = utility.parse_args()
recognizer = text_ocr_model.text_recognizer paddleocr_args.use_onnx = True
paddleocr_args.det_model_dir = det_model_dir
paddleocr_args.rec_model_dir = rec_model_dir
paddleocr_args.use_gpu = det_use_gpu
detector = predict_det.TextDetector(paddleocr_args)
paddleocr_args.use_gpu = rec_use_gpu
recognizer = predict_rec.TextRecognizer(paddleocr_args)
lang_ocr_models = [detector, recognizer] lang_ocr_models = [detector, recognizer]
latex_rec_models = [latex_rec_model, tokenizer] latex_rec_models = [latex_rec_model, tokenizer]
res = mix_inference(img_path, args.lang , infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam) res = mix_inference(img_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
print(res) print(res)

View File

@@ -1,14 +1,13 @@
import re import re
import heapq import heapq
import cv2 import cv2
import time
import numpy as np import numpy as np
from collections import Counter from collections import Counter
from typing import List from typing import List
from PIL import Image from PIL import Image
from paddleocr.ppocr.utils.utility import alpha_to_color
from ..det_model.inference import predict as latex_det_predict from ..det_model.inference import predict as latex_det_predict
from ..det_model.Bbox import Bbox, draw_bboxes from ..det_model.Bbox import Bbox, draw_bboxes
@@ -64,7 +63,7 @@ def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbo
idx = 0 idx = 0
while (len(bboxes) > 0): while (len(bboxes) > 0):
idx += 1 idx += 1
assert candidate.p.x < curr.p.x or not candidate.same_row(curr) assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr): if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
res.append(candidate) res.append(candidate)
@@ -134,14 +133,8 @@ def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray
return sliced_imgs return sliced_imgs
def preprocess_image(_image):
_image = alpha_to_color(_image, (255, 255, 255))
return _image
def mix_inference( def mix_inference(
img_path: str, img_path: str,
language: str,
infer_config, infer_config,
latex_det_model, latex_det_model,
@@ -156,7 +149,6 @@ def mix_inference(
''' '''
global img global img
img = cv2.imread(img_path) img = cv2.imread(img_path)
img = alpha_to_color(img, (255, 255, 255))
corners = [tuple(img[0, 0]), tuple(img[0, -1]), corners = [tuple(img[0, 0]), tuple(img[0, -1]),
tuple(img[-1, 0]), tuple(img[-1, -1])] tuple(img[-1, 0]), tuple(img[-1, -1])]
bg_color = np.array(Counter(corners).most_common(1)[0][0]) bg_color = np.array(Counter(corners).most_common(1)[0][0])
@@ -172,9 +164,6 @@ def mix_inference(
det_model, rec_model = lang_ocr_models det_model, rec_model = lang_ocr_models
det_prediction, _ = det_model(masked_img) det_prediction, _ = det_model(masked_img)
# log results
draw_bboxes(Image.fromarray(img), latex_bboxes, name="ocr_bboxes(unmerged).png")
ocr_bboxes = [ ocr_bboxes = [
Bbox( Bbox(
p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0], p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0],
@@ -184,8 +173,12 @@ def mix_inference(
) )
for p in det_prediction for p in det_prediction
] ]
# log results
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(unmerged).png")
ocr_bboxes = sorted(ocr_bboxes) ocr_bboxes = sorted(ocr_bboxes)
ocr_bboxes = bbox_merge(ocr_bboxes) ocr_bboxes = bbox_merge(ocr_bboxes)
# log results
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png") draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes) ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes)) ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
@@ -193,7 +186,6 @@ def mix_inference(
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes) sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
rec_predictions, _ = rec_model(sliced_imgs) rec_predictions, _ = rec_model(sliced_imgs)
assert len(rec_predictions) == len(ocr_bboxes) assert len(rec_predictions) == len(ocr_bboxes)
for content, bbox in zip(rec_predictions, ocr_bboxes): for content, bbox in zip(rec_predictions, ocr_bboxes):
bbox.content = content[0] bbox.content = content[0]
@@ -202,6 +194,7 @@ def mix_inference(
for bbox in latex_bboxes: for bbox in latex_bboxes:
latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w]) latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w])
latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=200) latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=200)
for bbox, content in zip(latex_bboxes, latex_rec_res): for bbox, content in zip(latex_bboxes, latex_rec_res):
bbox.content = to_katex(content) bbox.content = to_katex(content)
if bbox.label == "embedding": if bbox.label == "embedding":