TexTellerv2
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,3 +6,4 @@
|
|||||||
**/.cache
|
**/.cache
|
||||||
**/tmp*
|
**/tmp*
|
||||||
**/data
|
**/data
|
||||||
|
**/ckpt
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
from PIL import Image
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from models.globals import (
|
from models.globals import (
|
||||||
@@ -43,22 +42,23 @@ class TexTeller(VisionEncoderDecoderModel):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# texteller = TexTeller()
|
# texteller = TexTeller()
|
||||||
from ..utils.inference import inference
|
# from ..utils.inference import inference
|
||||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
|
# model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt')
|
||||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
# model.save_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt2', safe_serialization=False)
|
||||||
|
# tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||||
|
|
||||||
base = '/home/lhy/code/TeXify/src/models/ocr_model/model'
|
# base = '/home/lhy/code/TeXify/src/models/ocr_model/model'
|
||||||
imgs_path = [
|
# imgs_path = [
|
||||||
# base + '/1.jpg',
|
# # base + '/1.jpg',
|
||||||
# base + '/2.jpg',
|
# # base + '/2.jpg',
|
||||||
# base + '/3.jpg',
|
# # base + '/3.jpg',
|
||||||
# base + '/4.jpg',
|
# # base + '/4.jpg',
|
||||||
# base + '/5.jpg',
|
# # base + '/5.jpg',
|
||||||
# base + '/6.jpg',
|
# # base + '/6.jpg',
|
||||||
base + '/foo.jpg'
|
# base + '/foo.jpg'
|
||||||
]
|
# ]
|
||||||
|
|
||||||
# res = inference(model, [img1, img2, img3, img4, img5, img6, img7], tokenizer)
|
# # res = inference(model, [img1, img2, img3, img4, img5, img6, img7], tokenizer)
|
||||||
res = inference(model, imgs_path, tokenizer)
|
# res = inference(model, imgs_path, tokenizer)
|
||||||
pause = 1
|
# pause = 1
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
|
|||||||
)
|
)
|
||||||
|
|
||||||
# trainer.train(resume_from_checkpoint=None)
|
# trainer.train(resume_from_checkpoint=None)
|
||||||
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000')
|
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-288000')
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||||
@@ -94,9 +94,9 @@ if __name__ == '__main__':
|
|||||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||||
# model = TexTeller()
|
# model = TexTeller()
|
||||||
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-64000')
|
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-588000')
|
||||||
|
|
||||||
enable_train = True
|
enable_train = False
|
||||||
enable_evaluate = True
|
enable_evaluate = True
|
||||||
if enable_train:
|
if enable_train:
|
||||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from transformers import EvalPrediction, RobertaTokenizer
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
def bleu_metric(eval_preds:EvalPrediction, tokenizer:RobertaTokenizer) -> Dict:
|
def bleu_metric(eval_preds:EvalPrediction, tokenizer:RobertaTokenizer) -> Dict:
|
||||||
metric = evaluate.load('/home/lhy/code/TeXify/src/models/ocr_model/train/google_bleu/google_bleu.py') # 这里需要联网,所以会卡住
|
metric = evaluate.load('/home/lhy/code/TexTeller/src/models/ocr_model/train/google_bleu') # 这里需要联网,所以会卡住
|
||||||
|
|
||||||
logits, labels = eval_preds.predictions, eval_preds.label_ids
|
logits, labels = eval_preds.predictions, eval_preds.label_ids
|
||||||
preds = logits
|
preds = logits
|
||||||
|
|||||||
Reference in New Issue
Block a user