完成了load1) er.py, 以 2) 部分代码的loader加载路径的更改

This commit is contained in:
三洋三洋
2024-03-03 15:59:15 +00:00
parent 69b10eccc7
commit 38877d90b8
5 changed files with 9 additions and 7 deletions

3
.gitignore vendored
View File

@@ -4,4 +4,5 @@
**/logs **/logs
**/.cache **/.cache
**/tmp* **/tmp*
**/data

View File

@@ -30,7 +30,7 @@ OCR_IMG_MAX_WIDTH = 768
OCR_IMG_CHANNELS = 1 # 灰度图 OCR_IMG_CHANNELS = 1 # 灰度图
# ocr模型训练数据集的最长token数 # ocr模型训练数据集的最长token数
MAX_TOKEN_SIZE = 512 # 模型最长的embedding长度被设置成了512所以这里必须是512 MAX_TOKEN_SIZE = 2048 # 模型最长的embedding长度(默认512)
# MAX_TOKEN_SIZE = 600 # MAX_TOKEN_SIZE = 600
# ocr模型训练时随机缩放的比例 # ocr模型训练时随机缩放的比例

View File

@@ -5,6 +5,7 @@ from models.globals import (
VOCAB_SIZE, VOCAB_SIZE,
OCR_IMG_SIZE, OCR_IMG_SIZE,
OCR_IMG_CHANNELS, OCR_IMG_CHANNELS,
MAX_TOKEN_SIZE
) )
from transformers import ( from transformers import (
@@ -25,6 +26,7 @@ class TexTeller(VisionEncoderDecoderModel):
)) ))
decoder = TrOCRForCausalLM(TrOCRConfig( decoder = TrOCRForCausalLM(TrOCRConfig(
vocab_size=VOCAB_SIZE, vocab_size=VOCAB_SIZE,
max_position_embeddings=MAX_TOKEN_SIZE
)) ))
super().__init__(encoder=encoder, decoder=decoder) super().__init__(encoder=encoder, decoder=decoder)

View File

@@ -65,8 +65,7 @@ if __name__ == '__main__':
os.chdir(script_dirpath) os.chdir(script_dirpath)
dataset = load_dataset( dataset = load_dataset(
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', '/home/lhy/code/TexTeller/src/models/ocr_model/train/data/loader.py'
'cleaned_formulas'
)['train'] )['train']
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH) dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
dataset = dataset.shuffle(seed=42) dataset = dataset.shuffle(seed=42)
@@ -81,8 +80,8 @@ if __name__ == '__main__':
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42) split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller() model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000') # model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000')
enable_train = False enable_train = False
enable_evaluate = True enable_evaluate = True

View File

@@ -5,7 +5,7 @@ from ...globals import VOCAB_SIZE
if __name__ == '__main__': if __name__ == '__main__':
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-raw') tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-raw')
dataset = load_dataset("/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py", "cleaned_formulas")['train'] dataset = load_dataset("/home/lhy/code/TexTeller/src/models/ocr_model/train/data/loader.py")['train']
new_tokenizer = tokenizer.train_new_from_iterator(text_iterator=dataset['latex_formula'], vocab_size=VOCAB_SIZE) new_tokenizer = tokenizer.train_new_from_iterator(text_iterator=dataset['latex_formula'], vocab_size=VOCAB_SIZE)
new_tokenizer.save_pretrained('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') new_tokenizer.save_pretrained('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
pause = 1 pause = 1