完成了load1) er.py, 以 2) 部分代码的loader加载路径的更改
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,4 +4,5 @@
|
|||||||
|
|
||||||
**/logs
|
**/logs
|
||||||
**/.cache
|
**/.cache
|
||||||
**/tmp*
|
**/tmp*
|
||||||
|
**/data
|
||||||
@@ -30,7 +30,7 @@ OCR_IMG_MAX_WIDTH = 768
|
|||||||
OCR_IMG_CHANNELS = 1 # 灰度图
|
OCR_IMG_CHANNELS = 1 # 灰度图
|
||||||
|
|
||||||
# ocr模型训练数据集的最长token数
|
# ocr模型训练数据集的最长token数
|
||||||
MAX_TOKEN_SIZE = 512 # 模型最长的embedding长度被设置成了512,所以这里必须是512
|
MAX_TOKEN_SIZE = 2048 # 模型最长的embedding长度(默认512)
|
||||||
# MAX_TOKEN_SIZE = 600
|
# MAX_TOKEN_SIZE = 600
|
||||||
|
|
||||||
# ocr模型训练时随机缩放的比例
|
# ocr模型训练时随机缩放的比例
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from models.globals import (
|
|||||||
VOCAB_SIZE,
|
VOCAB_SIZE,
|
||||||
OCR_IMG_SIZE,
|
OCR_IMG_SIZE,
|
||||||
OCR_IMG_CHANNELS,
|
OCR_IMG_CHANNELS,
|
||||||
|
MAX_TOKEN_SIZE
|
||||||
)
|
)
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -25,6 +26,7 @@ class TexTeller(VisionEncoderDecoderModel):
|
|||||||
))
|
))
|
||||||
decoder = TrOCRForCausalLM(TrOCRConfig(
|
decoder = TrOCRForCausalLM(TrOCRConfig(
|
||||||
vocab_size=VOCAB_SIZE,
|
vocab_size=VOCAB_SIZE,
|
||||||
|
max_position_embeddings=MAX_TOKEN_SIZE
|
||||||
))
|
))
|
||||||
super().__init__(encoder=encoder, decoder=decoder)
|
super().__init__(encoder=encoder, decoder=decoder)
|
||||||
|
|
||||||
|
|||||||
@@ -65,8 +65,7 @@ if __name__ == '__main__':
|
|||||||
os.chdir(script_dirpath)
|
os.chdir(script_dirpath)
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
'/home/lhy/code/TexTeller/src/models/ocr_model/train/data/loader.py'
|
||||||
'cleaned_formulas'
|
|
||||||
)['train']
|
)['train']
|
||||||
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
|
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
|
||||||
dataset = dataset.shuffle(seed=42)
|
dataset = dataset.shuffle(seed=42)
|
||||||
@@ -81,8 +80,8 @@ if __name__ == '__main__':
|
|||||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
||||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||||
# model = TexTeller()
|
model = TexTeller()
|
||||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000')
|
# model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000')
|
||||||
|
|
||||||
enable_train = False
|
enable_train = False
|
||||||
enable_evaluate = True
|
enable_evaluate = True
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from ...globals import VOCAB_SIZE
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-raw')
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-raw')
|
||||||
dataset = load_dataset("/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py", "cleaned_formulas")['train']
|
dataset = load_dataset("/home/lhy/code/TexTeller/src/models/ocr_model/train/data/loader.py")['train']
|
||||||
new_tokenizer = tokenizer.train_new_from_iterator(text_iterator=dataset['latex_formula'], vocab_size=VOCAB_SIZE)
|
new_tokenizer = tokenizer.train_new_from_iterator(text_iterator=dataset['latex_formula'], vocab_size=VOCAB_SIZE)
|
||||||
new_tokenizer.save_pretrained('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
new_tokenizer.save_pretrained('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||||
pause = 1
|
pause = 1
|
||||||
|
|||||||
Reference in New Issue
Block a user