优化了transform.py中的trim_white_border
This commit is contained in:
@@ -37,8 +37,8 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
|
|||||||
data_collator=collate_fn_with_tokenizer,
|
data_collator=collate_fn_with_tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train(resume_from_checkpoint=None)
|
# trainer.train(resume_from_checkpoint=None)
|
||||||
# trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-288000')
|
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000')
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import cv2
|
|||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from typing import List
|
from typing import List
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
from ...globals import (
|
from ...globals import (
|
||||||
IMG_CHANNELS,
|
IMG_CHANNELS,
|
||||||
@@ -55,20 +56,24 @@ def trim_white_border(image: np.ndarray):
|
|||||||
if image.dtype != np.uint8:
|
if image.dtype != np.uint8:
|
||||||
raise ValueError(f"Image should stored in uint8")
|
raise ValueError(f"Image should stored in uint8")
|
||||||
|
|
||||||
|
corners = [tuple(image[0, 0]), tuple(image[0, -1]),
|
||||||
|
tuple(image[-1, 0]), tuple(image[-1, -1])]
|
||||||
|
bg_color = Counter(corners).most_common(1)[0][0]
|
||||||
|
bg_color_np = np.array(bg_color, dtype=np.uint8)
|
||||||
|
|
||||||
# 创建与原图像同样大小的纯白背景图像
|
# 创建与原图像同样大小的纯白背景图像
|
||||||
h, w = image.shape[:2]
|
h, w = image.shape[:2]
|
||||||
bg = np.full((h, w, 3), 255, dtype=np.uint8)
|
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
|
||||||
|
|
||||||
# 计算差异
|
# 计算差异
|
||||||
diff = cv2.absdiff(image, bg)
|
diff = cv2.absdiff(image, bg)
|
||||||
|
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
# 只要差值大于1,就全部转化为255
|
threshold = 15 # 接近背景色的也裁剪掉
|
||||||
_, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY)
|
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
|
||||||
|
|
||||||
# 把差值转灰度图
|
|
||||||
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
|
|
||||||
# 计算图像中非零像素点的最小外接矩阵
|
# 计算图像中非零像素点的最小外接矩阵
|
||||||
x, y, w, h = cv2.boundingRect(gray_diff)
|
x, y, w, h = cv2.boundingRect(diff)
|
||||||
|
|
||||||
# 裁剪图像
|
# 裁剪图像
|
||||||
trimmed_image = image[y:y+h, x:x+w]
|
trimmed_image = image[y:y+h, x:x+w]
|
||||||
|
|||||||
Reference in New Issue
Block a user