From ef7cccff03e4d8a6585457256a6c927956069488 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Mon, 25 Mar 2024 11:46:43 +0000 Subject: [PATCH] TexTellerv2 --- .gitignore | 3 ++- src/models/ocr_model/model/TexTeller.py | 34 ++++++++++++------------- src/models/ocr_model/train/train.py | 6 ++--- src/models/ocr_model/utils/metrics.py | 2 +- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 787a849..789f903 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ **/logs **/.cache **/tmp* -**/data \ No newline at end of file +**/data +**/ckpt \ No newline at end of file diff --git a/src/models/ocr_model/model/TexTeller.py b/src/models/ocr_model/model/TexTeller.py index 1a521fa..adad913 100644 --- a/src/models/ocr_model/model/TexTeller.py +++ b/src/models/ocr_model/model/TexTeller.py @@ -1,4 +1,3 @@ -from PIL import Image from pathlib import Path from models.globals import ( @@ -43,22 +42,23 @@ class TexTeller(VisionEncoderDecoderModel): if __name__ == "__main__": # texteller = TexTeller() - from ..utils.inference import inference - model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500') - tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') + # from ..utils.inference import inference + # model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt') + # model.save_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt2', safe_serialization=False) + # tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') - base = '/home/lhy/code/TeXify/src/models/ocr_model/model' - imgs_path = [ - # base + '/1.jpg', - # base + '/2.jpg', - # base + '/3.jpg', - # base + '/4.jpg', - # base + '/5.jpg', - # base + '/6.jpg', - base + '/foo.jpg' - ] + # base = '/home/lhy/code/TeXify/src/models/ocr_model/model' + # imgs_path = [ + # # base + '/1.jpg', + # # base + '/2.jpg', + # # base + '/3.jpg', + # # base + '/4.jpg', + # # base + '/5.jpg', + # # base + '/6.jpg', + # base + '/foo.jpg' + # ] - # res = inference(model, [img1, img2, img3, img4, img5, img6, img7], tokenizer) - res = inference(model, imgs_path, tokenizer) - pause = 1 + # # res = inference(model, [img1, img2, img3, img4, img5, img6, img7], tokenizer) + # res = inference(model, imgs_path, tokenizer) + # pause = 1 diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py index 6ab8814..e110c34 100644 --- a/src/models/ocr_model/train/train.py +++ b/src/models/ocr_model/train/train.py @@ -38,7 +38,7 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz ) # trainer.train(resume_from_checkpoint=None) - trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000') + trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-288000') def evaluate(model, tokenizer, eval_dataset, collate_fn): @@ -94,9 +94,9 @@ if __name__ == '__main__': train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) # model = TexTeller() - model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000') + model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-588000') - enable_train = True + enable_train = False enable_evaluate = True if enable_train: train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer) diff --git a/src/models/ocr_model/utils/metrics.py b/src/models/ocr_model/utils/metrics.py index a21d131..c876af0 100644 --- a/src/models/ocr_model/utils/metrics.py +++ b/src/models/ocr_model/utils/metrics.py @@ -4,7 +4,7 @@ from transformers import EvalPrediction, RobertaTokenizer from typing import Dict def bleu_metric(eval_preds:EvalPrediction, tokenizer:RobertaTokenizer) -> Dict: - metric = evaluate.load('/home/lhy/code/TeXify/src/models/ocr_model/train/google_bleu/google_bleu.py') # 这里需要联网,所以会卡住 + metric = evaluate.load('/home/lhy/code/TexTeller/src/models/ocr_model/train/google_bleu') # 这里需要联网,所以会卡住 logits, labels = eval_preds.predictions, eval_preds.label_ids preds = logits