bugfix: missing filter_fn and inference/train transform

This commit is contained in:
三洋三洋
2024-05-12 07:49:04 +00:00
parent 249a4d5a5f
commit 0a51bde1c5

View File

@@ -14,7 +14,7 @@ from transformers import (
from .training_args import CONFIG from .training_args import CONFIG
from ..model.TexTeller import TexTeller from ..model.TexTeller import TexTeller
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn from ..utils.functional import tokenize_fn, collate_fn, img_train_transform, img_inf_transform, filter_fn
from ..utils.metrics import bleu_metric from ..utils.metrics import bleu_metric
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
@@ -75,14 +75,20 @@ if __name__ == '__main__':
tokenizer = TexTeller.get_tokenizer() tokenizer = TexTeller.get_tokenizer()
# If you want use your own tokenizer, please modify the path to your tokenizer # If you want use your own tokenizer, please modify the path to your tokenizer
#+tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer') #+tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
dataset = dataset.filter(
filter_fn_with_tokenizer,
num_proc=8
)
map_fn = partial(tokenize_fn, tokenizer=tokenizer) map_fn = partial(tokenize_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8) tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
# Split dataset into train and eval, ratio 9:1 # Split dataset into train and eval, ratio 9:1
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42) split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
train_dataset = train_dataset.with_transform(img_train_transform)
eval_dataset = eval_dataset.with_transform(img_inf_transform)
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# Train from scratch # Train from scratch