TexTellerv2
This commit is contained in:
@@ -38,7 +38,7 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
|
||||
)
|
||||
|
||||
# trainer.train(resume_from_checkpoint=None)
|
||||
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000')
|
||||
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-288000')
|
||||
|
||||
|
||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||
@@ -94,9 +94,9 @@ if __name__ == '__main__':
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
# model = TexTeller()
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000')
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-588000')
|
||||
|
||||
enable_train = True
|
||||
enable_train = False
|
||||
enable_evaluate = True
|
||||
if enable_train:
|
||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||
|
||||
Reference in New Issue
Block a user