2024-02-11 08:06:50 +00:00
|
|
|
import torch
|
2024-02-27 07:13:36 +00:00
|
|
|
import numpy as np
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
from transformers import RobertaTokenizerFast, GenerationConfig
|
2024-02-27 07:13:36 +00:00
|
|
|
from typing import List, Union
|
2024-02-11 08:06:50 +00:00
|
|
|
|
2024-04-21 00:05:14 +08:00
|
|
|
from .transforms import inference_transform
|
|
|
|
|
from .helpers import convert2rgb
|
|
|
|
|
from ..model.TexTeller import TexTeller
|
|
|
|
|
from ...globals import MAX_TOKEN_SIZE
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(
|
2025-02-28 19:56:49 +08:00
|
|
|
model: TexTeller,
|
2024-02-11 08:06:50 +00:00
|
|
|
tokenizer: RobertaTokenizerFast,
|
2025-02-28 19:56:49 +08:00
|
|
|
imgs: Union[List[str], List[np.ndarray]],
|
2024-04-21 00:05:14 +08:00
|
|
|
accelerator: str = 'cpu',
|
2024-02-11 08:06:50 +00:00
|
|
|
num_beams: int = 1,
|
2025-02-28 19:56:49 +08:00
|
|
|
max_tokens=None,
|
2024-02-11 08:06:50 +00:00
|
|
|
) -> List[str]:
|
2024-04-21 00:05:14 +08:00
|
|
|
if imgs == []:
|
|
|
|
|
return []
|
2024-06-22 21:51:51 +08:00
|
|
|
if hasattr(model, 'eval'):
|
|
|
|
|
# not onnx session, turn model.eval()
|
|
|
|
|
model.eval()
|
2024-04-05 07:25:06 +00:00
|
|
|
if isinstance(imgs[0], str):
|
2025-02-28 19:56:49 +08:00
|
|
|
imgs = convert2rgb(imgs)
|
2024-02-27 07:13:36 +00:00
|
|
|
else: # already numpy array(rgb format)
|
2024-04-05 07:25:06 +00:00
|
|
|
assert isinstance(imgs[0], np.ndarray)
|
2025-02-28 19:56:49 +08:00
|
|
|
imgs = imgs
|
2024-02-11 08:06:50 +00:00
|
|
|
imgs = inference_transform(imgs)
|
|
|
|
|
pixel_values = torch.stack(imgs)
|
|
|
|
|
|
2024-06-22 21:51:51 +08:00
|
|
|
if hasattr(model, 'eval'):
|
|
|
|
|
# not onnx session, move weights to device
|
|
|
|
|
model = model.to(accelerator)
|
2024-04-21 00:05:14 +08:00
|
|
|
pixel_values = pixel_values.to(accelerator)
|
2024-02-11 08:06:50 +00:00
|
|
|
|
|
|
|
|
generate_config = GenerationConfig(
|
2024-04-21 00:05:14 +08:00
|
|
|
max_new_tokens=MAX_TOKEN_SIZE if max_tokens is None else max_tokens,
|
2024-02-11 08:06:50 +00:00
|
|
|
num_beams=num_beams,
|
|
|
|
|
do_sample=False,
|
|
|
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
|
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
|
|
|
bos_token_id=tokenizer.bos_token_id,
|
|
|
|
|
)
|
2025-02-28 19:56:49 +08:00
|
|
|
pred = model.generate(pixel_values.to(model.device), generation_config=generate_config)
|
2024-02-11 08:06:50 +00:00
|
|
|
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
|
|
|
|
return res
|