diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py index b5b833f..26056f7 100644 --- a/src/models/ocr_model/utils/transforms.py +++ b/src/models/ocr_model/utils/transforms.py @@ -5,6 +5,7 @@ import cv2 from torchvision.transforms import v2 from typing import List, Union +from PIL import Image from ....globals import ( OCR_IMG_CHANNELS, @@ -37,12 +38,12 @@ general_transform_pipeline = v2.Compose([ ]) -def trim_white_border(image: Union[np.ndarray, List[List[List]]]): +def trim_white_border(image: np.ndarray): # image是一个3维的ndarray,RGB格式,维度分布为[H, W, C](通道维在第三维上) - # 检查images中的第一个元素是否是嵌套的列表结构 - if isinstance(image, list): - image = np.array(image, dtype=np.uint8) + # # 检查images中的第一个元素是否是嵌套的列表结构 + # if isinstance(image, list): + # image = np.array(image, dtype=np.uint8) # 检查图像是否为RGB格式,同时检查通道维是不是在第三维上 if len(image.shape) != 3 or image.shape[2] != 3: @@ -85,16 +86,16 @@ def padding(images: List[torch.Tensor], required_size: int): def random_resize( - images: Union[List[np.ndarray], List[List[List[List]]]], + images: List[np.ndarray], minr: float, maxr: float ) -> List[np.ndarray]: # np.ndarray的格式:3维,RGB格式,维度分布为[H, W, C](通道维在第三维上) - # 检查images中的第一个元素是否是嵌套的列表结构 - if isinstance(images[0], list): - # 将嵌套的列表结构转换为np.ndarray - images = [np.array(img, dtype=np.uint8) for img in images] + # # 检查images中的第一个元素是否是嵌套的列表结构 + # if isinstance(images[0], list): + # # 将嵌套的列表结构转换为np.ndarray + # images = [np.array(img, dtype=np.uint8) for img in images] if len(images[0].shape) != 3 or images[0].shape[2] != 3: raise ValueError("Image is not in RGB format or channel is not in third dimension") @@ -116,11 +117,12 @@ def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]: return images -def train_transform(images: List[List[List[List]]]) -> List[torch.Tensor]: +def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now" assert OCR_FIX_SIZE == True, "Only support fixed size images for now" # random resize first + images = [np.array(img.convert('RGB')) for img in images] images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO) return general_transform(images)