修改了v3(支持自然场景、混合文字场景识别)版本的inference.py模版

This commit is contained in:
三洋三洋
2024-04-05 07:25:06 +00:00
parent 5b730329b4
commit 34ac31504a
3 changed files with 46 additions and 10 deletions

View File

@@ -1,9 +1,11 @@
import os
import argparse
import cv2 as cv
from pathlib import Path
from models.ocr_model.utils.inference import inference
from models.ocr_model.utils.inference import inference as latex_inference
from models.ocr_model.model.TexTeller import TexTeller
from utils import load_det_tex_model, load_lang_models
if __name__ == '__main__':
@@ -21,16 +23,31 @@ if __name__ == '__main__':
action='store_true',
help='use cuda or not'
)
# ================= new feature ==================
parser.add_argument(
'-mix',
type=str,
help='use mix mode, only Chinese and English are supported.'
)
# ==================================================
args = parser.parse_args()
# You can use your own checkpoint and tokenizer path.
print('Loading model and tokenizer...')
model = TexTeller.from_pretrained()
latex_rec_model = TexTeller.from_pretrained()
tokenizer = TexTeller.get_tokenizer()
print('Model and tokenizer loaded.')
img_path = [args.img]
# img_path = [args.img]
img = cv.imread(args.img)
print('Inference...')
res = inference(model, tokenizer, img_path, args.cuda)
print(res[0])
if not args.mix:
res = latex_inference(latex_rec_model, tokenizer, [img], args.cuda)
print(res[0])
else:
# latex_det_model = load_det_tex_model()
# lang_model = load_lang_models()...
...
# res: str = mix_inference(latex_det_model, latex_rec_model, lang_model, img, args.cuda)
# print(res)

View File

@@ -13,16 +13,16 @@ from models.globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs_path: Union[List[str], List[np.ndarray]],
imgs: Union[List[str], List[np.ndarray]],
use_cuda: bool,
num_beams: int = 1,
) -> List[str]:
model.eval()
if isinstance(imgs_path[0], str):
imgs = convert2rgb(imgs_path)
if isinstance(imgs[0], str):
imgs = convert2rgb(imgs)
else: # already numpy array(rgb format)
assert isinstance(imgs_path[0], np.ndarray)
imgs = imgs_path
assert isinstance(imgs[0], np.ndarray)
imgs = imgs
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)

19
src/utils.py Normal file
View File

@@ -0,0 +1,19 @@
import numpy as np
from models.ocr_model.utils.inference import inference as latex_inference
def load_lang_models(language: str):
...
# language: 'ch' or 'en'
# return det_model, rec_model (or model)
def load_det_tex_model():
...
# return the loaded latex detection model
def mix_inference(latex_det_model, latex_rec_model, lang_model, img: np.ndarray, use_cuda: bool) -> str:
...
# latex_inference(...)