Update inference.py
替换文本OCR模型为paddleocr
This commit is contained in:
@@ -5,6 +5,7 @@ import cv2 as cv
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from onnxruntime import InferenceSession
|
from onnxruntime import InferenceSession
|
||||||
|
from paddleocr import PaddleOCR
|
||||||
|
|
||||||
from models.utils import mix_inference
|
from models.utils import mix_inference
|
||||||
from models.ocr_model.utils.to_katex import to_katex
|
from models.ocr_model.utils.to_katex import to_katex
|
||||||
@@ -13,10 +14,6 @@ from models.ocr_model.utils.inference import inference as latex_inference
|
|||||||
from models.ocr_model.model.TexTeller import TexTeller
|
from models.ocr_model.model.TexTeller import TexTeller
|
||||||
from models.det_model.inference import PredictConfig
|
from models.det_model.inference import PredictConfig
|
||||||
|
|
||||||
from surya.model.detection import segformer
|
|
||||||
from surya.model.recognition.model import load_model
|
|
||||||
from surya.model.recognition.processor import load_processor
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
os.chdir(Path(__file__).resolve().parent)
|
os.chdir(Path(__file__).resolve().parent)
|
||||||
@@ -44,7 +41,6 @@ if __name__ == '__main__':
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='use mix mode'
|
help='use mix mode'
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-lang',
|
'-lang',
|
||||||
type=str,
|
type=str,
|
||||||
@@ -76,10 +72,21 @@ if __name__ == '__main__':
|
|||||||
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
|
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
|
||||||
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
||||||
|
|
||||||
det_processor, det_model = segformer.load_processor(), segformer.load_model()
|
use_gpu = args.inference_mode == 'cuda'
|
||||||
rec_model, rec_processor = load_model(), load_processor()
|
text_ocr_model = PaddleOCR(
|
||||||
lang_ocr_models = [det_model, det_processor, rec_model, rec_processor]
|
use_angle_cls=False, lang='ch', use_gpu=use_gpu,
|
||||||
|
det_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_det_server_infer",
|
||||||
|
rec_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_rec_server_infer",
|
||||||
|
det_limit_type='max',
|
||||||
|
det_limit_side_len=1280,
|
||||||
|
use_dilation=True,
|
||||||
|
det_db_score_mode="slow",
|
||||||
|
) # need to run only once to load model into memory
|
||||||
|
|
||||||
|
detector = text_ocr_model.text_detector
|
||||||
|
recognizer = text_ocr_model.text_recognizer
|
||||||
|
|
||||||
|
lang_ocr_models = [detector, recognizer]
|
||||||
latex_rec_models = [latex_rec_model, tokenizer]
|
latex_rec_models = [latex_rec_model, tokenizer]
|
||||||
res = mix_inference(img_path, args.lang , infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
|
res = mix_inference(img_path, args.lang , infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
|
||||||
print(res)
|
print(res)
|
||||||
|
|||||||
Reference in New Issue
Block a user