merge v3_nature_scence

This commit is contained in:
三洋三洋
2024-03-28 13:44:32 +00:00
42 changed files with 125335 additions and 959 deletions

View File

@@ -7,47 +7,96 @@ from torchvision.transforms import v2
from typing import List
from PIL import Image
from models.globals import (
from ...globals import (
IMG_CHANNELS,
FIXED_IMG_SIZE,
IMAGE_MEAN, IMAGE_STD,
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
)
from .ocr_aug import ocr_augmentation_pipeline
# train_pipeline = default_augraphy_pipeline(scan_only=True)
train_pipeline = ocr_augmentation_pipeline()
general_transform_pipeline = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.uint8, scale=True),
v2.Grayscale(),
v2.Resize(
size=FIXED_IMG_SIZE - 1,
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
#+返回一个List of torchvision.Imagelist的长度就是batch_size
#+因此在整个Compose pipeline的最后输出的也是一个List of torchvision.Image
#+注意不是返回一整个torchvision.Imagebatch_size的维度是拿出来的
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
v2.Grayscale(), # 转灰度图(视具体任务而定)
v2.Resize( # 固定resize到一个正方形上
size=FIXED_IMG_SIZE - 1, # size必须小于max_size
interpolation=v2.InterpolationMode.BICUBIC,
max_size=FIXED_IMG_SIZE,
antialias=True
),
v2.ToDtype(torch.float32, scale=True),
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
# v2.ToPILImage() # 用于观察转换后的结果是否正确debug用
])
def trim_white_border(image: np.ndarray):
# image是一个3维的ndarrayRGB格式维度分布为[H, W, C](通道维在第三维上)
# # 检查images中的第一个元素是否是嵌套的列表结构
# if isinstance(image, list):
# image = np.array(image, dtype=np.uint8)
# 检查图像是否为RGB格式同时检查通道维是不是在第三维上
if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
# 检查图片是否使用 uint8 类型
if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8")
# 创建与原图像同样大小的纯白背景图像
h, w = image.shape[:2]
bg = np.full((h, w, 3), 255, dtype=np.uint8)
# 计算差异
diff = cv2.absdiff(image, bg)
# 只要差值大于1就全部转化为255
_, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY)
# 把差值转灰度图
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
# 计算图像中非零像素点的最小外接矩阵
x, y, w, h = cv2.boundingRect(gray_diff)
# 裁剪图像
trimmed_image = image[y:y+h, x:x+w]
return trimmed_image
def padding(images: List[torch.Tensor], required_size: int):
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
randi = [random.randint(0, max_size) for _ in range(4)]
pad_height_size = randi[1] + randi[3]
pad_width_size = randi[0] + randi[2]
if (pad_height_size + image.shape[0] < 30):
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
randi[1] += compensate_height
randi[3] += compensate_height
if (pad_width_size + image.shape[1] < 30):
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
randi[0] += compensate_width
randi[2] += compensate_width
return v2.functional.pad(
torch.from_numpy(image).permute(2, 0, 1),
padding=randi,
padding_mode='constant',
fill=(255, 255, 255)
)
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
images = [
v2.functional.pad(
img,
@@ -63,6 +112,13 @@ def random_resize(
minr: float,
maxr: float
) -> List[np.ndarray]:
# np.ndarray的格式3维RGB格式维度分布为[H, W, C](通道维在第三维上)
# # 检查images中的第一个元素是否是嵌套的列表结构
# if isinstance(images[0], list):
# # 将嵌套的列表结构转换为np.ndarray
# images = [np.array(img, dtype=np.uint8) for img in images]
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
@@ -73,18 +129,90 @@ def random_resize(
]
def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
# Get the center of the image to define the point of rotation
image_center = tuple(np.array(image.shape[1::-1]) / 2)
# Generate a random angle within the specified range
angle = random.randint(min_angle, max_angle)
# Get the rotation matrix for rotating the image around its center
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
# Determine the size of the rotated image
cos = np.abs(rotation_mat[0, 0])
sin = np.abs(rotation_mat[0, 1])
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
# Adjust the rotation matrix to take into account translation
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
# Rotate the image with the specified border color (white in this case)
rotated_image = cv2.warpAffine(image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255))
return rotated_image
def ocr_aug(image: np.ndarray) -> np.ndarray:
# 20%的概率进行随机旋转
if random.random() < 0.2:
image = rotate(image, -5, 5)
# 增加白边
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
# 数据增强
image = train_pipeline(image)
return image
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
assert IMG_CHANNELS == 1 , "Only support grayscale images for now"
images = [np.array(img.convert('RGB')) for img in images]
# random resize first
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
# 裁剪掉白边
images = [trim_white_border(image) for image in images]
images = general_transform_pipeline(images)
# OCR augmentation
images = [ocr_aug(image) for image in images]
# general transform pipeline
images = [general_transform_pipeline(image) for image in images]
# padding to fixed size
images = padding(images, FIXED_IMG_SIZE)
return images
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
images = [np.array(img.convert('RGB')) for img in images]
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
return general_transform(images)
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
return general_transform(images)
assert IMG_CHANNELS == 1 , "Only support grayscale images for now"
images = [np.array(img.convert('RGB')) for img in images]
# 裁剪掉白边
images = [trim_white_border(image) for image in images]
# general transform pipeline
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
# padding to fixed size
images = padding(images, FIXED_IMG_SIZE)
return images
if __name__ == '__main__':
from pathlib import Path
from .helpers import convert2rgb
base_dir = Path('/home/lhy/code/TeXify/src/models/ocr_model/model')
imgs_path = [
base_dir / '1.jpg',
base_dir / '2.jpg',
base_dir / '3.jpg',
base_dir / '4.jpg',
base_dir / '5.jpg',
base_dir / '6.jpg',
base_dir / '7.jpg',
]
imgs_path = [str(img_path) for img_path in imgs_path]
imgs = convert2rgb(imgs_path)
res = random_resize(imgs, 0.5, 1.5)
pause = 1