修改了transforms.py中inference_transform的bug: 在训练的eval阶段没有把png图片转化为np.ndarray
This commit is contained in:
@@ -4,7 +4,7 @@ import numpy as np
|
|||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
@@ -190,8 +190,9 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
|
||||||
assert IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
assert IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||||
|
images = [np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images]
|
||||||
# 裁剪掉白边
|
# 裁剪掉白边
|
||||||
images = [trim_white_border(image) for image in images]
|
images = [trim_white_border(image) for image in images]
|
||||||
# general transform pipeline
|
# general transform pipeline
|
||||||
|
|||||||
Reference in New Issue
Block a user