2024-03-25 16:35:34 +08:00
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
from functools import partial
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
from datasets import load_dataset
|
2024-01-30 08:36:23 +00:00
|
|
|
from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrainingArguments, GenerationConfig
|
2024-03-25 16:35:34 +08:00
|
|
|
|
|
|
|
|
from .training_args import CONFIG
|
|
|
|
|
from ..model.TexTeller import TexTeller
|
2024-03-28 10:19:40 +00:00
|
|
|
from ..utils.functional import tokenize_fn, collate_fn, img_train_transform, img_inf_transform, filter_fn
|
2024-03-25 16:35:34 +08:00
|
|
|
from ..utils.metrics import bleu_metric
|
2024-03-28 13:44:32 +00:00
|
|
|
from ...globals import MAX_TOKEN_SIZE
|
2024-03-25 16:35:34 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
|
|
|
|
training_args = TrainingArguments(**CONFIG)
|
2024-03-28 10:19:40 +00:00
|
|
|
debug_mode = False
|
2024-03-06 13:59:36 +00:00
|
|
|
if debug_mode:
|
|
|
|
|
training_args.auto_find_batch_size = False
|
|
|
|
|
training_args.num_train_epochs = 2
|
2024-03-13 02:21:02 +00:00
|
|
|
# training_args.per_device_train_batch_size = 3
|
2024-03-06 13:59:36 +00:00
|
|
|
training_args.per_device_train_batch_size = 2
|
|
|
|
|
training_args.per_device_eval_batch_size = 2 * training_args.per_device_train_batch_size
|
|
|
|
|
training_args.jit_mode_eval = False
|
|
|
|
|
training_args.torch_compile = False
|
|
|
|
|
training_args.dataloader_num_workers = 1
|
|
|
|
|
|
2024-03-25 16:35:34 +08:00
|
|
|
trainer = Trainer(
|
|
|
|
|
model,
|
|
|
|
|
training_args,
|
|
|
|
|
|
|
|
|
|
train_dataset=train_dataset,
|
|
|
|
|
eval_dataset=eval_dataset,
|
|
|
|
|
|
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
|
data_collator=collate_fn_with_tokenizer,
|
|
|
|
|
)
|
|
|
|
|
|
2024-04-10 16:09:13 +00:00
|
|
|
# trainer.train(resume_from_checkpoint=None)
|
2024-04-16 13:56:56 +00:00
|
|
|
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-644000')
|
2024-03-25 16:35:34 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
|
|
|
|
eval_config = CONFIG.copy()
|
|
|
|
|
eval_config['predict_with_generate'] = True
|
|
|
|
|
generate_config = GenerationConfig(
|
2024-03-26 08:16:28 +00:00
|
|
|
max_length=MAX_TOKEN_SIZE-100,
|
2024-03-25 16:35:34 +08:00
|
|
|
num_beams=1,
|
|
|
|
|
do_sample=False,
|
|
|
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
|
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
|
|
|
bos_token_id=tokenizer.bos_token_id,
|
|
|
|
|
)
|
|
|
|
|
eval_config['generation_config'] = generate_config
|
2024-01-31 15:27:35 +00:00
|
|
|
eval_config['auto_find_batch_size'] = False
|
2024-03-25 16:35:34 +08:00
|
|
|
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
|
|
|
|
|
|
|
|
|
|
trainer = Seq2SeqTrainer(
|
|
|
|
|
model,
|
|
|
|
|
seq2seq_config,
|
|
|
|
|
|
|
|
|
|
eval_dataset=eval_dataset,
|
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
|
data_collator=collate_fn,
|
|
|
|
|
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
|
|
|
|
)
|
|
|
|
|
|
2024-01-30 08:36:23 +00:00
|
|
|
res = trainer.evaluate()
|
2024-01-31 15:27:35 +00:00
|
|
|
print(res)
|
2024-03-25 16:35:34 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2024-01-28 06:19:23 +00:00
|
|
|
cur_path = os.getcwd()
|
2024-03-25 16:35:34 +08:00
|
|
|
script_dirpath = Path(__file__).resolve().parent
|
|
|
|
|
os.chdir(script_dirpath)
|
|
|
|
|
|
2024-01-31 10:20:27 +00:00
|
|
|
dataset = load_dataset(
|
2024-03-03 15:59:15 +00:00
|
|
|
'/home/lhy/code/TexTeller/src/models/ocr_model/train/data/loader.py'
|
2024-01-31 10:20:27 +00:00
|
|
|
)['train']
|
2024-03-06 13:59:36 +00:00
|
|
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TexTeller/src/models/tokenizer/roberta-tokenizer-7Mformulas')
|
|
|
|
|
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
|
|
|
|
|
|
|
|
|
|
dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=16)
|
2024-03-25 16:35:34 +08:00
|
|
|
dataset = dataset.shuffle(seed=42)
|
|
|
|
|
dataset = dataset.flatten_indices()
|
|
|
|
|
|
|
|
|
|
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
2024-01-31 10:20:27 +00:00
|
|
|
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True)
|
2024-03-25 16:35:34 +08:00
|
|
|
|
2024-03-13 02:21:02 +00:00
|
|
|
split_dataset = tokenized_dataset.train_test_split(test_size=0.005, seed=42)
|
2024-03-25 16:35:34 +08:00
|
|
|
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
|
|
|
|
|
2024-03-28 10:19:40 +00:00
|
|
|
train_dataset = train_dataset.with_transform(img_train_transform)
|
|
|
|
|
eval_dataset = eval_dataset.with_transform(img_inf_transform)
|
2024-03-25 16:35:34 +08:00
|
|
|
|
2024-01-31 10:20:27 +00:00
|
|
|
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
2024-03-13 02:21:02 +00:00
|
|
|
# model = TexTeller()
|
2024-04-16 13:56:56 +00:00
|
|
|
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-644000')
|
2024-01-30 08:36:23 +00:00
|
|
|
|
2024-03-27 04:54:49 +00:00
|
|
|
# ================= debug =======================
|
2024-03-28 10:19:40 +00:00
|
|
|
# foo = train_dataset[:50]
|
|
|
|
|
# bar = eval_dataset[:50]
|
2024-03-27 04:54:49 +00:00
|
|
|
# ================= debug =======================
|
2024-03-25 16:35:34 +08:00
|
|
|
|
2024-03-26 08:16:28 +00:00
|
|
|
enable_train = True
|
2024-02-08 13:48:34 +00:00
|
|
|
enable_evaluate = True
|
2024-03-25 16:35:34 +08:00
|
|
|
if enable_train:
|
2024-03-13 02:21:02 +00:00
|
|
|
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
2024-01-30 08:36:23 +00:00
|
|
|
if enable_evaluate:
|
2024-03-25 16:35:34 +08:00
|
|
|
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
2024-01-30 08:36:23 +00:00
|
|
|
|
2024-03-13 02:21:02 +00:00
|
|
|
os.chdir(cur_path)
|