Initial Commit

This commit is contained in:
三洋三洋
2024-01-31 10:11:07 +00:00
parent b7bf5c444f
commit 1fba652766
7 changed files with 193 additions and 89 deletions

View File

@@ -1,67 +1,155 @@
import torch
import random
import numpy as np
import cv2
from torchvision.transforms import v2
from PIL import ImageChops, Image
from typing import List
from typing import List, Union
from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN, IMAGE_STD
from ....globals import (
OCR_IMG_CHANNELS,
OCR_IMG_SIZE,
OCR_FIX_SIZE,
IMAGE_MEAN, IMAGE_STD,
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
)
def trim_white_border(image: Image.Image):
if image.mode == 'RGB':
bg_color = (255, 255, 255)
elif image.mode == 'L':
bg_color = 255
else:
raise ValueError("Only support RGB or L mode")
# 创建一个与图片一样大小的白色背景
bg = Image.new(image.mode, image.size, bg_color)
# 计算原图像与背景图像的差异。如果原图像在边框区域与左上角像素颜色相同,那么这些区域在差异图像中将是黑色的。
diff = ImageChops.difference(image, bg)
# 这一步增强差异图像中的对比度,使非背景区域更加明显。这对确定边界框有帮助,但参数的选择可能需要根据具体图像进行调整。
diff = ImageChops.add(diff, diff, 2.0, -100)
# 找到差异图像中非黑色区域的边界框。如果找到,原图将根据这个边界框被裁剪。
bbox = diff.getbbox()
return image.crop(bbox) if bbox else image
general_transform_pipeline = v2.Compose([
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
#+返回一个List of torchvision.Imagelist的长度就是batch_size
#+因此在整个Compose pipeline的最后输出的也是一个List of torchvision.Image
#+注意不是返回一整个torchvision.Imagebatch_size的维度是拿出来的
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
v2.Grayscale(), # 转灰度图(视具体任务而定)
v2.Resize( # 固定resize到一个正方形上
size=OCR_IMG_SIZE - 1, # size必须小于max_size
interpolation=v2.InterpolationMode.BICUBIC,
max_size=OCR_IMG_SIZE,
antialias=True
),
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
# v2.ToPILImage() # 用于观察转换后的结果是否正确debug用
])
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
images = [trim_white_border(image) for image in images]
transforms = v2.Compose([
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
#+返回一个List of torchvision.Imagelist的长度就是batch_size
#+因此在整个Compose pipeline的最后输出的也是一个List of torchvision.Image
#+注意不是返回一整个torchvision.Imagebatch_size的维度是拿出来的
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
v2.Grayscale(), # 转灰度图(视具体任务而定)
def trim_white_border(image: Union[np.ndarray, List[List[List]]]):
# image是一个3维的ndarrayRGB格式维度分布为[H, W, C](通道维在第三维上)
v2.Resize( # 固定resize到一个正方形上
size=OCR_IMG_SIZE - 1, # size必须小于max_size
interpolation=v2.InterpolationMode.BICUBIC,
max_size=OCR_IMG_SIZE,
antialias=True
),
# 检查images中的第一个元素是否是嵌套的列表结构
if isinstance(image, list):
image = np.array(image, dtype=np.uint8)
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
# 检查图像是否为RGB格式同时检查通道维是不是在第三维上
if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
# v2.ToPILImage() # 用于观察转换后的结果是否正确debug用
])
# 检查图片是否使用 uint8 类型
if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8")
images = transforms(images) # imgs: List[PIL.Image.Image]
# 创建与原图像同样大小的纯白背景图像
h, w = image.shape[:2]
bg = np.full((h, w, 3), 255, dtype=np.uint8)
# 计算差异
diff = cv2.absdiff(image, bg)
# 只要差值大于1就全部转化为255
_, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY)
# 把差值转灰度图
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
# 计算图像中非零像素点的最小外接矩阵
x, y, w, h = cv2.boundingRect(gray_diff)
# 裁剪图像
trimmed_image = image[y:y+h, x:x+w]
return trimmed_image
def padding(images: List[torch.Tensor], required_size: int):
images = [
v2.functional.pad(
img,
padding=[0, 0, OCR_IMG_SIZE - img.shape[2], OCR_IMG_SIZE - img.shape[1]]
padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
)
for img in images
]
return images
def random_resize(
images: Union[List[np.ndarray], List[List[List[List]]]],
minr: float,
maxr: float
) -> List[np.ndarray]:
# np.ndarray的格式3维RGB格式维度分布为[H, W, C](通道维在第三维上)
# 检查images中的第一个元素是否是嵌套的列表结构
if isinstance(images[0], list):
# 将嵌套的列表结构转换为np.ndarray
images = [np.array(img, dtype=np.uint8) for img in images]
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
return [
cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4) # 抗锯齿
for img, r in zip(images, ratios)
]
def general_transform(images: List[Image.Image]) -> List[torch.Tensor]:
# 裁剪掉白边
images = [trim_white_border(image) for image in images]
# general transform pipeline
images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image]
# padding to fixed size
images = padding(images, OCR_IMG_SIZE)
return images
def train_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
# random resize first
# images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
return general_transform(images)
def inference_transform(images: List[Image.Image]) -> List[torch.Tensor]:
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
return train_transform(images)
return general_transform(images)
if __name__ == '__main__':
from pathlib import Path
from .helpers import convert2rgb
base_dir = Path('/home/lhy/code/TeXify/src/models/ocr_model/model')
imgs_path = [
base_dir / '1.jpg',
base_dir / '2.jpg',
base_dir / '3.jpg',
base_dir / '4.jpg',
base_dir / '5.jpg',
base_dir / '6.jpg',
base_dir / '7.jpg',
]
imgs_path = [str(img_path) for img_path in imgs_path]
imgs = convert2rgb(imgs_path)
# res = train_transform(imgs)
# res = [v2.functional.to_pil_image(img) for img in res]
res = random_resize(imgs, 0.5, 1.5)
pause = 1