Files
TexTeller/src/inference.py

93 lines
3.0 KiB
Python
Raw Normal View History

2024-02-11 08:06:50 +00:00
import os
2024-05-07 13:19:43 +08:00
import sys
2024-02-11 08:06:50 +00:00
import argparse
import cv2 as cv
2024-02-11 08:06:50 +00:00
from pathlib import Path
from onnxruntime import InferenceSession
from paddleocr import PaddleOCR
from models.utils import mix_inference
from models.ocr_model.utils.to_katex import to_katex
from models.ocr_model.utils.inference import inference as latex_inference
2024-02-11 08:06:50 +00:00
from models.ocr_model.model.TexTeller import TexTeller
from models.det_model.inference import PredictConfig
2024-02-11 08:06:50 +00:00
if __name__ == '__main__':
os.chdir(Path(__file__).resolve().parent)
parser = argparse.ArgumentParser()
parser.add_argument(
'-img',
type=str,
required=True,
help='path to the input image'
)
parser.add_argument(
'--inference-mode',
type=str,
default='cpu',
help='Inference mode, select one of cpu, cuda, or mps'
)
parser.add_argument(
'--num-beam',
type=int,
default=1,
help='number of beam search for decoding'
2024-02-11 08:06:50 +00:00
)
parser.add_argument(
'-mix',
action='store_true',
help='use mix mode'
)
2024-05-07 13:19:43 +08:00
parser.add_argument(
'-lang',
type=str,
default='None'
)
2024-02-11 08:06:50 +00:00
args = parser.parse_args()
2024-05-07 13:19:43 +08:00
if args.mix and args.lang == "None":
print("When -mix is set, -lang must be set (support: ['zh', 'en'])")
sys.exit(-1)
2024-05-07 13:28:07 +08:00
elif args.mix and args.lang not in ['zh', 'en']:
print(f"language support: ['zh', 'en'] (invalid: {args.lang})")
sys.exit(-1)
2024-05-07 13:19:43 +08:00
2024-02-11 08:06:50 +00:00
# You can use your own checkpoint and tokenizer path.
print('Loading model and tokenizer...')
latex_rec_model = TexTeller.from_pretrained()
2024-02-11 08:06:50 +00:00
tokenizer = TexTeller.get_tokenizer()
print('Model and tokenizer loaded.')
img_path = args.img
img = cv.imread(img_path)
2024-02-11 08:06:50 +00:00
print('Inference...')
if not args.mix:
2024-04-17 10:30:09 +00:00
res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam)
2024-04-06 10:09:15 +00:00
res = to_katex(res[0])
print(res)
else:
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
2024-05-07 13:19:43 +08:00
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
use_gpu = args.inference_mode == 'cuda'
text_ocr_model = PaddleOCR(
use_angle_cls=False, lang='ch', use_gpu=use_gpu,
det_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_det_server_infer",
rec_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_rec_server_infer",
det_limit_type='max',
det_limit_side_len=1280,
use_dilation=True,
det_db_score_mode="slow",
) # need to run only once to load model into memory
detector = text_ocr_model.text_detector
recognizer = text_ocr_model.text_recognizer
lang_ocr_models = [detector, recognizer]
latex_rec_models = [latex_rec_model, tokenizer]
2024-05-07 13:19:43 +08:00
res = mix_inference(img_path, args.lang , infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
print(res)