76 lines
2.1 KiB
Python
76 lines
2.1 KiB
Python
import torch
|
|
from torchvision.transforms import v2
|
|
|
|
from PIL import Image, ImageChops
|
|
from ....globals import (
|
|
IMAGE_MEAN, IMAGE_STD,
|
|
LABEL_RATIO,
|
|
RESIZER_IMG_SIZE,
|
|
NUM_CHANNELS
|
|
)
|
|
|
|
from typing import (
|
|
Any,
|
|
List,
|
|
Dict,
|
|
)
|
|
|
|
|
|
def trim_white_border(image: Image):
|
|
if image.mode == 'RGB':
|
|
bg_color = (255, 255, 255)
|
|
elif image.mode == 'RGBA':
|
|
bg_color = (255, 255, 255, 255)
|
|
elif image.mode == 'L':
|
|
bg_color = 255
|
|
else:
|
|
raise ValueError("Unsupported image mode")
|
|
bg = Image.new(image.mode, image.size, bg_color)
|
|
diff = ImageChops.difference(image, bg)
|
|
diff = ImageChops.add(diff, diff, 2.0, -100)
|
|
bbox = diff.getbbox()
|
|
if bbox:
|
|
return image.crop(bbox)
|
|
|
|
|
|
def preprocess_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
|
imgs = samples['pixel_values']
|
|
imgs = [trim_white_border(img) for img in imgs]
|
|
labels = [float(img.height * LABEL_RATIO) for img in imgs]
|
|
|
|
assert NUM_CHANNELS == 1, "Only support grayscale images"
|
|
transform = v2.Compose([
|
|
v2.ToImage(),
|
|
v2.ToDtype(torch.uint8, scale=True),
|
|
v2.Grayscale(),
|
|
v2.Resize(
|
|
size=RESIZER_IMG_SIZE - 1, # size必须小于max_size
|
|
interpolation=v2.InterpolationMode.BICUBIC,
|
|
max_size=RESIZER_IMG_SIZE,
|
|
antialias=True
|
|
),
|
|
v2.ToDtype(torch.float32, scale=True),
|
|
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
|
])
|
|
imgs = transform(imgs)
|
|
imgs = [
|
|
v2.functional.pad(
|
|
img,
|
|
padding=[0, 0, RESIZER_IMG_SIZE - img.shape[2], RESIZER_IMG_SIZE - img.shape[1]]
|
|
)
|
|
for img in imgs
|
|
]
|
|
|
|
res = {'pixel_values': imgs, 'labels': labels}
|
|
return res
|
|
|
|
|
|
if __name__ == "__main__": # unit test
|
|
import datasets
|
|
data = datasets.load_dataset("/home/lhy/code/TeXify/src/models/resizer/train/dataset/dataset.py").shuffle(seed=42)
|
|
data = data.with_transform(preprocess_fn)
|
|
train_data, test_data = data['train'], data['test']
|
|
|
|
inpu = train_data[:10]
|
|
pause = 1
|