From 2d6c46b88d26fb5a177aaaaa114350479c578fc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Mon, 4 Mar 2024 05:35:59 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=A5=BD=E4=BA=86=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=EF=BC=8C=E5=8A=A0=E5=85=A5=E4=BA=86=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=A2=9E=E5=BC=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/ocr_model/utils/functional.py | 2 +- src/models/ocr_model/utils/transforms.py | 48 ++++++++++++++++-------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/src/models/ocr_model/utils/functional.py b/src/models/ocr_model/utils/functional.py index 92f40c0..969a493 100644 --- a/src/models/ocr_model/utils/functional.py +++ b/src/models/ocr_model/utils/functional.py @@ -21,7 +21,7 @@ def left_move(x: torch.Tensor, pad_val): def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]: assert tokenizer is not None, 'tokenizer should not be None' tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True) - tokenized_formula['pixel_values'] = [np.array(sample) for sample in samples['image']] + tokenized_formula['pixel_values'] = samples['image'] return tokenized_formula diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py index 199fa33..bdb5ef9 100644 --- a/src/models/ocr_model/utils/transforms.py +++ b/src/models/ocr_model/utils/transforms.py @@ -5,6 +5,7 @@ import cv2 from torchvision.transforms import v2 from typing import List, Union +from augraphy import * from PIL import Image from ...globals import ( @@ -15,12 +16,13 @@ from ...globals import ( MAX_RESIZE_RATIO, MIN_RESIZE_RATIO ) +train_pipeline = default_augraphy_pipeline() general_transform_pipeline = v2.Compose([ v2.ToImage(), # Convert to tensor, only needed if you had a PIL image - #+返回一个List of torchvision.Image,list的长度就是batch_size - #+因此在整个Compose pipeline的最后,输出的也是一个List of torchvision.Image - #+注意:不是返回一整个torchvision.Image,batch_size的维度是拿出来的 + #+返回一个List of torchvision.Image,list的长度就是batch_size + #+因此在整个Compose pipeline的最后,输出的也是一个List of torchvision.Image + #+注意:不是返回一整个torchvision.Image,batch_size的维度是拿出来的 v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point v2.Grayscale(), # 转灰度图(视具体任务而定) @@ -74,7 +76,15 @@ def trim_white_border(image: np.ndarray): return trimmed_image -def padding(images: List[torch.Tensor], required_size: int): +def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray: + randi = [random.randint(0, max_size) for _ in range(4)] + return v2.functional.pad( + image, + padding=randi + ) + + +def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]: images = [ v2.functional.pad( img, @@ -107,9 +117,19 @@ def random_resize( ] -def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]: +def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: + assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now" + assert OCR_FIX_SIZE == True, "Only support fixed size images for now" + + images = [np.array(img.convert('RGB')) for img in images] + # random resize first + images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO) # 裁剪掉白边 images = [trim_white_border(image) for image in images] + # 增加白边 + images = [add_white_border(image, max_size=35) for image in images] + # 数据增强 + images = [train_pipeline(image) for image in images] # general transform pipeline images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image] # padding to fixed size @@ -117,21 +137,17 @@ def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]: return images -def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: - assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now" - assert OCR_FIX_SIZE == True, "Only support fixed size images for now" - - # random resize first - images = [np.array(img.convert('RGB')) for img in images] - images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO) - return general_transform(images) - - def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]: assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now" assert OCR_FIX_SIZE == True, "Only support fixed size images for now" + # 裁剪掉白边 + images = [trim_white_border(image) for image in images] + # general transform pipeline + images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image] + # padding to fixed size + images = padding(images, OCR_IMG_SIZE) - return general_transform(images) + return images if __name__ == '__main__':