From 38877d90b848b43234d800a18676de4bdfc4ad73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Sun, 3 Mar 2024 15:59:15 +0000 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E4=BA=86load1)=20er.py,=20?= =?UTF-8?q?=E4=BB=A5=202)=20=E9=83=A8=E5=88=86=E4=BB=A3=E7=A0=81=E7=9A=84l?= =?UTF-8?q?oader=E5=8A=A0=E8=BD=BD=E8=B7=AF=E5=BE=84=E7=9A=84=E6=9B=B4?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 ++- src/models/globals.py | 2 +- src/models/ocr_model/model/TexTeller.py | 2 ++ src/models/ocr_model/train/train.py | 7 +++---- src/models/tokenizer/train/train.py | 2 +- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 37c7cfd..787a849 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ **/logs **/.cache -**/tmp* \ No newline at end of file +**/tmp* +**/data \ No newline at end of file diff --git a/src/models/globals.py b/src/models/globals.py index 664af8d..3f1ed56 100644 --- a/src/models/globals.py +++ b/src/models/globals.py @@ -30,7 +30,7 @@ OCR_IMG_MAX_WIDTH = 768 OCR_IMG_CHANNELS = 1 # 灰度图 # ocr模型训练数据集的最长token数 -MAX_TOKEN_SIZE = 512 # 模型最长的embedding长度被设置成了512,所以这里必须是512 +MAX_TOKEN_SIZE = 2048 # 模型最长的embedding长度(默认512) # MAX_TOKEN_SIZE = 600 # ocr模型训练时随机缩放的比例 diff --git a/src/models/ocr_model/model/TexTeller.py b/src/models/ocr_model/model/TexTeller.py index 142efa6..1a521fa 100644 --- a/src/models/ocr_model/model/TexTeller.py +++ b/src/models/ocr_model/model/TexTeller.py @@ -5,6 +5,7 @@ from models.globals import ( VOCAB_SIZE, OCR_IMG_SIZE, OCR_IMG_CHANNELS, + MAX_TOKEN_SIZE ) from transformers import ( @@ -25,6 +26,7 @@ class TexTeller(VisionEncoderDecoderModel): )) decoder = TrOCRForCausalLM(TrOCRConfig( vocab_size=VOCAB_SIZE, + max_position_embeddings=MAX_TOKEN_SIZE )) super().__init__(encoder=encoder, decoder=decoder) diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py index ded62ea..a4edc9c 100644 --- a/src/models/ocr_model/train/train.py +++ b/src/models/ocr_model/train/train.py @@ -65,8 +65,7 @@ if __name__ == '__main__': os.chdir(script_dirpath) dataset = load_dataset( - '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', - 'cleaned_formulas' + '/home/lhy/code/TexTeller/src/models/ocr_model/train/data/loader.py' )['train'] dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH) dataset = dataset.shuffle(seed=42) @@ -81,8 +80,8 @@ if __name__ == '__main__': split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42) train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) - # model = TexTeller() - model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000') + model = TexTeller() + # model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000') enable_train = False enable_evaluate = True diff --git a/src/models/tokenizer/train/train.py b/src/models/tokenizer/train/train.py index 4f9d0bc..ab37915 100644 --- a/src/models/tokenizer/train/train.py +++ b/src/models/tokenizer/train/train.py @@ -5,7 +5,7 @@ from ...globals import VOCAB_SIZE if __name__ == '__main__': tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-raw') - dataset = load_dataset("/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py", "cleaned_formulas")['train'] + dataset = load_dataset("/home/lhy/code/TexTeller/src/models/ocr_model/train/data/loader.py")['train'] new_tokenizer = tokenizer.train_new_from_iterator(text_iterator=dataset['latex_formula'], vocab_size=VOCAB_SIZE) new_tokenizer.save_pretrained('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') pause = 1