优化了transform.py中的trim_white_border

This commit is contained in:
三洋三洋
2024-04-10 16:09:13 +00:00
parent 1589fb3217
commit 762012be1f
2 changed files with 13 additions and 8 deletions

View File

@@ -37,8 +37,8 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
data_collator=collate_fn_with_tokenizer,
)
trainer.train(resume_from_checkpoint=None)
# trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-288000')
# trainer.train(resume_from_checkpoint=None)
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000')
def evaluate(model, tokenizer, eval_dataset, collate_fn):

View File

@@ -6,6 +6,7 @@ import cv2
from torchvision.transforms import v2
from typing import List
from PIL import Image
from collections import Counter
from ...globals import (
IMG_CHANNELS,
@@ -55,20 +56,24 @@ def trim_white_border(image: np.ndarray):
if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8")
corners = [tuple(image[0, 0]), tuple(image[0, -1]),
tuple(image[-1, 0]), tuple(image[-1, -1])]
bg_color = Counter(corners).most_common(1)[0][0]
bg_color_np = np.array(bg_color, dtype=np.uint8)
# 创建与原图像同样大小的纯白背景图像
h, w = image.shape[:2]
bg = np.full((h, w, 3), 255, dtype=np.uint8)
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
# 计算差异
diff = cv2.absdiff(image, bg)
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
# 只要差值大于1就全部转化为255
_, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY)
threshold = 15 # 接近背景色的也裁剪掉
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
# 把差值转灰度图
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
# 计算图像中非零像素点的最小外接矩阵
x, y, w, h = cv2.boundingRect(gray_diff)
x, y, w, h = cv2.boundingRect(diff)
# 裁剪图像
trimmed_image = image[y:y+h, x:x+w]