From 274fd6cdda3c9b9ddee140b38ed6f09e6c7b042d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Fri, 2 Feb 2024 05:40:26 +0000 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E5=85=A5=E4=BA=86=E8=BE=93=E5=85=A5?= =?UTF-8?q?=E5=9B=BE=E7=89=87=E7=9A=84=E6=9C=80=E5=B0=8F=E5=AE=BD=E5=92=8C?= =?UTF-8?q?=E9=AB=98=E7=9A=84=E8=BF=87=E6=BB=A4=EF=BC=8C=E9=98=B2=E6=AD=A2?= =?UTF-8?q?=E6=B3=A8=E5=85=A5=E5=9E=83=E5=9C=BE=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/globals.py | 4 ++++ src/models/ocr_model/train/train.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/globals.py b/src/globals.py index 96afec9..664af8d 100644 --- a/src/globals.py +++ b/src/globals.py @@ -37,6 +37,10 @@ MAX_TOKEN_SIZE = 512 # 模型最长的embedding长度被设置成了512, MAX_RESIZE_RATIO = 1.15 MIN_RESIZE_RATIO = 0.75 +# ocr模型输入的图片要求的最低宽和高(过滤垃圾数据) +MIN_HEIGHT = 12 +MIN_WIDTH = 30 + # ============================================================================= # diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py index 01d6b6e..d1039ed 100644 --- a/src/models/ocr_model/train/train.py +++ b/src/models/ocr_model/train/train.py @@ -11,7 +11,7 @@ from .training_args import CONFIG from ..model.TexTeller import TexTeller from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn from ..utils.metrics import bleu_metric -from ....globals import MAX_TOKEN_SIZE +from ....globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer): @@ -68,6 +68,7 @@ if __name__ == '__main__': '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', 'cleaned_formulas' )['train'] + dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH) dataset = dataset.shuffle(seed=42) dataset = dataset.flatten_indices()