把代码修改成了接受输入为png的图片

This commit is contained in:
三洋三洋
2024-02-03 09:40:13 +00:00
parent 274fd6cdda
commit b5dbf64716

View File

@@ -5,6 +5,7 @@ import cv2
from torchvision.transforms import v2
from typing import List, Union
from PIL import Image
from ....globals import (
OCR_IMG_CHANNELS,
@@ -37,12 +38,12 @@ general_transform_pipeline = v2.Compose([
])
def trim_white_border(image: Union[np.ndarray, List[List[List]]]):
def trim_white_border(image: np.ndarray):
# image是一个3维的ndarrayRGB格式维度分布为[H, W, C](通道维在第三维上)
# 检查images中的第一个元素是否是嵌套的列表结构
if isinstance(image, list):
image = np.array(image, dtype=np.uint8)
# # 检查images中的第一个元素是否是嵌套的列表结构
# if isinstance(image, list):
# image = np.array(image, dtype=np.uint8)
# 检查图像是否为RGB格式同时检查通道维是不是在第三维上
if len(image.shape) != 3 or image.shape[2] != 3:
@@ -85,16 +86,16 @@ def padding(images: List[torch.Tensor], required_size: int):
def random_resize(
images: Union[List[np.ndarray], List[List[List[List]]]],
images: List[np.ndarray],
minr: float,
maxr: float
) -> List[np.ndarray]:
# np.ndarray的格式3维RGB格式维度分布为[H, W, C](通道维在第三维上)
# 检查images中的第一个元素是否是嵌套的列表结构
if isinstance(images[0], list):
# 将嵌套的列表结构转换为np.ndarray
images = [np.array(img, dtype=np.uint8) for img in images]
# # 检查images中的第一个元素是否是嵌套的列表结构
# if isinstance(images[0], list):
# # 将嵌套的列表结构转换为np.ndarray
# images = [np.array(img, dtype=np.uint8) for img in images]
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
@@ -116,11 +117,12 @@ def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
return images
def train_transform(images: List[List[List[List]]]) -> List[torch.Tensor]:
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
# random resize first
images = [np.array(img.convert('RGB')) for img in images]
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
return general_transform(images)