把代码修改成了接受输入为png的图片
This commit is contained in:
@@ -5,6 +5,7 @@ import cv2
|
|||||||
|
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from ....globals import (
|
from ....globals import (
|
||||||
OCR_IMG_CHANNELS,
|
OCR_IMG_CHANNELS,
|
||||||
@@ -37,12 +38,12 @@ general_transform_pipeline = v2.Compose([
|
|||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def trim_white_border(image: Union[np.ndarray, List[List[List]]]):
|
def trim_white_border(image: np.ndarray):
|
||||||
# image是一个3维的ndarray,RGB格式,维度分布为[H, W, C](通道维在第三维上)
|
# image是一个3维的ndarray,RGB格式,维度分布为[H, W, C](通道维在第三维上)
|
||||||
|
|
||||||
# 检查images中的第一个元素是否是嵌套的列表结构
|
# # 检查images中的第一个元素是否是嵌套的列表结构
|
||||||
if isinstance(image, list):
|
# if isinstance(image, list):
|
||||||
image = np.array(image, dtype=np.uint8)
|
# image = np.array(image, dtype=np.uint8)
|
||||||
|
|
||||||
# 检查图像是否为RGB格式,同时检查通道维是不是在第三维上
|
# 检查图像是否为RGB格式,同时检查通道维是不是在第三维上
|
||||||
if len(image.shape) != 3 or image.shape[2] != 3:
|
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||||
@@ -85,16 +86,16 @@ def padding(images: List[torch.Tensor], required_size: int):
|
|||||||
|
|
||||||
|
|
||||||
def random_resize(
|
def random_resize(
|
||||||
images: Union[List[np.ndarray], List[List[List[List]]]],
|
images: List[np.ndarray],
|
||||||
minr: float,
|
minr: float,
|
||||||
maxr: float
|
maxr: float
|
||||||
) -> List[np.ndarray]:
|
) -> List[np.ndarray]:
|
||||||
# np.ndarray的格式:3维,RGB格式,维度分布为[H, W, C](通道维在第三维上)
|
# np.ndarray的格式:3维,RGB格式,维度分布为[H, W, C](通道维在第三维上)
|
||||||
|
|
||||||
# 检查images中的第一个元素是否是嵌套的列表结构
|
# # 检查images中的第一个元素是否是嵌套的列表结构
|
||||||
if isinstance(images[0], list):
|
# if isinstance(images[0], list):
|
||||||
# 将嵌套的列表结构转换为np.ndarray
|
# # 将嵌套的列表结构转换为np.ndarray
|
||||||
images = [np.array(img, dtype=np.uint8) for img in images]
|
# images = [np.array(img, dtype=np.uint8) for img in images]
|
||||||
|
|
||||||
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
||||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||||
@@ -116,11 +117,12 @@ def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
def train_transform(images: List[List[List[List]]]) -> List[torch.Tensor]:
|
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||||
|
|
||||||
# random resize first
|
# random resize first
|
||||||
|
images = [np.array(img.convert('RGB')) for img in images]
|
||||||
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||||
return general_transform(images)
|
return general_transform(images)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user