updated API usage (supports remote calls)

This commit is contained in:
三洋三洋
2024-02-27 07:13:36 +00:00
parent b4537944d0
commit 3527a4af47
3 changed files with 22 additions and 10 deletions

View File

@@ -1,7 +1,8 @@
import torch
import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List
from typing import List, Union
from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.transforms import inference_transform
@@ -12,12 +13,15 @@ from models.globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs_path: List[str],
imgs_path: Union[List[str], List[np.ndarray]],
use_cuda: bool,
num_beams: int = 1,
) -> List[str]:
model.eval()
imgs = convert2rgb(imgs_path)
if isinstance(imgs_path[0], str):
imgs = convert2rgb(imgs_path)
else: # already numpy array(rgb format)
imgs = imgs_path
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)