Using paddleocr with onnxruntime
Deleted the code for test time.
This commit is contained in:
@@ -1,11 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import argparse
|
import argparse
|
||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from onnxruntime import InferenceSession
|
from onnxruntime import InferenceSession
|
||||||
from paddleocr import PaddleOCR
|
from models.thrid_party.paddleocr.infer import predict_det, predict_rec
|
||||||
|
from models.thrid_party.paddleocr.infer import utility
|
||||||
|
|
||||||
from models.utils import mix_inference
|
from models.utils import mix_inference
|
||||||
from models.ocr_model.utils.to_katex import to_katex
|
from models.ocr_model.utils.to_katex import to_katex
|
||||||
@@ -41,19 +41,8 @@ if __name__ == '__main__':
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='use mix mode'
|
help='use mix mode'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
'-lang',
|
|
||||||
type=str,
|
|
||||||
default='None'
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.mix and args.lang == "None":
|
|
||||||
print("When -mix is set, -lang must be set (support: ['zh', 'en'])")
|
|
||||||
sys.exit(-1)
|
|
||||||
elif args.mix and args.lang not in ['zh', 'en']:
|
|
||||||
print(f"language support: ['zh', 'en'] (invalid: {args.lang})")
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
# You can use your own checkpoint and tokenizer path.
|
# You can use your own checkpoint and tokenizer path.
|
||||||
print('Loading model and tokenizer...')
|
print('Loading model and tokenizer...')
|
||||||
@@ -73,20 +62,24 @@ if __name__ == '__main__':
|
|||||||
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
||||||
|
|
||||||
use_gpu = args.inference_mode == 'cuda'
|
use_gpu = args.inference_mode == 'cuda'
|
||||||
text_ocr_model = PaddleOCR(
|
SIZE_LIMIT = 20 * 1024 * 1024
|
||||||
use_angle_cls=False, lang='ch', use_gpu=use_gpu,
|
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
|
||||||
det_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_det_server_infer",
|
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
|
||||||
rec_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_rec_server_infer",
|
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
|
||||||
det_limit_type='max',
|
det_use_gpu = False
|
||||||
det_limit_side_len=1280,
|
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
|
||||||
use_dilation=True,
|
|
||||||
det_db_score_mode="slow",
|
|
||||||
) # need to run only once to load model into memory
|
|
||||||
|
|
||||||
detector = text_ocr_model.text_detector
|
paddleocr_args = utility.parse_args()
|
||||||
recognizer = text_ocr_model.text_recognizer
|
paddleocr_args.use_onnx = True
|
||||||
|
paddleocr_args.det_model_dir = det_model_dir
|
||||||
|
paddleocr_args.rec_model_dir = rec_model_dir
|
||||||
|
|
||||||
|
paddleocr_args.use_gpu = det_use_gpu
|
||||||
|
detector = predict_det.TextDetector(paddleocr_args)
|
||||||
|
paddleocr_args.use_gpu = rec_use_gpu
|
||||||
|
recognizer = predict_rec.TextRecognizer(paddleocr_args)
|
||||||
|
|
||||||
lang_ocr_models = [detector, recognizer]
|
lang_ocr_models = [detector, recognizer]
|
||||||
latex_rec_models = [latex_rec_model, tokenizer]
|
latex_rec_models = [latex_rec_model, tokenizer]
|
||||||
res = mix_inference(img_path, args.lang , infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
|
res = mix_inference(img_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
|
||||||
print(res)
|
print(res)
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
import re
|
import re
|
||||||
import heapq
|
import heapq
|
||||||
import cv2
|
import cv2
|
||||||
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import List
|
from typing import List
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from paddleocr.ppocr.utils.utility import alpha_to_color
|
|
||||||
|
|
||||||
from ..det_model.inference import predict as latex_det_predict
|
from ..det_model.inference import predict as latex_det_predict
|
||||||
from ..det_model.Bbox import Bbox, draw_bboxes
|
from ..det_model.Bbox import Bbox, draw_bboxes
|
||||||
|
|
||||||
@@ -64,7 +63,7 @@ def split_conflict(ocr_bboxes: List[Bbox], latex_bboxes: List[Bbox]) -> List[Bbo
|
|||||||
idx = 0
|
idx = 0
|
||||||
while (len(bboxes) > 0):
|
while (len(bboxes) > 0):
|
||||||
idx += 1
|
idx += 1
|
||||||
assert candidate.p.x < curr.p.x or not candidate.same_row(curr)
|
assert candidate.p.x <= curr.p.x or not candidate.same_row(curr)
|
||||||
|
|
||||||
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
|
if candidate.ur_point.x <= curr.p.x or not candidate.same_row(curr):
|
||||||
res.append(candidate)
|
res.append(candidate)
|
||||||
@@ -134,14 +133,8 @@ def slice_from_image(img: np.ndarray, ocr_bboxes: List[Bbox]) -> List[np.ndarray
|
|||||||
return sliced_imgs
|
return sliced_imgs
|
||||||
|
|
||||||
|
|
||||||
def preprocess_image(_image):
|
|
||||||
_image = alpha_to_color(_image, (255, 255, 255))
|
|
||||||
return _image
|
|
||||||
|
|
||||||
|
|
||||||
def mix_inference(
|
def mix_inference(
|
||||||
img_path: str,
|
img_path: str,
|
||||||
language: str,
|
|
||||||
infer_config,
|
infer_config,
|
||||||
latex_det_model,
|
latex_det_model,
|
||||||
|
|
||||||
@@ -156,7 +149,6 @@ def mix_inference(
|
|||||||
'''
|
'''
|
||||||
global img
|
global img
|
||||||
img = cv2.imread(img_path)
|
img = cv2.imread(img_path)
|
||||||
img = alpha_to_color(img, (255, 255, 255))
|
|
||||||
corners = [tuple(img[0, 0]), tuple(img[0, -1]),
|
corners = [tuple(img[0, 0]), tuple(img[0, -1]),
|
||||||
tuple(img[-1, 0]), tuple(img[-1, -1])]
|
tuple(img[-1, 0]), tuple(img[-1, -1])]
|
||||||
bg_color = np.array(Counter(corners).most_common(1)[0][0])
|
bg_color = np.array(Counter(corners).most_common(1)[0][0])
|
||||||
@@ -172,9 +164,6 @@ def mix_inference(
|
|||||||
|
|
||||||
det_model, rec_model = lang_ocr_models
|
det_model, rec_model = lang_ocr_models
|
||||||
det_prediction, _ = det_model(masked_img)
|
det_prediction, _ = det_model(masked_img)
|
||||||
# log results
|
|
||||||
draw_bboxes(Image.fromarray(img), latex_bboxes, name="ocr_bboxes(unmerged).png")
|
|
||||||
|
|
||||||
ocr_bboxes = [
|
ocr_bboxes = [
|
||||||
Bbox(
|
Bbox(
|
||||||
p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0],
|
p[0][0], p[0][1], p[3][1]-p[0][1], p[1][0]-p[0][0],
|
||||||
@@ -184,8 +173,12 @@ def mix_inference(
|
|||||||
)
|
)
|
||||||
for p in det_prediction
|
for p in det_prediction
|
||||||
]
|
]
|
||||||
|
# log results
|
||||||
|
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(unmerged).png")
|
||||||
|
|
||||||
ocr_bboxes = sorted(ocr_bboxes)
|
ocr_bboxes = sorted(ocr_bboxes)
|
||||||
ocr_bboxes = bbox_merge(ocr_bboxes)
|
ocr_bboxes = bbox_merge(ocr_bboxes)
|
||||||
|
# log results
|
||||||
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
|
draw_bboxes(Image.fromarray(img), ocr_bboxes, name="ocr_bboxes(merged).png")
|
||||||
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
|
ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes)
|
||||||
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
|
ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes))
|
||||||
@@ -193,7 +186,6 @@ def mix_inference(
|
|||||||
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
|
sliced_imgs: List[np.ndarray] = slice_from_image(img, ocr_bboxes)
|
||||||
rec_predictions, _ = rec_model(sliced_imgs)
|
rec_predictions, _ = rec_model(sliced_imgs)
|
||||||
|
|
||||||
|
|
||||||
assert len(rec_predictions) == len(ocr_bboxes)
|
assert len(rec_predictions) == len(ocr_bboxes)
|
||||||
for content, bbox in zip(rec_predictions, ocr_bboxes):
|
for content, bbox in zip(rec_predictions, ocr_bboxes):
|
||||||
bbox.content = content[0]
|
bbox.content = content[0]
|
||||||
@@ -202,6 +194,7 @@ def mix_inference(
|
|||||||
for bbox in latex_bboxes:
|
for bbox in latex_bboxes:
|
||||||
latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w])
|
latex_imgs.append(img[bbox.p.y:bbox.p.y + bbox.h, bbox.p.x:bbox.p.x + bbox.w])
|
||||||
latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=200)
|
latex_rec_res = latex_rec_predict(*latex_rec_models, latex_imgs, accelerator, num_beams, max_tokens=200)
|
||||||
|
|
||||||
for bbox, content in zip(latex_bboxes, latex_rec_res):
|
for bbox, content in zip(latex_bboxes, latex_rec_res):
|
||||||
bbox.content = to_katex(content)
|
bbox.content = to_katex(content)
|
||||||
if bbox.label == "embedding":
|
if bbox.label == "embedding":
|
||||||
|
|||||||
Reference in New Issue
Block a user