修改了v3(支持自然场景、混合文字场景识别)版本的inference.py模版

This commit is contained in:
三洋三洋
2024-04-05 07:25:06 +00:00
parent 5b730329b4
commit 34ac31504a
3 changed files with 46 additions and 10 deletions

View File

@@ -13,16 +13,16 @@ from models.globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs_path: Union[List[str], List[np.ndarray]],
imgs: Union[List[str], List[np.ndarray]],
use_cuda: bool,
num_beams: int = 1,
) -> List[str]:
model.eval()
if isinstance(imgs_path[0], str):
imgs = convert2rgb(imgs_path)
if isinstance(imgs[0], str):
imgs = convert2rgb(imgs)
else: # already numpy array(rgb format)
assert isinstance(imgs_path[0], np.ndarray)
imgs = imgs_path
assert isinstance(imgs[0], np.ndarray)
imgs = imgs
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)