diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py index 0e2979d..9d37f44 100644 --- a/src/models/ocr_model/train/train.py +++ b/src/models/ocr_model/train/train.py @@ -14,7 +14,7 @@ from transformers import ( from .training_args import CONFIG from ..model.TexTeller import TexTeller -from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn +from ..utils.functional import tokenize_fn, collate_fn, img_train_transform, img_inf_transform, filter_fn from ..utils.metrics import bleu_metric from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT @@ -75,14 +75,20 @@ if __name__ == '__main__': tokenizer = TexTeller.get_tokenizer() # If you want use your own tokenizer, please modify the path to your tokenizer #+tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer') + filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer) + dataset = dataset.filter( + filter_fn_with_tokenizer, + num_proc=8 + ) map_fn = partial(tokenize_fn, tokenizer=tokenizer) tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8) - tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn) # Split dataset into train and eval, ratio 9:1 split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42) train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] + train_dataset = train_dataset.with_transform(img_train_transform) + eval_dataset = eval_dataset.with_transform(img_inf_transform) collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) # Train from scratch