[chore] exclude paddleocr directory from pre-commit hooks
This commit is contained in:
177
texteller/models/ocr_model/utils/transforms.py
Normal file
177
texteller/models/ocr_model/utils/transforms.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from torchvision.transforms import v2
|
||||
from typing import List, Union
|
||||
from PIL import Image
|
||||
from collections import Counter
|
||||
|
||||
from ...globals import (
|
||||
IMG_CHANNELS,
|
||||
FIXED_IMG_SIZE,
|
||||
IMAGE_MEAN,
|
||||
IMAGE_STD,
|
||||
MAX_RESIZE_RATIO,
|
||||
MIN_RESIZE_RATIO,
|
||||
)
|
||||
from .ocr_aug import ocr_augmentation_pipeline
|
||||
|
||||
# train_pipeline = default_augraphy_pipeline(scan_only=True)
|
||||
train_pipeline = ocr_augmentation_pipeline()
|
||||
|
||||
general_transform_pipeline = v2.Compose(
|
||||
[
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
|
||||
v2.Grayscale(),
|
||||
v2.Resize(
|
||||
size=FIXED_IMG_SIZE - 1,
|
||||
interpolation=v2.InterpolationMode.BICUBIC,
|
||||
max_size=FIXED_IMG_SIZE,
|
||||
antialias=True,
|
||||
),
|
||||
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
||||
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
||||
# v2.ToPILImage()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def trim_white_border(image: np.ndarray):
|
||||
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
if image.dtype != np.uint8:
|
||||
raise ValueError(f"Image should stored in uint8")
|
||||
|
||||
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
|
||||
bg_color = Counter(corners).most_common(1)[0][0]
|
||||
bg_color_np = np.array(bg_color, dtype=np.uint8)
|
||||
|
||||
h, w = image.shape[:2]
|
||||
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
|
||||
|
||||
diff = cv2.absdiff(image, bg)
|
||||
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
threshold = 15
|
||||
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
|
||||
|
||||
x, y, w, h = cv2.boundingRect(diff)
|
||||
|
||||
trimmed_image = image[y : y + h, x : x + w]
|
||||
|
||||
return trimmed_image
|
||||
|
||||
|
||||
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
|
||||
randi = [random.randint(0, max_size) for _ in range(4)]
|
||||
pad_height_size = randi[1] + randi[3]
|
||||
pad_width_size = randi[0] + randi[2]
|
||||
if pad_height_size + image.shape[0] < 30:
|
||||
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
|
||||
randi[1] += compensate_height
|
||||
randi[3] += compensate_height
|
||||
if pad_width_size + image.shape[1] < 30:
|
||||
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
|
||||
randi[0] += compensate_width
|
||||
randi[2] += compensate_width
|
||||
return v2.functional.pad(
|
||||
torch.from_numpy(image).permute(2, 0, 1),
|
||||
padding=randi,
|
||||
padding_mode='constant',
|
||||
fill=(255, 255, 255),
|
||||
)
|
||||
|
||||
|
||||
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
|
||||
images = [
|
||||
v2.functional.pad(
|
||||
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def random_resize(images: List[np.ndarray], minr: float, maxr: float) -> List[np.ndarray]:
|
||||
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
|
||||
return [
|
||||
cv2.resize(
|
||||
img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4
|
||||
) # 抗锯齿
|
||||
for img, r in zip(images, ratios)
|
||||
]
|
||||
|
||||
|
||||
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
|
||||
# Get the center of the image to define the point of rotation
|
||||
image_center = tuple(np.array(image.shape[1::-1]) / 2)
|
||||
|
||||
# Generate a random angle within the specified range
|
||||
angle = random.randint(min_angle, max_angle)
|
||||
|
||||
# Get the rotation matrix for rotating the image around its center
|
||||
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
||||
|
||||
# Determine the size of the rotated image
|
||||
cos = np.abs(rotation_mat[0, 0])
|
||||
sin = np.abs(rotation_mat[0, 1])
|
||||
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
|
||||
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
|
||||
|
||||
# Adjust the rotation matrix to take into account translation
|
||||
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
|
||||
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
|
||||
|
||||
# Rotate the image with the specified border color (white in this case)
|
||||
rotated_image = cv2.warpAffine(
|
||||
image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255)
|
||||
)
|
||||
|
||||
return rotated_image
|
||||
|
||||
|
||||
def ocr_aug(image: np.ndarray) -> np.ndarray:
|
||||
if random.random() < 0.2:
|
||||
image = rotate(image, -5, 5)
|
||||
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
|
||||
image = train_pipeline(image)
|
||||
return image
|
||||
|
||||
|
||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
||||
|
||||
images = [np.array(img.convert('RGB')) for img in images]
|
||||
# random resize first
|
||||
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||
images = [trim_white_border(image) for image in images]
|
||||
|
||||
# OCR augmentation
|
||||
images = [ocr_aug(image) for image in images]
|
||||
|
||||
# general transform pipeline
|
||||
images = [general_transform_pipeline(image) for image in images]
|
||||
# padding to fixed size
|
||||
images = padding(images, FIXED_IMG_SIZE)
|
||||
return images
|
||||
|
||||
|
||||
def inference_transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
|
||||
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
||||
images = [
|
||||
np.array(img.convert('RGB')) if isinstance(img, Image.Image) else img for img in images
|
||||
]
|
||||
images = [trim_white_border(image) for image in images]
|
||||
# general transform pipeline
|
||||
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
|
||||
# padding to fixed size
|
||||
images = padding(images, FIXED_IMG_SIZE)
|
||||
|
||||
return images
|
||||
Reference in New Issue
Block a user