完成了TexTellerv2的训练(不支持自然场景)

This commit is contained in:
三洋三洋
2024-03-13 02:21:02 +00:00
parent 93979bddf6
commit a42df1510f
2 changed files with 14 additions and 15 deletions

View File

@@ -1,5 +1,4 @@
import os import os
import numpy as np
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@@ -16,17 +15,17 @@ from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer): def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
training_args = TrainingArguments(**CONFIG) training_args = TrainingArguments(**CONFIG)
debug_mode = True debug_mode = False
if debug_mode: if debug_mode:
training_args.auto_find_batch_size = False training_args.auto_find_batch_size = False
training_args.num_train_epochs = 2 training_args.num_train_epochs = 2
# training_args.per_device_train_batch_size = 3
training_args.per_device_train_batch_size = 2 training_args.per_device_train_batch_size = 2
training_args.per_device_eval_batch_size = 2 * training_args.per_device_train_batch_size training_args.per_device_eval_batch_size = 2 * training_args.per_device_train_batch_size
training_args.jit_mode_eval = False training_args.jit_mode_eval = False
training_args.torch_compile = False training_args.torch_compile = False
training_args.dataloader_num_workers = 1 training_args.dataloader_num_workers = 1
trainer = Trainer( trainer = Trainer(
model, model,
training_args, training_args,
@@ -38,7 +37,8 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
data_collator=collate_fn_with_tokenizer, data_collator=collate_fn_with_tokenizer,
) )
trainer.train(resume_from_checkpoint=None) # trainer.train(resume_from_checkpoint=None)
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000')
def evaluate(model, tokenizer, eval_dataset, collate_fn): def evaluate(model, tokenizer, eval_dataset, collate_fn):
@@ -93,8 +93,8 @@ if __name__ == '__main__':
split_dataset = tokenized_dataset.train_test_split(test_size=0.005, seed=42) split_dataset = tokenized_dataset.train_test_split(test_size=0.005, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
model = TexTeller() # model = TexTeller()
# model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000') model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000')
enable_train = True enable_train = True
enable_evaluate = True enable_evaluate = True
@@ -103,5 +103,4 @@ if __name__ == '__main__':
if enable_evaluate: if enable_evaluate:
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer) evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
os.chdir(cur_path) os.chdir(cur_path)

View File

@@ -16,16 +16,16 @@ CONFIG = {
#+通常与eval_steps一致 #+通常与eval_steps一致
"logging_nan_inf_filter": False, # 对loss=nan或inf进行记录 "logging_nan_inf_filter": False, # 对loss=nan或inf进行记录
"num_train_epochs": 10, # 总的训练轮数 "num_train_epochs": 2, # 总的训练轮数
# "max_steps": 3, # 训练的最大步骤数。如果设置了这个参数, # "max_steps": 3, # 训练的最大步骤数。如果设置了这个参数,
#+那么num_train_epochs将被忽略通常用于调试 #+那么num_train_epochs将被忽略通常用于调试
# "label_names": ['your_label_name'], # 指定data_loader中的标签名如果不指定则默认为'labels' # "label_names": ['your_label_name'], # 指定data_loader中的标签名如果不指定则默认为'labels'
"per_device_train_batch_size": 64, # 每个GPU的batch size "per_device_train_batch_size": 3, # 每个GPU的batch size
"per_device_eval_batch_size": 16, # 每个GPU的evaluation batch size "per_device_eval_batch_size": 6, # 每个GPU的evaluation batch size
# "auto_find_batch_size": True, # 自动搜索合适的batch size指数decay
"auto_find_batch_size": True, # 自动搜索合适的batch size指数decay "auto_find_batch_size": True, # 自动搜索合适的batch size指数decay
# "auto_find_batch_size": False, # 自动搜索合适的batch size指数decay
"optim": "adamw_torch", # 还提供了很多AdamW的变体相较于经典的AdamW更加高效 "optim": "adamw_torch", # 还提供了很多AdamW的变体相较于经典的AdamW更加高效
#+当设置了optim后就不需要在Trainer中传入optimizer #+当设置了optim后就不需要在Trainer中传入optimizer