写好了v3版本的训练代码(v3版本加入了自然场景训练增强)
This commit is contained in:
@@ -4,7 +4,7 @@ import numpy as np
|
||||
import cv2
|
||||
|
||||
from torchvision.transforms import v2
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
|
||||
from ...globals import (
|
||||
@@ -77,7 +77,6 @@ def trim_white_border(image: np.ndarray):
|
||||
return trimmed_image
|
||||
|
||||
|
||||
# BUGY CODE
|
||||
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
|
||||
randi = [random.randint(0, max_size) for _ in range(4)]
|
||||
pad_height_size = randi[1] + randi[3]
|
||||
@@ -131,9 +130,38 @@ def random_resize(
|
||||
]
|
||||
|
||||
|
||||
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
|
||||
# Get the center of the image to define the point of rotation
|
||||
image_center = tuple(np.array(image.shape[1::-1]) / 2)
|
||||
|
||||
# Generate a random angle within the specified range
|
||||
angle = random.randint(min_angle, max_angle)
|
||||
|
||||
# Get the rotation matrix for rotating the image around its center
|
||||
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
||||
|
||||
# Determine the size of the rotated image
|
||||
cos = np.abs(rotation_mat[0, 0])
|
||||
sin = np.abs(rotation_mat[0, 1])
|
||||
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
|
||||
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
|
||||
|
||||
# Adjust the rotation matrix to take into account translation
|
||||
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
|
||||
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
|
||||
|
||||
# Rotate the image with the specified border color (white in this case)
|
||||
rotated_image = cv2.warpAffine(image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255))
|
||||
|
||||
return rotated_image
|
||||
|
||||
|
||||
def ocr_aug(image: np.ndarray) -> np.ndarray:
|
||||
# 20%的概率进行随机旋转
|
||||
if random.random() < 0.2:
|
||||
image = rotate(image, -5, 5)
|
||||
# 增加白边
|
||||
image = add_white_border(image, max_size=35).permute(1, 2, 0).numpy()
|
||||
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
|
||||
# 数据增强
|
||||
image = train_pipeline(image)
|
||||
return image
|
||||
@@ -149,10 +177,7 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
# 裁剪掉白边
|
||||
images = [trim_white_border(image) for image in images]
|
||||
|
||||
# 增加白边
|
||||
# images = [add_white_border(image, max_size=35) for image in images]
|
||||
# 数据增强
|
||||
# images = [train_pipeline(image.permute(1, 2, 0).numpy()) for image in images]
|
||||
# OCR augmentation
|
||||
images = [ocr_aug(image) for image in images]
|
||||
|
||||
# general transform pipeline
|
||||
@@ -165,10 +190,11 @@ def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||
images = [np.array(img.convert('RGB')) for img in images]
|
||||
# 裁剪掉白边
|
||||
images = [trim_white_border(image) for image in images]
|
||||
# general transform pipeline
|
||||
images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image]
|
||||
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
|
||||
# padding to fixed size
|
||||
images = padding(images, OCR_IMG_SIZE)
|
||||
|
||||
@@ -190,8 +216,6 @@ if __name__ == '__main__':
|
||||
]
|
||||
imgs_path = [str(img_path) for img_path in imgs_path]
|
||||
imgs = convert2rgb(imgs_path)
|
||||
# res = train_transform(imgs)
|
||||
# res = [v2.functional.to_pil_image(img) for img in res]
|
||||
res = random_resize(imgs, 0.5, 1.5)
|
||||
pause = 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user