修改了v3(支持自然场景、混合文字场景识别)版本的inference.py模版
This commit is contained in:
@@ -1,9 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
import cv2 as cv
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from models.ocr_model.utils.inference import inference
|
from models.ocr_model.utils.inference import inference as latex_inference
|
||||||
from models.ocr_model.model.TexTeller import TexTeller
|
from models.ocr_model.model.TexTeller import TexTeller
|
||||||
|
from utils import load_det_tex_model, load_lang_models
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -21,16 +23,31 @@ if __name__ == '__main__':
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='use cuda or not'
|
help='use cuda or not'
|
||||||
)
|
)
|
||||||
|
# ================= new feature ==================
|
||||||
|
parser.add_argument(
|
||||||
|
'-mix',
|
||||||
|
type=str,
|
||||||
|
help='use mix mode, only Chinese and English are supported.'
|
||||||
|
)
|
||||||
|
# ==================================================
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# You can use your own checkpoint and tokenizer path.
|
# You can use your own checkpoint and tokenizer path.
|
||||||
print('Loading model and tokenizer...')
|
print('Loading model and tokenizer...')
|
||||||
model = TexTeller.from_pretrained()
|
latex_rec_model = TexTeller.from_pretrained()
|
||||||
tokenizer = TexTeller.get_tokenizer()
|
tokenizer = TexTeller.get_tokenizer()
|
||||||
print('Model and tokenizer loaded.')
|
print('Model and tokenizer loaded.')
|
||||||
|
|
||||||
img_path = [args.img]
|
# img_path = [args.img]
|
||||||
|
img = cv.imread(args.img)
|
||||||
print('Inference...')
|
print('Inference...')
|
||||||
res = inference(model, tokenizer, img_path, args.cuda)
|
if not args.mix:
|
||||||
print(res[0])
|
res = latex_inference(latex_rec_model, tokenizer, [img], args.cuda)
|
||||||
|
print(res[0])
|
||||||
|
else:
|
||||||
|
# latex_det_model = load_det_tex_model()
|
||||||
|
# lang_model = load_lang_models()...
|
||||||
|
...
|
||||||
|
# res: str = mix_inference(latex_det_model, latex_rec_model, lang_model, img, args.cuda)
|
||||||
|
# print(res)
|
||||||
|
|||||||
@@ -13,16 +13,16 @@ from models.globals import MAX_TOKEN_SIZE
|
|||||||
def inference(
|
def inference(
|
||||||
model: TexTeller,
|
model: TexTeller,
|
||||||
tokenizer: RobertaTokenizerFast,
|
tokenizer: RobertaTokenizerFast,
|
||||||
imgs_path: Union[List[str], List[np.ndarray]],
|
imgs: Union[List[str], List[np.ndarray]],
|
||||||
use_cuda: bool,
|
use_cuda: bool,
|
||||||
num_beams: int = 1,
|
num_beams: int = 1,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
model.eval()
|
model.eval()
|
||||||
if isinstance(imgs_path[0], str):
|
if isinstance(imgs[0], str):
|
||||||
imgs = convert2rgb(imgs_path)
|
imgs = convert2rgb(imgs)
|
||||||
else: # already numpy array(rgb format)
|
else: # already numpy array(rgb format)
|
||||||
assert isinstance(imgs_path[0], np.ndarray)
|
assert isinstance(imgs[0], np.ndarray)
|
||||||
imgs = imgs_path
|
imgs = imgs
|
||||||
imgs = inference_transform(imgs)
|
imgs = inference_transform(imgs)
|
||||||
pixel_values = torch.stack(imgs)
|
pixel_values = torch.stack(imgs)
|
||||||
|
|
||||||
|
|||||||
19
src/utils.py
Normal file
19
src/utils.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from models.ocr_model.utils.inference import inference as latex_inference
|
||||||
|
|
||||||
|
|
||||||
|
def load_lang_models(language: str):
|
||||||
|
...
|
||||||
|
# language: 'ch' or 'en'
|
||||||
|
# return det_model, rec_model (or model)
|
||||||
|
|
||||||
|
|
||||||
|
def load_det_tex_model():
|
||||||
|
...
|
||||||
|
# return the loaded latex detection model
|
||||||
|
|
||||||
|
|
||||||
|
def mix_inference(latex_det_model, latex_rec_model, lang_model, img: np.ndarray, use_cuda: bool) -> str:
|
||||||
|
...
|
||||||
|
# latex_inference(...)
|
||||||
Reference in New Issue
Block a user