把代码修改成了接受输入为png的图片
This commit is contained in:
@@ -5,6 +5,7 @@ import cv2
|
||||
|
||||
from torchvision.transforms import v2
|
||||
from typing import List, Union
|
||||
from PIL import Image
|
||||
|
||||
from ....globals import (
|
||||
OCR_IMG_CHANNELS,
|
||||
@@ -37,12 +38,12 @@ general_transform_pipeline = v2.Compose([
|
||||
])
|
||||
|
||||
|
||||
def trim_white_border(image: Union[np.ndarray, List[List[List]]]):
|
||||
def trim_white_border(image: np.ndarray):
|
||||
# image是一个3维的ndarray,RGB格式,维度分布为[H, W, C](通道维在第三维上)
|
||||
|
||||
# 检查images中的第一个元素是否是嵌套的列表结构
|
||||
if isinstance(image, list):
|
||||
image = np.array(image, dtype=np.uint8)
|
||||
# # 检查images中的第一个元素是否是嵌套的列表结构
|
||||
# if isinstance(image, list):
|
||||
# image = np.array(image, dtype=np.uint8)
|
||||
|
||||
# 检查图像是否为RGB格式,同时检查通道维是不是在第三维上
|
||||
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||
@@ -85,16 +86,16 @@ def padding(images: List[torch.Tensor], required_size: int):
|
||||
|
||||
|
||||
def random_resize(
|
||||
images: Union[List[np.ndarray], List[List[List[List]]]],
|
||||
images: List[np.ndarray],
|
||||
minr: float,
|
||||
maxr: float
|
||||
) -> List[np.ndarray]:
|
||||
# np.ndarray的格式:3维,RGB格式,维度分布为[H, W, C](通道维在第三维上)
|
||||
|
||||
# 检查images中的第一个元素是否是嵌套的列表结构
|
||||
if isinstance(images[0], list):
|
||||
# 将嵌套的列表结构转换为np.ndarray
|
||||
images = [np.array(img, dtype=np.uint8) for img in images]
|
||||
# # 检查images中的第一个元素是否是嵌套的列表结构
|
||||
# if isinstance(images[0], list):
|
||||
# # 将嵌套的列表结构转换为np.ndarray
|
||||
# images = [np.array(img, dtype=np.uint8) for img in images]
|
||||
|
||||
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
@@ -116,11 +117,12 @@ def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||
return images
|
||||
|
||||
|
||||
def train_transform(images: List[List[List[List]]]) -> List[torch.Tensor]:
|
||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||
|
||||
# random resize first
|
||||
images = [np.array(img.convert('RGB')) for img in images]
|
||||
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||
return general_transform(images)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user