diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py index c717945..f63af07 100644 --- a/src/models/ocr_model/utils/transforms.py +++ b/src/models/ocr_model/utils/transforms.py @@ -4,7 +4,7 @@ import numpy as np import cv2 from torchvision.transforms import v2 -from typing import List +from typing import List, Union from PIL import Image from collections import Counter @@ -190,8 +190,9 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: return images -def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]: +def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]: assert IMG_CHANNELS == 1 , "Only support grayscale images for now" + images = [np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images] # 裁剪掉白边 images = [trim_white_border(image) for image in images] # general transform pipeline