Update inference.py
替换文本OCR模型为paddleocr
This commit is contained in:
@@ -5,6 +5,7 @@ import cv2 as cv
|
||||
|
||||
from pathlib import Path
|
||||
from onnxruntime import InferenceSession
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
from models.utils import mix_inference
|
||||
from models.ocr_model.utils.to_katex import to_katex
|
||||
@@ -13,10 +14,6 @@ from models.ocr_model.utils.inference import inference as latex_inference
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
from models.det_model.inference import PredictConfig
|
||||
|
||||
from surya.model.detection import segformer
|
||||
from surya.model.recognition.model import load_model
|
||||
from surya.model.recognition.processor import load_processor
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.chdir(Path(__file__).resolve().parent)
|
||||
@@ -44,7 +41,6 @@ if __name__ == '__main__':
|
||||
action='store_true',
|
||||
help='use mix mode'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-lang',
|
||||
type=str,
|
||||
@@ -76,10 +72,21 @@ if __name__ == '__main__':
|
||||
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
|
||||
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
||||
|
||||
det_processor, det_model = segformer.load_processor(), segformer.load_model()
|
||||
rec_model, rec_processor = load_model(), load_processor()
|
||||
lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
|
||||
use_gpu = args.inference_mode == 'cuda'
|
||||
text_ocr_model = PaddleOCR(
|
||||
use_angle_cls=False, lang='ch', use_gpu=use_gpu,
|
||||
det_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_det_server_infer",
|
||||
rec_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_rec_server_infer",
|
||||
det_limit_type='max',
|
||||
det_limit_side_len=1280,
|
||||
use_dilation=True,
|
||||
det_db_score_mode="slow",
|
||||
) # need to run only once to load model into memory
|
||||
|
||||
detector = text_ocr_model.text_detector
|
||||
recognizer = text_ocr_model.text_recognizer
|
||||
|
||||
lang_ocr_models = [detector, recognizer]
|
||||
latex_rec_models = [latex_rec_model, tokenizer]
|
||||
res = mix_inference(img_path, args.lang , infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
|
||||
print(res)
|
||||
|
||||
Reference in New Issue
Block a user