Files
TexTeller/texteller/inference.py

82 lines
3.0 KiB
Python
Raw Normal View History

2024-02-11 08:06:50 +00:00
import os
import argparse
import cv2 as cv
2024-02-11 08:06:50 +00:00
from pathlib import Path
from onnxruntime import InferenceSession
from models.thrid_party.paddleocr.infer import predict_det, predict_rec
from models.thrid_party.paddleocr.infer import utility
from models.utils import mix_inference
from models.ocr_model.utils.to_katex import to_katex
from models.ocr_model.utils.inference import inference as latex_inference
2024-02-11 08:06:50 +00:00
from models.ocr_model.model.TexTeller import TexTeller
from models.det_model.inference import PredictConfig
2024-02-11 08:06:50 +00:00
if __name__ == '__main__':
os.chdir(Path(__file__).resolve().parent)
parser = argparse.ArgumentParser()
parser.add_argument('-img', type=str, required=True, help='path to the input image')
2024-02-11 08:06:50 +00:00
parser.add_argument(
'--inference-mode',
type=str,
default='cpu',
help='Inference mode, select one of cpu, cuda, or mps',
2024-02-11 08:06:50 +00:00
)
parser.add_argument(
'--num-beam', type=int, default=1, help='number of beam search for decoding'
)
parser.add_argument('-mix', action='store_true', help='use mix mode')
2024-02-11 08:06:50 +00:00
args = parser.parse_args()
2024-02-11 08:06:50 +00:00
# You can use your own checkpoint and tokenizer path.
print('Loading model and tokenizer...')
latex_rec_model = TexTeller.from_pretrained()
2024-02-11 08:06:50 +00:00
tokenizer = TexTeller.get_tokenizer()
print('Model and tokenizer loaded.')
img_path = args.img
img = cv.imread(img_path)
2024-02-11 08:06:50 +00:00
print('Inference...')
if not args.mix:
2024-04-17 10:30:09 +00:00
res = latex_inference(latex_rec_model, tokenizer, [img], args.inference_mode, args.num_beam)
2024-04-06 10:09:15 +00:00
res = to_katex(res[0])
print(res)
else:
infer_config = PredictConfig("./models/det_model/model/infer_cfg.yml")
2024-05-07 13:19:43 +08:00
latex_det_model = InferenceSession("./models/det_model/model/rtdetr_r50vd_6x_coco.onnx")
use_gpu = args.inference_mode == 'cuda'
SIZE_LIMIT = 20 * 1024 * 1024
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
det_use_gpu = False
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
paddleocr_args = utility.parse_args()
paddleocr_args.use_onnx = True
paddleocr_args.det_model_dir = det_model_dir
paddleocr_args.rec_model_dir = rec_model_dir
paddleocr_args.use_gpu = det_use_gpu
detector = predict_det.TextDetector(paddleocr_args)
paddleocr_args.use_gpu = rec_use_gpu
recognizer = predict_rec.TextRecognizer(paddleocr_args)
lang_ocr_models = [detector, recognizer]
latex_rec_models = [latex_rec_model, tokenizer]
res = mix_inference(
img_path,
infer_config,
latex_det_model,
lang_ocr_models,
latex_rec_models,
args.inference_mode,
args.num_beam,
)
print(res)