完成了load1) er.py, 以 2) 部分代码的loader加载路径的更改
This commit is contained in:
@@ -5,6 +5,7 @@ from models.globals import (
|
||||
VOCAB_SIZE,
|
||||
OCR_IMG_SIZE,
|
||||
OCR_IMG_CHANNELS,
|
||||
MAX_TOKEN_SIZE
|
||||
)
|
||||
|
||||
from transformers import (
|
||||
@@ -25,6 +26,7 @@ class TexTeller(VisionEncoderDecoderModel):
|
||||
))
|
||||
decoder = TrOCRForCausalLM(TrOCRConfig(
|
||||
vocab_size=VOCAB_SIZE,
|
||||
max_position_embeddings=MAX_TOKEN_SIZE
|
||||
))
|
||||
super().__init__(encoder=encoder, decoder=decoder)
|
||||
|
||||
|
||||
@@ -65,8 +65,7 @@ if __name__ == '__main__':
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
dataset = load_dataset(
|
||||
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||
'cleaned_formulas'
|
||||
'/home/lhy/code/TexTeller/src/models/ocr_model/train/data/loader.py'
|
||||
)['train']
|
||||
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
|
||||
dataset = dataset.shuffle(seed=42)
|
||||
@@ -81,8 +80,8 @@ if __name__ == '__main__':
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
# model = TexTeller()
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000')
|
||||
model = TexTeller()
|
||||
# model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000')
|
||||
|
||||
enable_train = False
|
||||
enable_evaluate = True
|
||||
|
||||
Reference in New Issue
Block a user