From 0a51bde1c505bb63834a99c5beb04c3bee969c57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Sun, 12 May 2024 07:49:04 +0000 Subject: [PATCH] bugfix: missing filter_fn and inference/train transform --- src/models/ocr_model/train/train.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py index 0e2979d..9d37f44 100644 --- a/src/models/ocr_model/train/train.py +++ b/src/models/ocr_model/train/train.py @@ -14,7 +14,7 @@ from transformers import ( from .training_args import CONFIG from ..model.TexTeller import TexTeller -from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn +from ..utils.functional import tokenize_fn, collate_fn, img_train_transform, img_inf_transform, filter_fn from ..utils.metrics import bleu_metric from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT @@ -75,14 +75,20 @@ if __name__ == '__main__': tokenizer = TexTeller.get_tokenizer() # If you want use your own tokenizer, please modify the path to your tokenizer #+tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer') + filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer) + dataset = dataset.filter( + filter_fn_with_tokenizer, + num_proc=8 + ) map_fn = partial(tokenize_fn, tokenizer=tokenizer) tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8) - tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn) # Split dataset into train and eval, ratio 9:1 split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42) train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] + train_dataset = train_dataset.with_transform(img_train_transform) + eval_dataset = eval_dataset.with_transform(img_inf_transform) collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) # Train from scratch