diff --git a/src/inference.py b/src/inference.py index c6d6f61..99a2992 100644 --- a/src/inference.py +++ b/src/inference.py @@ -1,9 +1,11 @@ import os import argparse +import cv2 as cv from pathlib import Path -from models.ocr_model.utils.inference import inference +from models.ocr_model.utils.inference import inference as latex_inference from models.ocr_model.model.TexTeller import TexTeller +from utils import load_det_tex_model, load_lang_models if __name__ == '__main__': @@ -21,16 +23,31 @@ if __name__ == '__main__': action='store_true', help='use cuda or not' ) + # ================= new feature ================== + parser.add_argument( + '-mix', + type=str, + help='use mix mode, only Chinese and English are supported.' + ) + # ================================================== args = parser.parse_args() # You can use your own checkpoint and tokenizer path. print('Loading model and tokenizer...') - model = TexTeller.from_pretrained() + latex_rec_model = TexTeller.from_pretrained() tokenizer = TexTeller.get_tokenizer() print('Model and tokenizer loaded.') - img_path = [args.img] + # img_path = [args.img] + img = cv.imread(args.img) print('Inference...') - res = inference(model, tokenizer, img_path, args.cuda) - print(res[0]) + if not args.mix: + res = latex_inference(latex_rec_model, tokenizer, [img], args.cuda) + print(res[0]) + else: + # latex_det_model = load_det_tex_model() + # lang_model = load_lang_models()... + ... + # res: str = mix_inference(latex_det_model, latex_rec_model, lang_model, img, args.cuda) + # print(res) diff --git a/src/models/ocr_model/utils/inference.py b/src/models/ocr_model/utils/inference.py index cc34101..fcff742 100644 --- a/src/models/ocr_model/utils/inference.py +++ b/src/models/ocr_model/utils/inference.py @@ -13,16 +13,16 @@ from models.globals import MAX_TOKEN_SIZE def inference( model: TexTeller, tokenizer: RobertaTokenizerFast, - imgs_path: Union[List[str], List[np.ndarray]], + imgs: Union[List[str], List[np.ndarray]], use_cuda: bool, num_beams: int = 1, ) -> List[str]: model.eval() - if isinstance(imgs_path[0], str): - imgs = convert2rgb(imgs_path) + if isinstance(imgs[0], str): + imgs = convert2rgb(imgs) else: # already numpy array(rgb format) - assert isinstance(imgs_path[0], np.ndarray) - imgs = imgs_path + assert isinstance(imgs[0], np.ndarray) + imgs = imgs imgs = inference_transform(imgs) pixel_values = torch.stack(imgs) diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..e6e4a9d --- /dev/null +++ b/src/utils.py @@ -0,0 +1,19 @@ +import numpy as np + +from models.ocr_model.utils.inference import inference as latex_inference + + +def load_lang_models(language: str): + ... + # language: 'ch' or 'en' + # return det_model, rec_model (or model) + + +def load_det_tex_model(): + ... + # return the loaded latex detection model + + +def mix_inference(latex_det_model, latex_rec_model, lang_model, img: np.ndarray, use_cuda: bool) -> str: + ... + # latex_inference(...)