2025-04-16 14:23:02 +00:00
|
|
|
from collections import Counter
|
|
|
|
|
from typing import List, Union
|
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
from PIL import Image
|
|
|
|
|
from torchvision.transforms import v2
|
|
|
|
|
|
|
|
|
|
from texteller.constants import (
|
|
|
|
|
FIXED_IMG_SIZE,
|
|
|
|
|
IMG_CHANNELS,
|
|
|
|
|
IMAGE_MEAN,
|
|
|
|
|
IMAGE_STD,
|
|
|
|
|
)
|
|
|
|
|
from texteller.logger import get_logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_logger = get_logger()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def readimgs(image_paths: list[str]) -> list[np.ndarray]:
|
|
|
|
|
"""
|
|
|
|
|
Read and preprocess a list of images from their file paths.
|
|
|
|
|
|
|
|
|
|
This function reads each image from the provided paths, handles different
|
|
|
|
|
bit depths (converting 16-bit to 8-bit if necessary), and normalizes color
|
|
|
|
|
channels to RGB format regardless of the original color space (BGR, BGRA,
|
|
|
|
|
or grayscale).
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
image_paths (list[str]): A list of file paths to the images to be read.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list[np.ndarray]: A list of NumPy arrays containing the preprocessed images
|
|
|
|
|
in RGB format. Images that could not be read are skipped.
|
|
|
|
|
"""
|
|
|
|
|
processed_images = []
|
|
|
|
|
for path in image_paths:
|
|
|
|
|
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
|
|
|
|
if image is None:
|
|
|
|
|
raise ValueError(f"Image at {path} could not be read.")
|
|
|
|
|
if image.dtype == np.uint16:
|
2025-04-21 13:52:16 +08:00
|
|
|
_logger.warning(f"Converting {path} to 8-bit, image may be lossy.")
|
2025-04-16 14:23:02 +00:00
|
|
|
image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
|
|
|
|
|
|
|
|
|
|
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
|
|
|
|
if channels == 4:
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
|
|
|
|
elif channels == 1:
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
|
|
|
|
elif channels == 3:
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
|
|
processed_images.append(image)
|
|
|
|
|
|
|
|
|
|
return processed_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def trim_white_border(image: np.ndarray) -> np.ndarray:
|
|
|
|
|
if len(image.shape) != 3 or image.shape[2] != 3:
|
|
|
|
|
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
|
|
|
|
|
|
|
|
|
if image.dtype != np.uint8:
|
|
|
|
|
raise ValueError(f"Image should stored in uint8")
|
|
|
|
|
|
|
|
|
|
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
|
|
|
|
|
bg_color = Counter(corners).most_common(1)[0][0]
|
|
|
|
|
bg_color_np = np.array(bg_color, dtype=np.uint8)
|
|
|
|
|
|
|
|
|
|
h, w = image.shape[:2]
|
|
|
|
|
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
|
|
|
|
|
|
|
|
|
|
diff = cv2.absdiff(image, bg)
|
|
|
|
|
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
|
|
|
|
|
|
|
|
|
|
threshold = 15
|
|
|
|
|
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
|
|
|
|
|
|
|
|
|
|
x, y, w, h = cv2.boundingRect(diff)
|
|
|
|
|
|
|
|
|
|
trimmed_image = image[y : y + h, x : x + w]
|
|
|
|
|
|
|
|
|
|
return trimmed_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
|
|
|
|
|
images = [
|
|
|
|
|
v2.functional.pad(
|
|
|
|
|
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
|
|
|
|
)
|
|
|
|
|
for img in images
|
|
|
|
|
]
|
|
|
|
|
return images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transform(images: List[Union[np.ndarray, Image.Image]]) -> List[torch.Tensor]:
|
|
|
|
|
general_transform_pipeline = v2.Compose(
|
|
|
|
|
[
|
|
|
|
|
v2.ToImage(),
|
|
|
|
|
v2.ToDtype(torch.uint8, scale=True),
|
|
|
|
|
v2.Grayscale(),
|
|
|
|
|
v2.Resize(
|
|
|
|
|
size=FIXED_IMG_SIZE - 1,
|
|
|
|
|
interpolation=v2.InterpolationMode.BICUBIC,
|
|
|
|
|
max_size=FIXED_IMG_SIZE,
|
|
|
|
|
antialias=True,
|
|
|
|
|
),
|
|
|
|
|
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
|
|
|
|
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
|
|
|
|
images = [
|
2025-04-21 13:52:16 +08:00
|
|
|
np.array(img.convert("RGB")) if isinstance(img, Image.Image) else img for img in images
|
2025-04-16 14:23:02 +00:00
|
|
|
]
|
|
|
|
|
images = [trim_white_border(image) for image in images]
|
|
|
|
|
images = [general_transform_pipeline(image) for image in images]
|
|
|
|
|
images = padding(images, FIXED_IMG_SIZE)
|
|
|
|
|
|
|
|
|
|
return images
|