82 lines
3.0 KiB
Python
82 lines
3.0 KiB
Python
import os
|
|
import argparse
|
|
import cv2 as cv
|
|
|
|
from pathlib import Path
|
|
from onnxruntime import InferenceSession
|
|
from models.thrid_party.paddleocr.infer import predict_det, predict_rec
|
|
from models.thrid_party.paddleocr.infer import utility
|
|
|
|
from models.utils import mix_inference
|
|
from models.ocr_model.utils.to_katex import to_katex
|
|
from models.ocr_model.utils.inference import inference as latex_inference
|
|
|
|
from models.ocr_model.model.TexTeller import TexTeller
|
|
from models.det_model.inference import PredictConfig
|
|
|
|
|
|
if __name__ == '__main__':
|
|
os.chdir(Path(__file__).resolve().parent)
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-img', type=str, required=True, help='path to the input image')
|
|
parser.add_argument(
|
|
'--inference-mode',
|
|
type=str,
|
|
default='cpu',
|
|
help='Inference mode, select one of cpu, cuda, or mps',
|
|
)
|
|
parser.add_argument(
|
|
'--num-beam', type=int, default=1, help='number of beam search for decoding'
|
|
)
|
|
parser.add_argument('-mix', action='store_true', help='use mix mode')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# You can use your own checkpoint and tokenizer path.
|
|
print('Loading model and tokenizer...')
|
|
latex_rec_model = TexTeller.from_pretrained()
|
|
tokenizer = TexTeller.get_tokenizer()
|
|
print('Model and tokenizer loaded.')
|
|
|
|
img_path = args.img
|
|
img = cv.imread(img_path)
|
|
print('Inference...')
|
|
if not args.mix:
|
|
res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam)
|
|
res = to_katex(res[0])
|
|
print(res)
|
|
else:
|
|
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
|
|
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
|
|
|
|
use_gpu = args.inference_mode == 'cuda'
|
|
SIZE_LIMIT = 20 * 1024 * 1024
|
|
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
|
|
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
|
|
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
|
|
det_use_gpu = False
|
|
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
|
|
|
|
paddleocr_args = utility.parse_args()
|
|
paddleocr_args.use_onnx = True
|
|
paddleocr_args.det_model_dir = det_model_dir
|
|
paddleocr_args.rec_model_dir = rec_model_dir
|
|
|
|
paddleocr_args.use_gpu = det_use_gpu
|
|
detector = predict_det.TextDetector(paddleocr_args)
|
|
paddleocr_args.use_gpu = rec_use_gpu
|
|
recognizer = predict_rec.TextRecognizer(paddleocr_args)
|
|
|
|
lang_ocr_models = [detector, recognizer]
|
|
latex_rec_models = [latex_rec_model, tokenizer]
|
|
res = mix_inference(
|
|
img_path,
|
|
infer_config,
|
|
latex_det_model,
|
|
lang_ocr_models,
|
|
latex_rec_models,
|
|
args.inference_mode,
|
|
args.num_beam,
|
|
)
|
|
print(res)
|