Files
TexTeller/texteller/models/ocr_model/utils/inference.py

50 lines
1.5 KiB
Python
Raw Normal View History

2024-02-11 08:06:50 +00:00
import torch
import numpy as np
2024-02-11 08:06:50 +00:00
from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List, Union
2024-02-11 08:06:50 +00:00
from .transforms import inference_transform
from .helpers import convert2rgb
from ..model.TexTeller import TexTeller
from ...globals import MAX_TOKEN_SIZE
2024-02-11 08:06:50 +00:00
def inference(
model: TexTeller,
2024-02-11 08:06:50 +00:00
tokenizer: RobertaTokenizerFast,
imgs: Union[List[str], List[np.ndarray]],
accelerator: str = 'cpu',
2024-02-11 08:06:50 +00:00
num_beams: int = 1,
max_tokens=None,
2024-02-11 08:06:50 +00:00
) -> List[str]:
if imgs == []:
return []
2024-06-22 21:51:51 +08:00
if hasattr(model, 'eval'):
# not onnx session, turn model.eval()
model.eval()
if isinstance(imgs[0], str):
imgs = convert2rgb(imgs)
else: # already numpy array(rgb format)
assert isinstance(imgs[0], np.ndarray)
imgs = imgs
2024-02-11 08:06:50 +00:00
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
2024-06-22 21:51:51 +08:00
if hasattr(model, 'eval'):
# not onnx session, move weights to device
model = model.to(accelerator)
pixel_values = pixel_values.to(accelerator)
2024-02-11 08:06:50 +00:00
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
2024-02-11 08:06:50 +00:00
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
pred = model.generate(pixel_values.to(model.device), generation_config=generate_config)
2024-02-11 08:06:50 +00:00
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res