修改了transforms.py中inference_transform的bug: 在训练的eval阶段没有把png图片转化为np.ndarray

This commit is contained in:
三洋三洋
2024-04-10 17:06:44 +00:00
parent 762012be1f
commit 1538cb73f8

View File

@@ -4,7 +4,7 @@ import numpy as np
import cv2 import cv2
from torchvision.transforms import v2 from torchvision.transforms import v2
from typing import List from typing import List, Union
from PIL import Image from PIL import Image
from collections import Counter from collections import Counter
@@ -190,8 +190,9 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
return images return images
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]: def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
assert IMG_CHANNELS == 1 , "Only support grayscale images for now" assert IMG_CHANNELS == 1 , "Only support grayscale images for now"
images = [np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images]
# 裁剪掉白边 # 裁剪掉白边
images = [trim_white_border(image) for image in images] images = [trim_white_border(image) for image in images]
# general transform pipeline # general transform pipeline