From 762012be1f2267b8dc475b32015c39282ee6dd4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Wed, 10 Apr 2024 16:09:13 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BA=86transform.py?= =?UTF-8?q?=E4=B8=AD=E7=9A=84trim=5Fwhite=5Fborder?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/ocr_model/train/train.py | 4 ++-- src/models/ocr_model/utils/transforms.py | 17 +++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py index 49c0c64..25671de 100644 --- a/src/models/ocr_model/train/train.py +++ b/src/models/ocr_model/train/train.py @@ -37,8 +37,8 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz data_collator=collate_fn_with_tokenizer, ) - trainer.train(resume_from_checkpoint=None) - # trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-288000') + # trainer.train(resume_from_checkpoint=None) + trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000') def evaluate(model, tokenizer, eval_dataset, collate_fn): diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py index 7c014ae..c717945 100644 --- a/src/models/ocr_model/utils/transforms.py +++ b/src/models/ocr_model/utils/transforms.py @@ -6,6 +6,7 @@ import cv2 from torchvision.transforms import v2 from typing import List from PIL import Image +from collections import Counter from ...globals import ( IMG_CHANNELS, @@ -55,20 +56,24 @@ def trim_white_border(image: np.ndarray): if image.dtype != np.uint8: raise ValueError(f"Image should stored in uint8") + corners = [tuple(image[0, 0]), tuple(image[0, -1]), + tuple(image[-1, 0]), tuple(image[-1, -1])] + bg_color = Counter(corners).most_common(1)[0][0] + bg_color_np = np.array(bg_color, dtype=np.uint8) + # 创建与原图像同样大小的纯白背景图像 h, w = image.shape[:2] - bg = np.full((h, w, 3), 255, dtype=np.uint8) + bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8) # 计算差异 diff = cv2.absdiff(image, bg) + mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY) - # 只要差值大于1,就全部转化为255 - _, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY) + threshold = 15 # 接近背景色的也裁剪掉 + _, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY) - # 把差值转灰度图 - gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY) # 计算图像中非零像素点的最小外接矩阵 - x, y, w, h = cv2.boundingRect(gray_diff) + x, y, w, h = cv2.boundingRect(diff) # 裁剪图像 trimmed_image = image[y:y+h, x:x+w]