updated API usage (supports remote calls)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import RobertaTokenizerFast, GenerationConfig
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from models.ocr_model.model.TexTeller import TexTeller
|
||||
from models.ocr_model.utils.transforms import inference_transform
|
||||
@@ -12,12 +13,15 @@ from models.globals import MAX_TOKEN_SIZE
|
||||
def inference(
|
||||
model: TexTeller,
|
||||
tokenizer: RobertaTokenizerFast,
|
||||
imgs_path: List[str],
|
||||
imgs_path: Union[List[str], List[np.ndarray]],
|
||||
use_cuda: bool,
|
||||
num_beams: int = 1,
|
||||
) -> List[str]:
|
||||
model.eval()
|
||||
imgs = convert2rgb(imgs_path)
|
||||
if isinstance(imgs_path[0], str):
|
||||
imgs = convert2rgb(imgs_path)
|
||||
else: # already numpy array(rgb format)
|
||||
imgs = imgs_path
|
||||
imgs = inference_transform(imgs)
|
||||
pixel_values = torch.stack(imgs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user