From b5dbf64716319b416cb0c5c90b3e9d01aa5ade42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Sat, 3 Feb 2024 09:40:13 +0000 Subject: [PATCH] =?UTF-8?q?=E6=8A=8A=E4=BB=A3=E7=A0=81=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E6=88=90=E4=BA=86=E6=8E=A5=E5=8F=97=E8=BE=93=E5=85=A5=E4=B8=BA?= =?UTF-8?q?png=E7=9A=84=E5=9B=BE=E7=89=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/ocr_model/utils/transforms.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py index b5b833f..26056f7 100644 --- a/src/models/ocr_model/utils/transforms.py +++ b/src/models/ocr_model/utils/transforms.py @@ -5,6 +5,7 @@ import cv2 from torchvision.transforms import v2 from typing import List, Union +from PIL import Image from ....globals import ( OCR_IMG_CHANNELS, @@ -37,12 +38,12 @@ general_transform_pipeline = v2.Compose([ ]) -def trim_white_border(image: Union[np.ndarray, List[List[List]]]): +def trim_white_border(image: np.ndarray): # image是一个3维的ndarray,RGB格式,维度分布为[H, W, C](通道维在第三维上) - # 检查images中的第一个元素是否是嵌套的列表结构 - if isinstance(image, list): - image = np.array(image, dtype=np.uint8) + # # 检查images中的第一个元素是否是嵌套的列表结构 + # if isinstance(image, list): + # image = np.array(image, dtype=np.uint8) # 检查图像是否为RGB格式,同时检查通道维是不是在第三维上 if len(image.shape) != 3 or image.shape[2] != 3: @@ -85,16 +86,16 @@ def padding(images: List[torch.Tensor], required_size: int): def random_resize( - images: Union[List[np.ndarray], List[List[List[List]]]], + images: List[np.ndarray], minr: float, maxr: float ) -> List[np.ndarray]: # np.ndarray的格式:3维,RGB格式,维度分布为[H, W, C](通道维在第三维上) - # 检查images中的第一个元素是否是嵌套的列表结构 - if isinstance(images[0], list): - # 将嵌套的列表结构转换为np.ndarray - images = [np.array(img, dtype=np.uint8) for img in images] + # # 检查images中的第一个元素是否是嵌套的列表结构 + # if isinstance(images[0], list): + # # 将嵌套的列表结构转换为np.ndarray + # images = [np.array(img, dtype=np.uint8) for img in images] if len(images[0].shape) != 3 or images[0].shape[2] != 3: raise ValueError("Image is not in RGB format or channel is not in third dimension") @@ -116,11 +117,12 @@ def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]: return images -def train_transform(images: List[List[List[List]]]) -> List[torch.Tensor]: +def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now" assert OCR_FIX_SIZE == True, "Only support fixed size images for now" # random resize first + images = [np.array(img.convert('RGB')) for img in images] images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO) return general_transform(images)