From 1538cb73f81024c32299e4b4edbcf846ad1ff5f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Wed, 10 Apr 2024 17:06:44 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86transforms.py?= =?UTF-8?q?=E4=B8=ADinference=5Ftransform=E7=9A=84bug:=20=E5=9C=A8?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E7=9A=84eval=E9=98=B6=E6=AE=B5=E6=B2=A1?= =?UTF-8?q?=E6=9C=89=E6=8A=8Apng=E5=9B=BE=E7=89=87=E8=BD=AC=E5=8C=96?= =?UTF-8?q?=E4=B8=BAnp.ndarray?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/ocr_model/utils/transforms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py index c717945..f63af07 100644 --- a/src/models/ocr_model/utils/transforms.py +++ b/src/models/ocr_model/utils/transforms.py @@ -4,7 +4,7 @@ import numpy as np import cv2 from torchvision.transforms import v2 -from typing import List +from typing import List, Union from PIL import Image from collections import Counter @@ -190,8 +190,9 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: return images -def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]: +def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]: assert IMG_CHANNELS == 1 , "Only support grayscale images for now" + images = [np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images] # 裁剪掉白边 images = [trim_white_border(image) for image in images] # general transform pipeline