完成了TexTellerv2的训练(不支持自然场景)
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -16,17 +15,17 @@ from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
|||||||
|
|
||||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||||
training_args = TrainingArguments(**CONFIG)
|
training_args = TrainingArguments(**CONFIG)
|
||||||
debug_mode = True
|
debug_mode = False
|
||||||
if debug_mode:
|
if debug_mode:
|
||||||
training_args.auto_find_batch_size = False
|
training_args.auto_find_batch_size = False
|
||||||
training_args.num_train_epochs = 2
|
training_args.num_train_epochs = 2
|
||||||
|
# training_args.per_device_train_batch_size = 3
|
||||||
training_args.per_device_train_batch_size = 2
|
training_args.per_device_train_batch_size = 2
|
||||||
training_args.per_device_eval_batch_size = 2 * training_args.per_device_train_batch_size
|
training_args.per_device_eval_batch_size = 2 * training_args.per_device_train_batch_size
|
||||||
training_args.jit_mode_eval = False
|
training_args.jit_mode_eval = False
|
||||||
training_args.torch_compile = False
|
training_args.torch_compile = False
|
||||||
training_args.dataloader_num_workers = 1
|
training_args.dataloader_num_workers = 1
|
||||||
|
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model,
|
model,
|
||||||
training_args,
|
training_args,
|
||||||
@@ -38,7 +37,8 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
|
|||||||
data_collator=collate_fn_with_tokenizer,
|
data_collator=collate_fn_with_tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train(resume_from_checkpoint=None)
|
# trainer.train(resume_from_checkpoint=None)
|
||||||
|
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000')
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||||
@@ -93,15 +93,14 @@ if __name__ == '__main__':
|
|||||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.005, seed=42)
|
split_dataset = tokenized_dataset.train_test_split(test_size=0.005, seed=42)
|
||||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||||
model = TexTeller()
|
# model = TexTeller()
|
||||||
# model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000')
|
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000')
|
||||||
|
|
||||||
enable_train = True
|
enable_train = True
|
||||||
enable_evaluate = True
|
enable_evaluate = True
|
||||||
if enable_train:
|
if enable_train:
|
||||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||||
if enable_evaluate:
|
if enable_evaluate:
|
||||||
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
os.chdir(cur_path)
|
os.chdir(cur_path)
|
||||||
@@ -16,16 +16,16 @@ CONFIG = {
|
|||||||
#+通常与eval_steps一致
|
#+通常与eval_steps一致
|
||||||
"logging_nan_inf_filter": False, # 对loss=nan或inf进行记录
|
"logging_nan_inf_filter": False, # 对loss=nan或inf进行记录
|
||||||
|
|
||||||
"num_train_epochs": 10, # 总的训练轮数
|
"num_train_epochs": 2, # 总的训练轮数
|
||||||
# "max_steps": 3, # 训练的最大步骤数。如果设置了这个参数,
|
# "max_steps": 3, # 训练的最大步骤数。如果设置了这个参数,
|
||||||
#+那么num_train_epochs将被忽略(通常用于调试)
|
#+那么num_train_epochs将被忽略(通常用于调试)
|
||||||
|
|
||||||
# "label_names": ['your_label_name'], # 指定data_loader中的标签名,如果不指定则默认为'labels'
|
# "label_names": ['your_label_name'], # 指定data_loader中的标签名,如果不指定则默认为'labels'
|
||||||
|
|
||||||
"per_device_train_batch_size": 64, # 每个GPU的batch size
|
"per_device_train_batch_size": 3, # 每个GPU的batch size
|
||||||
"per_device_eval_batch_size": 16, # 每个GPU的evaluation batch size
|
"per_device_eval_batch_size": 6, # 每个GPU的evaluation batch size
|
||||||
|
# "auto_find_batch_size": True, # 自动搜索合适的batch size(指数decay)
|
||||||
"auto_find_batch_size": True, # 自动搜索合适的batch size(指数decay)
|
"auto_find_batch_size": True, # 自动搜索合适的batch size(指数decay)
|
||||||
# "auto_find_batch_size": False, # 自动搜索合适的batch size(指数decay)
|
|
||||||
|
|
||||||
"optim": "adamw_torch", # 还提供了很多AdamW的变体(相较于经典的AdamW更加高效)
|
"optim": "adamw_torch", # 还提供了很多AdamW的变体(相较于经典的AdamW更加高效)
|
||||||
#+当设置了optim后,就不需要在Trainer中传入optimizer
|
#+当设置了optim后,就不需要在Trainer中传入optimizer
|
||||||
|
|||||||
Reference in New Issue
Block a user