修改了transforms.py中inference_transform的bug: 在训练的eval阶段没有把png图片转化为np.ndarray

This commit is contained in:
三洋三洋
2024-04-10 17:06:44 +00:00
parent 762012be1f
commit 1538cb73f8

View File

@@ -4,7 +4,7 @@ import numpy as np
import cv2
from torchvision.transforms import v2
from typing import List
from typing import List, Union
from PIL import Image
from collections import Counter
@@ -190,8 +190,9 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
return images
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
assert IMG_CHANNELS == 1 , "Only support grayscale images for now"
images = [np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images]
# 裁剪掉白边
images = [trim_white_border(image) for image in images]
# general transform pipeline