2024-02-11 08:06:50 +00:00
|
|
|
import os
|
2024-05-07 13:19:43 +08:00
|
|
|
import sys
|
2024-02-11 08:06:50 +00:00
|
|
|
import argparse
|
2024-04-05 07:25:06 +00:00
|
|
|
import cv2 as cv
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
from pathlib import Path
|
2024-04-21 00:05:14 +08:00
|
|
|
from onnxruntime import InferenceSession
|
2024-05-09 00:22:01 +08:00
|
|
|
from paddleocr import PaddleOCR
|
2024-04-21 00:05:14 +08:00
|
|
|
|
|
|
|
|
from models.utils import mix_inference
|
|
|
|
|
from models.ocr_model.utils.to_katex import to_katex
|
2024-04-05 07:25:06 +00:00
|
|
|
from models.ocr_model.utils.inference import inference as latex_inference
|
2024-04-21 00:05:14 +08:00
|
|
|
|
2024-02-11 08:06:50 +00:00
|
|
|
from models.ocr_model.model.TexTeller import TexTeller
|
2024-04-21 00:05:14 +08:00
|
|
|
from models.det_model.inference import PredictConfig
|
|
|
|
|
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
os.chdir(Path(__file__).resolve().parent)
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'-img',
|
|
|
|
|
type=str,
|
|
|
|
|
required=True,
|
|
|
|
|
help='path to the input image'
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
2024-04-17 09:12:07 +00:00
|
|
|
'--inference-mode',
|
|
|
|
|
type=str,
|
|
|
|
|
default='cpu',
|
|
|
|
|
help='Inference mode, select one of cpu, cuda, or mps'
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--num-beam',
|
|
|
|
|
type=int,
|
|
|
|
|
default=1,
|
|
|
|
|
help='number of beam search for decoding'
|
2024-02-11 08:06:50 +00:00
|
|
|
)
|
2024-04-05 07:25:06 +00:00
|
|
|
parser.add_argument(
|
|
|
|
|
'-mix',
|
2024-04-21 00:05:14 +08:00
|
|
|
action='store_true',
|
|
|
|
|
help='use mix mode'
|
2024-04-05 07:25:06 +00:00
|
|
|
)
|
2024-05-07 13:19:43 +08:00
|
|
|
parser.add_argument(
|
|
|
|
|
'-lang',
|
|
|
|
|
type=str,
|
|
|
|
|
default='None'
|
|
|
|
|
)
|
|
|
|
|
|
2024-02-11 08:06:50 +00:00
|
|
|
args = parser.parse_args()
|
2024-05-07 13:19:43 +08:00
|
|
|
if args.mix and args.lang == "None":
|
|
|
|
|
print("When -mix is set, -lang must be set (support: ['zh', 'en'])")
|
|
|
|
|
sys.exit(-1)
|
2024-05-07 13:28:07 +08:00
|
|
|
elif args.mix and args.lang not in ['zh', 'en']:
|
|
|
|
|
print(f"language support: ['zh', 'en'] (invalid: {args.lang})")
|
|
|
|
|
sys.exit(-1)
|
2024-05-07 13:19:43 +08:00
|
|
|
|
2024-02-11 08:06:50 +00:00
|
|
|
# You can use your own checkpoint and tokenizer path.
|
|
|
|
|
print('Loading model and tokenizer...')
|
2024-04-05 07:25:06 +00:00
|
|
|
latex_rec_model = TexTeller.from_pretrained()
|
2024-02-11 08:06:50 +00:00
|
|
|
tokenizer = TexTeller.get_tokenizer()
|
|
|
|
|
print('Model and tokenizer loaded.')
|
|
|
|
|
|
2024-04-21 00:05:14 +08:00
|
|
|
img_path = args.img
|
|
|
|
|
img = cv.imread(img_path)
|
2024-02-11 08:06:50 +00:00
|
|
|
print('Inference...')
|
2024-04-05 07:25:06 +00:00
|
|
|
if not args.mix:
|
2024-04-17 10:30:09 +00:00
|
|
|
res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam)
|
2024-04-06 10:09:15 +00:00
|
|
|
res = to_katex(res[0])
|
|
|
|
|
print(res)
|
2024-04-05 07:25:06 +00:00
|
|
|
else:
|
2024-04-21 00:05:14 +08:00
|
|
|
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
|
2024-05-07 13:19:43 +08:00
|
|
|
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
2024-04-21 00:05:14 +08:00
|
|
|
|
2024-05-09 00:22:01 +08:00
|
|
|
use_gpu = args.inference_mode == 'cuda'
|
|
|
|
|
text_ocr_model = PaddleOCR(
|
|
|
|
|
use_angle_cls=False, lang='ch', use_gpu=use_gpu,
|
|
|
|
|
det_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_det_server_infer",
|
|
|
|
|
rec_model_dir="./models/text_ocr_model/infer_models/ch_PP-OCRv4_rec_server_infer",
|
|
|
|
|
det_limit_type='max',
|
|
|
|
|
det_limit_side_len=1280,
|
|
|
|
|
use_dilation=True,
|
|
|
|
|
det_db_score_mode="slow",
|
|
|
|
|
) # need to run only once to load model into memory
|
2024-04-21 00:05:14 +08:00
|
|
|
|
2024-05-09 00:22:01 +08:00
|
|
|
detector = text_ocr_model.text_detector
|
|
|
|
|
recognizer = text_ocr_model.text_recognizer
|
|
|
|
|
|
|
|
|
|
lang_ocr_models = [detector, recognizer]
|
2024-04-21 00:05:14 +08:00
|
|
|
latex_rec_models = [latex_rec_model, tokenizer]
|
2024-05-07 13:19:43 +08:00
|
|
|
res = mix_inference(img_path, args.lang , infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam)
|
2024-04-21 00:05:14 +08:00
|
|
|
print(res)
|