checkpoint

This commit is contained in:
三洋三洋
2024-04-16 13:56:56 +00:00
parent 7d1d8ddd77
commit f81a31a8c9
6 changed files with 238 additions and 24 deletions

View File

@@ -1,6 +1,6 @@
from pathlib import Path
from models.globals import (
from ...globals import (
VOCAB_SIZE,
FIXED_IMG_SIZE,
IMG_CHANNELS,

View File

@@ -38,7 +38,7 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
)
# trainer.train(resume_from_checkpoint=None)
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000')
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-644000')
def evaluate(model, tokenizer, eval_dataset, collate_fn):
@@ -96,7 +96,7 @@ if __name__ == '__main__':
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-440000')
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-644000')
# ================= debug =======================
# foo = train_dataset[:50]