优化了transform.py中的trim_white_border
This commit is contained in:
@@ -37,8 +37,8 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
|
||||
data_collator=collate_fn_with_tokenizer,
|
||||
)
|
||||
|
||||
trainer.train(resume_from_checkpoint=None)
|
||||
# trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-288000')
|
||||
# trainer.train(resume_from_checkpoint=None)
|
||||
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000')
|
||||
|
||||
|
||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||
|
||||
@@ -6,6 +6,7 @@ import cv2
|
||||
from torchvision.transforms import v2
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
from collections import Counter
|
||||
|
||||
from ...globals import (
|
||||
IMG_CHANNELS,
|
||||
@@ -55,20 +56,24 @@ def trim_white_border(image: np.ndarray):
|
||||
if image.dtype != np.uint8:
|
||||
raise ValueError(f"Image should stored in uint8")
|
||||
|
||||
corners = [tuple(image[0, 0]), tuple(image[0, -1]),
|
||||
tuple(image[-1, 0]), tuple(image[-1, -1])]
|
||||
bg_color = Counter(corners).most_common(1)[0][0]
|
||||
bg_color_np = np.array(bg_color, dtype=np.uint8)
|
||||
|
||||
# 创建与原图像同样大小的纯白背景图像
|
||||
h, w = image.shape[:2]
|
||||
bg = np.full((h, w, 3), 255, dtype=np.uint8)
|
||||
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
|
||||
|
||||
# 计算差异
|
||||
diff = cv2.absdiff(image, bg)
|
||||
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# 只要差值大于1,就全部转化为255
|
||||
_, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY)
|
||||
threshold = 15 # 接近背景色的也裁剪掉
|
||||
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
|
||||
|
||||
# 把差值转灰度图
|
||||
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
|
||||
# 计算图像中非零像素点的最小外接矩阵
|
||||
x, y, w, h = cv2.boundingRect(gray_diff)
|
||||
x, y, w, h = cv2.boundingRect(diff)
|
||||
|
||||
# 裁剪图像
|
||||
trimmed_image = image[y:y+h, x:x+w]
|
||||
|
||||
Reference in New Issue
Block a user