From e16f46e85679d5b364ad01985469ccf23ce34032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Fri, 5 Apr 2024 07:25:06 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86v3(=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E8=87=AA=E7=84=B6=E5=9C=BA=E6=99=AF=E3=80=81=E6=B7=B7?= =?UTF-8?q?=E5=90=88=E6=96=87=E5=AD=97=E5=9C=BA=E6=99=AF=E8=AF=86=E5=88=AB?= =?UTF-8?q?)=E7=89=88=E6=9C=AC=E7=9A=84inference.py=E6=A8=A1=E7=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/inference.py | 27 ++++++++++++++++++++----- src/models/ocr_model/utils/inference.py | 10 ++++----- src/utils.py | 19 +++++++++++++++++ 3 files changed, 46 insertions(+), 10 deletions(-) create mode 100644 src/utils.py diff --git a/src/inference.py b/src/inference.py index c6d6f61..99a2992 100644 --- a/src/inference.py +++ b/src/inference.py @@ -1,9 +1,11 @@ import os import argparse +import cv2 as cv from pathlib import Path -from models.ocr_model.utils.inference import inference +from models.ocr_model.utils.inference import inference as latex_inference from models.ocr_model.model.TexTeller import TexTeller +from utils import load_det_tex_model, load_lang_models if __name__ == '__main__': @@ -21,16 +23,31 @@ if __name__ == '__main__': action='store_true', help='use cuda or not' ) + # ================= new feature ================== + parser.add_argument( + '-mix', + type=str, + help='use mix mode, only Chinese and English are supported.' + ) + # ================================================== args = parser.parse_args() # You can use your own checkpoint and tokenizer path. print('Loading model and tokenizer...') - model = TexTeller.from_pretrained() + latex_rec_model = TexTeller.from_pretrained() tokenizer = TexTeller.get_tokenizer() print('Model and tokenizer loaded.') - img_path = [args.img] + # img_path = [args.img] + img = cv.imread(args.img) print('Inference...') - res = inference(model, tokenizer, img_path, args.cuda) - print(res[0]) + if not args.mix: + res = latex_inference(latex_rec_model, tokenizer, [img], args.cuda) + print(res[0]) + else: + # latex_det_model = load_det_tex_model() + # lang_model = load_lang_models()... + ... + # res: str = mix_inference(latex_det_model, latex_rec_model, lang_model, img, args.cuda) + # print(res) diff --git a/src/models/ocr_model/utils/inference.py b/src/models/ocr_model/utils/inference.py index cc34101..fcff742 100644 --- a/src/models/ocr_model/utils/inference.py +++ b/src/models/ocr_model/utils/inference.py @@ -13,16 +13,16 @@ from models.globals import MAX_TOKEN_SIZE def inference( model: TexTeller, tokenizer: RobertaTokenizerFast, - imgs_path: Union[List[str], List[np.ndarray]], + imgs: Union[List[str], List[np.ndarray]], use_cuda: bool, num_beams: int = 1, ) -> List[str]: model.eval() - if isinstance(imgs_path[0], str): - imgs = convert2rgb(imgs_path) + if isinstance(imgs[0], str): + imgs = convert2rgb(imgs) else: # already numpy array(rgb format) - assert isinstance(imgs_path[0], np.ndarray) - imgs = imgs_path + assert isinstance(imgs[0], np.ndarray) + imgs = imgs imgs = inference_transform(imgs) pixel_values = torch.stack(imgs) diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..e6e4a9d --- /dev/null +++ b/src/utils.py @@ -0,0 +1,19 @@ +import numpy as np + +from models.ocr_model.utils.inference import inference as latex_inference + + +def load_lang_models(language: str): + ... + # language: 'ch' or 'en' + # return det_model, rec_model (or model) + + +def load_det_tex_model(): + ... + # return the loaded latex detection model + + +def mix_inference(latex_det_model, latex_rec_model, lang_model, img: np.ndarray, use_cuda: bool) -> str: + ... + # latex_inference(...)