加入了输入图片的最小宽和高的过滤,防止注入垃圾数据

This commit is contained in:
三洋三洋
2024-02-02 05:40:26 +00:00
parent ab1a05bf32
commit 274fd6cdda
2 changed files with 6 additions and 1 deletions

View File

@@ -37,6 +37,10 @@ MAX_TOKEN_SIZE = 512 # 模型最长的embedding长度被设置成了512
MAX_RESIZE_RATIO = 1.15
MIN_RESIZE_RATIO = 0.75
# ocr模型输入的图片要求的最低宽和高(过滤垃圾数据)
MIN_HEIGHT = 12
MIN_WIDTH = 30
# ============================================================================= #

View File

@@ -11,7 +11,7 @@ from .training_args import CONFIG
from ..model.TexTeller import TexTeller
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
from ..utils.metrics import bleu_metric
from ....globals import MAX_TOKEN_SIZE
from ....globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
@@ -68,6 +68,7 @@ if __name__ == '__main__':
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
'cleaned_formulas'
)['train']
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
dataset = dataset.shuffle(seed=42)
dataset = dataset.flatten_indices()