1) 加入了推理代码; 2) 整理了其他代码
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from torchvision.transforms import v2
|
||||
from PIL import ImageChops, Image
|
||||
from typing import Any, Dict, List
|
||||
from typing import List
|
||||
|
||||
from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN, IMAGE_STD
|
||||
|
||||
@@ -11,12 +10,10 @@ from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN
|
||||
def trim_white_border(image: Image.Image):
|
||||
if image.mode == 'RGB':
|
||||
bg_color = (255, 255, 255)
|
||||
elif image.mode == 'RGBA':
|
||||
bg_color = (255, 255, 255, 255)
|
||||
elif image.mode == 'L':
|
||||
bg_color = 255
|
||||
else:
|
||||
raise ValueError("Unsupported image mode")
|
||||
raise ValueError("Only support RGB or L mode")
|
||||
# 创建一个与图片一样大小的白色背景
|
||||
bg = Image.new(image.mode, image.size, bg_color)
|
||||
# 计算原图像与背景图像的差异。如果原图像在边框区域与左上角像素颜色相同,那么这些区域在差异图像中将是黑色的。
|
||||
@@ -25,8 +22,7 @@ def trim_white_border(image: Image.Image):
|
||||
diff = ImageChops.add(diff, diff, 2.0, -100)
|
||||
# 找到差异图像中非黑色区域的边界框。如果找到,原图将根据这个边界框被裁剪。
|
||||
bbox = diff.getbbox()
|
||||
if bbox:
|
||||
return image.crop(bbox)
|
||||
return image.crop(bbox) if bbox else image
|
||||
|
||||
|
||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user