加入了输入图片的最小宽和高的过滤,防止注入垃圾数据
This commit is contained in:
@@ -37,6 +37,10 @@ MAX_TOKEN_SIZE = 512 # 模型最长的embedding长度被设置成了512,
|
||||
MAX_RESIZE_RATIO = 1.15
|
||||
MIN_RESIZE_RATIO = 0.75
|
||||
|
||||
# ocr模型输入的图片要求的最低宽和高(过滤垃圾数据)
|
||||
MIN_HEIGHT = 12
|
||||
MIN_WIDTH = 30
|
||||
|
||||
# ============================================================================= #
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from .training_args import CONFIG
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
|
||||
from ..utils.metrics import bleu_metric
|
||||
from ....globals import MAX_TOKEN_SIZE
|
||||
from ....globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
|
||||
|
||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||
@@ -68,6 +68,7 @@ if __name__ == '__main__':
|
||||
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||
'cleaned_formulas'
|
||||
)['train']
|
||||
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
|
||||
dataset = dataset.shuffle(seed=42)
|
||||
dataset = dataset.flatten_indices()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user