Update inference.py

替换文本OCR模型为paddleocr
This commit is contained in:
TonyLee1256
2024-05-09 00:22:01 +08:00
committed by GitHub
parent fe7e4a7af0
commit bd2aaa3e00

View File

@@ -5,6 +5,7 @@ import cv2 as cv
from pathlib import Path
from onnxruntime import InferenceSession
from paddleocr import PaddleOCR
from models.utils import mix_inference
from models.ocr_model.utils.to_katex import to_katex
@@ -13,10 +14,6 @@ from models.ocr_model.utils.inference import inference as latex_inference
from models.ocr_model.model.TexTeller import TexTeller
from models.det_model.inference import PredictConfig
from surya.model.detection import segformer
from surya.model.recognition.model import load_model
from surya.model.recognition.processor import load_processor
if __name__ == '__main__':
os.chdir(Path(__file__).resolve().parent)
@@ -44,7 +41,6 @@ if __name__ == '__main__':
action='store_true',
help='use mix mode'
)
parser.add_argument(
'-lang',
type=str,
@@ -76,10 +72,21 @@ if __name__ == '__main__':
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
det_processor, det_model = segformer.load_processor(), segformer.load_model()
rec_model, rec_processor = load_model(), load_processor()
lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
use_gpu = args.inference_mode == 'cuda'
text_ocr_model = PaddleOCR(
use_angle_cls=False, lang='ch', use_gpu=use_gpu,
det_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_det_server_infer",
rec_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_rec_server_infer",
det_limit_type='max',
det_limit_side_len=1280,
use_dilation=True,
det_db_score_mode="slow",
) # need to run only once to load model into memory
detector = text_ocr_model.text_detector
recognizer = text_ocr_model.text_recognizer
lang_ocr_models = [detector, recognizer]
latex_rec_models = [latex_rec_model, tokenizer]
res = mix_inference(img_path, args.lang , infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
print(res)