写好了v3版本的训练代码(v3版本加入了自然场景训练增强)
This commit is contained in:
@@ -8,14 +8,14 @@ from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrai
|
||||
|
||||
from .training_args import CONFIG
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn, filter_fn
|
||||
from ..utils.functional import tokenize_fn, collate_fn, img_train_transform, img_inf_transform, filter_fn
|
||||
from ..utils.metrics import bleu_metric
|
||||
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
|
||||
|
||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||
training_args = TrainingArguments(**CONFIG)
|
||||
debug_mode = True
|
||||
debug_mode = False
|
||||
if debug_mode:
|
||||
training_args.auto_find_batch_size = False
|
||||
training_args.num_train_epochs = 2
|
||||
@@ -88,16 +88,20 @@ if __name__ == '__main__':
|
||||
|
||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True)
|
||||
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
|
||||
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.005, seed=42)
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
|
||||
train_dataset = train_dataset.with_transform(img_train_transform)
|
||||
eval_dataset = eval_dataset.with_transform(img_inf_transform)
|
||||
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
# model = TexTeller()
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt')
|
||||
|
||||
# ================= debug =======================
|
||||
foo = train_dataset[:3]
|
||||
# foo = train_dataset[:50]
|
||||
# bar = eval_dataset[:50]
|
||||
# ================= debug =======================
|
||||
|
||||
enable_train = True
|
||||
|
||||
Reference in New Issue
Block a user