From 7b2b947c47dc5f77449fdb39d995fbd22a4ea2aa Mon Sep 17 00:00:00 2001 From: TonyLee1256 <163754792+TonyLee1256@users.noreply.github.com> Date: Tue, 7 May 2024 13:19:43 +0800 Subject: [PATCH] bugfix inference.py --- src/inference.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/inference.py b/src/inference.py index b74a399..4022f6c 100644 --- a/src/inference.py +++ b/src/inference.py @@ -1,4 +1,5 @@ import os +import sys import argparse import cv2 as cv @@ -44,8 +45,17 @@ if __name__ == '__main__': help='use mix mode' ) + parser.add_argument( + '-lang', + type=str, + default='None' + ) + args = parser.parse_args() - + if args.mix and args.lang == "None": + print("When -mix is set, -lang must be set (support: ['zh', 'en'])") + sys.exit(-1) + # You can use your own checkpoint and tokenizer path. print('Loading model and tokenizer...') latex_rec_model = TexTeller.from_pretrained() @@ -61,12 +71,12 @@ if __name__ == '__main__': print(res) else: infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml") - latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco_IBEM_cnTextBook.onnx") + latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx") det_processor, det_model = segformer.load_processor(), segformer.load_model() rec_model, rec_processor = load_model(), load_processor() lang_ocr_models = [det_model, det_processor, rec_model, rec_processor] latex_rec_models = [latex_rec_model, tokenizer] - res = mix_inference(img_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam) + res = mix_inference(img_path, args.lang , infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam) print(res)