TexTellerv2

This commit is contained in:
三洋三洋
2024-03-25 11:46:43 +00:00
parent a42df1510f
commit ef7cccff03
4 changed files with 23 additions and 22 deletions

1
.gitignore vendored
View File

@@ -6,3 +6,4 @@
**/.cache **/.cache
**/tmp* **/tmp*
**/data **/data
**/ckpt

View File

@@ -1,4 +1,3 @@
from PIL import Image
from pathlib import Path from pathlib import Path
from models.globals import ( from models.globals import (
@@ -43,22 +42,23 @@ class TexTeller(VisionEncoderDecoderModel):
if __name__ == "__main__": if __name__ == "__main__":
# texteller = TexTeller() # texteller = TexTeller()
from ..utils.inference import inference # from ..utils.inference import inference
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500') # model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt')
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') # model.save_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt2', safe_serialization=False)
# tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
base = '/home/lhy/code/TeXify/src/models/ocr_model/model' # base = '/home/lhy/code/TeXify/src/models/ocr_model/model'
imgs_path = [ # imgs_path = [
# base + '/1.jpg', # # base + '/1.jpg',
# base + '/2.jpg', # # base + '/2.jpg',
# base + '/3.jpg', # # base + '/3.jpg',
# base + '/4.jpg', # # base + '/4.jpg',
# base + '/5.jpg', # # base + '/5.jpg',
# base + '/6.jpg', # # base + '/6.jpg',
base + '/foo.jpg' # base + '/foo.jpg'
] # ]
# res = inference(model, [img1, img2, img3, img4, img5, img6, img7], tokenizer) # # res = inference(model, [img1, img2, img3, img4, img5, img6, img7], tokenizer)
res = inference(model, imgs_path, tokenizer) # res = inference(model, imgs_path, tokenizer)
pause = 1 # pause = 1

View File

@@ -38,7 +38,7 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
) )
# trainer.train(resume_from_checkpoint=None) # trainer.train(resume_from_checkpoint=None)
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000') trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-288000')
def evaluate(model, tokenizer, eval_dataset, collate_fn): def evaluate(model, tokenizer, eval_dataset, collate_fn):
@@ -94,9 +94,9 @@ if __name__ == '__main__':
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller() # model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000') model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-588000')
enable_train = True enable_train = False
enable_evaluate = True enable_evaluate = True
if enable_train: if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer) train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)

View File

@@ -4,7 +4,7 @@ from transformers import EvalPrediction, RobertaTokenizer
from typing import Dict from typing import Dict
def bleu_metric(eval_preds:EvalPrediction, tokenizer:RobertaTokenizer) -> Dict: def bleu_metric(eval_preds:EvalPrediction, tokenizer:RobertaTokenizer) -> Dict:
metric = evaluate.load('/home/lhy/code/TeXify/src/models/ocr_model/train/google_bleu/google_bleu.py') # 这里需要联网,所以会卡住 metric = evaluate.load('/home/lhy/code/TexTeller/src/models/ocr_model/train/google_bleu') # 这里需要联网,所以会卡住
logits, labels = eval_preds.predictions, eval_preds.label_ids logits, labels = eval_preds.predictions, eval_preds.label_ids
preds = logits preds = logits