diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py index 40652ee..6ab8814 100644 --- a/src/models/ocr_model/train/train.py +++ b/src/models/ocr_model/train/train.py @@ -1,5 +1,4 @@ import os -import numpy as np from functools import partial from pathlib import Path @@ -16,16 +15,16 @@ from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer): training_args = TrainingArguments(**CONFIG) - debug_mode = True + debug_mode = False if debug_mode: training_args.auto_find_batch_size = False training_args.num_train_epochs = 2 + # training_args.per_device_train_batch_size = 3 training_args.per_device_train_batch_size = 2 training_args.per_device_eval_batch_size = 2 * training_args.per_device_train_batch_size training_args.jit_mode_eval = False training_args.torch_compile = False training_args.dataloader_num_workers = 1 - trainer = Trainer( model, @@ -38,7 +37,8 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz data_collator=collate_fn_with_tokenizer, ) - trainer.train(resume_from_checkpoint=None) + # trainer.train(resume_from_checkpoint=None) + trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000') def evaluate(model, tokenizer, eval_dataset, collate_fn): @@ -90,18 +90,17 @@ if __name__ == '__main__': tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True) tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn) - split_dataset = tokenized_dataset.train_test_split(test_size=0.005, seed=42) + split_dataset = tokenized_dataset.train_test_split(test_size=0.005, seed=42) train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) - model = TexTeller() - # model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000') + # model = TexTeller() + model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000') - enable_train = True + enable_train = True enable_evaluate = True if enable_train: - train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer) + train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer) if enable_evaluate: evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer) - - os.chdir(cur_path) \ No newline at end of file + os.chdir(cur_path) diff --git a/src/models/ocr_model/train/training_args.py b/src/models/ocr_model/train/training_args.py index e824459..73055b7 100644 --- a/src/models/ocr_model/train/training_args.py +++ b/src/models/ocr_model/train/training_args.py @@ -16,16 +16,16 @@ CONFIG = { #+通常与eval_steps一致 "logging_nan_inf_filter": False, # 对loss=nan或inf进行记录 - "num_train_epochs": 10, # 总的训练轮数 + "num_train_epochs": 2, # 总的训练轮数 # "max_steps": 3, # 训练的最大步骤数。如果设置了这个参数, #+那么num_train_epochs将被忽略(通常用于调试) # "label_names": ['your_label_name'], # 指定data_loader中的标签名,如果不指定则默认为'labels' - "per_device_train_batch_size": 64, # 每个GPU的batch size - "per_device_eval_batch_size": 16, # 每个GPU的evaluation batch size + "per_device_train_batch_size": 3, # 每个GPU的batch size + "per_device_eval_batch_size": 6, # 每个GPU的evaluation batch size + # "auto_find_batch_size": True, # 自动搜索合适的batch size(指数decay) "auto_find_batch_size": True, # 自动搜索合适的batch size(指数decay) - # "auto_find_batch_size": False, # 自动搜索合适的batch size(指数decay) "optim": "adamw_torch", # 还提供了很多AdamW的变体(相较于经典的AdamW更加高效) #+当设置了optim后,就不需要在Trainer中传入optimizer