优化了transform.py中的trim_white_border

This commit is contained in:
三洋三洋
2024-04-10 16:09:13 +00:00
parent aaee57acd2
commit 5c58b88c96
2 changed files with 13 additions and 8 deletions

View File

@@ -37,8 +37,8 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
data_collator=collate_fn_with_tokenizer, data_collator=collate_fn_with_tokenizer,
) )
trainer.train(resume_from_checkpoint=None) # trainer.train(resume_from_checkpoint=None)
# trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-288000') trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000')
def evaluate(model, tokenizer, eval_dataset, collate_fn): def evaluate(model, tokenizer, eval_dataset, collate_fn):

View File

@@ -6,6 +6,7 @@ import cv2
from torchvision.transforms import v2 from torchvision.transforms import v2
from typing import List from typing import List
from PIL import Image from PIL import Image
from collections import Counter
from ...globals import ( from ...globals import (
IMG_CHANNELS, IMG_CHANNELS,
@@ -55,20 +56,24 @@ def trim_white_border(image: np.ndarray):
if image.dtype != np.uint8: if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8") raise ValueError(f"Image should stored in uint8")
corners = [tuple(image[0, 0]), tuple(image[0, -1]),
tuple(image[-1, 0]), tuple(image[-1, -1])]
bg_color = Counter(corners).most_common(1)[0][0]
bg_color_np = np.array(bg_color, dtype=np.uint8)
# 创建与原图像同样大小的纯白背景图像 # 创建与原图像同样大小的纯白背景图像
h, w = image.shape[:2] h, w = image.shape[:2]
bg = np.full((h, w, 3), 255, dtype=np.uint8) bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
# 计算差异 # 计算差异
diff = cv2.absdiff(image, bg) diff = cv2.absdiff(image, bg)
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
# 只要差值大于1就全部转化为255 threshold = 15 # 接近背景色的也裁剪掉
_, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY) _, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
# 把差值转灰度图
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
# 计算图像中非零像素点的最小外接矩阵 # 计算图像中非零像素点的最小外接矩阵
x, y, w, h = cv2.boundingRect(gray_diff) x, y, w, h = cv2.boundingRect(diff)
# 裁剪图像 # 裁剪图像
trimmed_image = image[y:y+h, x:x+w] trimmed_image = image[y:y+h, x:x+w]