merge v2
This commit is contained in:
@@ -16,7 +16,7 @@ MAX_WIDTH = 1280
|
||||
TEXIFY_INPUT_DENSITY = 100
|
||||
|
||||
# ocr模型的tokenizer中的词典数量
|
||||
VOCAB_SIZE = 10000
|
||||
VOCAB_SIZE = 15000
|
||||
|
||||
# ocr模型是否固定输入图片的大小
|
||||
OCR_FIX_SIZE = True
|
||||
@@ -30,7 +30,8 @@ OCR_IMG_MAX_WIDTH = 768
|
||||
OCR_IMG_CHANNELS = 1 # 灰度图
|
||||
|
||||
# ocr模型训练数据集的最长token数
|
||||
MAX_TOKEN_SIZE = 2048 # 模型最长的embedding长度(默认512)
|
||||
MAX_TOKEN_SIZE = 1024 # 模型最长的embedding长度(默认512)
|
||||
# MAX_TOKEN_SIZE = 2048 # 模型最长的embedding长度(默认512)
|
||||
# MAX_TOKEN_SIZE = 600
|
||||
|
||||
# ocr模型训练时随机缩放的比例
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
from models.globals import (
|
||||
@@ -43,22 +42,23 @@ class TexTeller(VisionEncoderDecoderModel):
|
||||
|
||||
if __name__ == "__main__":
|
||||
# texteller = TexTeller()
|
||||
from ..utils.inference import inference
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
|
||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||
# from ..utils.inference import inference
|
||||
# model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt')
|
||||
# model.save_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt2', safe_serialization=False)
|
||||
# tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||
|
||||
base = '/home/lhy/code/TeXify/src/models/ocr_model/model'
|
||||
imgs_path = [
|
||||
# base + '/1.jpg',
|
||||
# base + '/2.jpg',
|
||||
# base + '/3.jpg',
|
||||
# base + '/4.jpg',
|
||||
# base + '/5.jpg',
|
||||
# base + '/6.jpg',
|
||||
base + '/foo.jpg'
|
||||
]
|
||||
# base = '/home/lhy/code/TeXify/src/models/ocr_model/model'
|
||||
# imgs_path = [
|
||||
# # base + '/1.jpg',
|
||||
# # base + '/2.jpg',
|
||||
# # base + '/3.jpg',
|
||||
# # base + '/4.jpg',
|
||||
# # base + '/5.jpg',
|
||||
# # base + '/6.jpg',
|
||||
# base + '/foo.jpg'
|
||||
# ]
|
||||
|
||||
# res = inference(model, [img1, img2, img3, img4, img5, img6, img7], tokenizer)
|
||||
res = inference(model, imgs_path, tokenizer)
|
||||
pause = 1
|
||||
# # res = inference(model, [img1, img2, img3, img4, img5, img6, img7], tokenizer)
|
||||
# res = inference(model, imgs_path, tokenizer)
|
||||
# pause = 1
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@@ -9,13 +8,24 @@ from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrai
|
||||
|
||||
from .training_args import CONFIG
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
|
||||
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn, filter_fn
|
||||
from ..utils.metrics import bleu_metric
|
||||
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
|
||||
|
||||
|
||||
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||
training_args = TrainingArguments(**CONFIG)
|
||||
debug_mode = False
|
||||
if debug_mode:
|
||||
training_args.auto_find_batch_size = False
|
||||
training_args.num_train_epochs = 2
|
||||
# training_args.per_device_train_batch_size = 3
|
||||
training_args.per_device_train_batch_size = 2
|
||||
training_args.per_device_eval_batch_size = 2 * training_args.per_device_train_batch_size
|
||||
training_args.jit_mode_eval = False
|
||||
training_args.torch_compile = False
|
||||
training_args.dataloader_num_workers = 1
|
||||
|
||||
trainer = Trainer(
|
||||
model,
|
||||
training_args,
|
||||
@@ -28,13 +38,14 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
|
||||
)
|
||||
|
||||
trainer.train(resume_from_checkpoint=None)
|
||||
# trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv2/checkpoint-288000')
|
||||
|
||||
|
||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||
eval_config = CONFIG.copy()
|
||||
eval_config['predict_with_generate'] = True
|
||||
generate_config = GenerationConfig(
|
||||
max_new_tokens=MAX_TOKEN_SIZE,
|
||||
max_length=MAX_TOKEN_SIZE-100,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
@@ -67,28 +78,29 @@ if __name__ == '__main__':
|
||||
dataset = load_dataset(
|
||||
'/home/lhy/code/TexTeller/src/models/ocr_model/train/data/loader.py'
|
||||
)['train']
|
||||
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
|
||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TexTeller/src/models/tokenizer/roberta-tokenizer-7Mformulas')
|
||||
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
|
||||
|
||||
# dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
|
||||
dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=16)
|
||||
dataset = dataset.shuffle(seed=42)
|
||||
dataset = dataset.flatten_indices()
|
||||
|
||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||
|
||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True)
|
||||
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
|
||||
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.005, seed=42)
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
model = TexTeller()
|
||||
# model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/train_with_random_resize/checkpoint-80000')
|
||||
# model = TexTeller()
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/ckpt')
|
||||
|
||||
enable_train = False
|
||||
enable_train = True
|
||||
enable_evaluate = True
|
||||
if enable_train:
|
||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||
if enable_evaluate:
|
||||
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
||||
|
||||
|
||||
os.chdir(cur_path)
|
||||
@@ -4,7 +4,7 @@ CONFIG = {
|
||||
# "data_seed": 42, # data sampler的采样也固定
|
||||
# "full_determinism": True, # 使整个训练完全固定(这个设置会有害于模型训练,只用于debug)
|
||||
|
||||
"output_dir": "train_result/train_with_random_resize", # 输出目录
|
||||
"output_dir": "train_result/TexTellerv3", # 输出目录
|
||||
"overwrite_output_dir": False, # 如果输出目录存在,不删除原先的内容
|
||||
"report_to": ["tensorboard"], # 输出日志到TensorBoard,
|
||||
#+通过在命令行:tensorboard --logdir ./logs 来查看日志
|
||||
@@ -12,18 +12,19 @@ CONFIG = {
|
||||
"logging_dir": None, # TensorBoard日志文件的存储目录(使用默认值)
|
||||
"log_level": "warning", # 其他可选:‘debug’, ‘info’, ‘warning’, ‘error’ and ‘critical’(由低级别到高级别)
|
||||
"logging_strategy": "steps", # 每隔一定步数记录一次日志
|
||||
"logging_steps": 500, # 记录日志的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
||||
"logging_steps": 4000, # 记录日志的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
||||
#+通常与eval_steps一致
|
||||
"logging_nan_inf_filter": False, # 对loss=nan或inf进行记录
|
||||
|
||||
"num_train_epochs": 10, # 总的训练轮数
|
||||
"num_train_epochs": 3, # 总的训练轮数
|
||||
# "max_steps": 3, # 训练的最大步骤数。如果设置了这个参数,
|
||||
#+那么num_train_epochs将被忽略(通常用于调试)
|
||||
|
||||
# "label_names": ['your_label_name'], # 指定data_loader中的标签名,如果不指定则默认为'labels'
|
||||
|
||||
"per_device_train_batch_size": 64, # 每个GPU的batch size
|
||||
"per_device_eval_batch_size": 16, # 每个GPU的evaluation batch size
|
||||
"per_device_train_batch_size": 3, # 每个GPU的batch size
|
||||
"per_device_eval_batch_size": 6, # 每个GPU的evaluation batch size
|
||||
# "auto_find_batch_size": True, # 自动搜索合适的batch size(指数decay)
|
||||
"auto_find_batch_size": True, # 自动搜索合适的batch size(指数decay)
|
||||
|
||||
"optim": "adamw_torch", # 还提供了很多AdamW的变体(相较于经典的AdamW更加高效)
|
||||
@@ -52,12 +53,12 @@ CONFIG = {
|
||||
"dataloader_drop_last": True, # 丢掉最后一个minibatch,保证训练的梯度稳定
|
||||
|
||||
"evaluation_strategy": "steps", # 评估策略,可以是"steps"或"epoch"
|
||||
"eval_steps": 500, # if evaluation_strategy="step"
|
||||
"eval_steps": 4000, # if evaluation_strategy="step"
|
||||
#+默认情况下与logging_steps一样,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
||||
|
||||
"save_strategy": "steps", # 保存checkpoint的策略
|
||||
"save_steps": 500, # checkpoint保存的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
||||
"save_total_limit": 5, # 保存的模型的最大数量。如果超过这个数量,最旧的模型将被删除
|
||||
"save_steps": 4000, # checkpoint保存的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
||||
"save_total_limit": 10, # 保存的模型的最大数量。如果超过这个数量,最旧的模型将被删除
|
||||
|
||||
"load_best_model_at_end": True, # 训练结束时是否加载最佳模型
|
||||
#+当设置True时,会保存训练时评估结果最好的checkpoint
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from functools import partial
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
from typing import List, Dict, Any
|
||||
from ..model.TexTeller import TexTeller
|
||||
from .transforms import train_transform
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def left_move(x: torch.Tensor, pad_val):
|
||||
@@ -50,6 +50,13 @@ def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
return samples
|
||||
|
||||
|
||||
def filter_fn(sample, tokenizer=None) -> bool:
|
||||
return (
|
||||
sample['image'].height > MIN_HEIGHT and sample['image'].width > MIN_WIDTH
|
||||
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset = load_dataset(
|
||||
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||
|
||||
@@ -4,7 +4,7 @@ from transformers import EvalPrediction, RobertaTokenizer
|
||||
from typing import Dict
|
||||
|
||||
def bleu_metric(eval_preds:EvalPrediction, tokenizer:RobertaTokenizer) -> Dict:
|
||||
metric = evaluate.load('/home/lhy/code/TeXify/src/models/ocr_model/train/google_bleu/google_bleu.py') # 这里需要联网,所以会卡住
|
||||
metric = evaluate.load('/home/lhy/code/TexTeller/src/models/ocr_model/train/google_bleu') # 这里需要联网,所以会卡住
|
||||
|
||||
logits, labels = eval_preds.predictions, eval_preds.label_ids
|
||||
preds = logits
|
||||
|
||||
14740
src/models/tokenizer/roberta-tokenizer-7Mformulas/merges.txt
Normal file
14740
src/models/tokenizer/roberta-tokenizer-7Mformulas/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"cls_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"mask_token": {
|
||||
"content": "<mask>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": "<pad>",
|
||||
"sep_token": "</s>",
|
||||
"unk_token": "<unk>"
|
||||
}
|
||||
29830
src/models/tokenizer/roberta-tokenizer-7Mformulas/tokenizer.json
Normal file
29830
src/models/tokenizer/roberta-tokenizer-7Mformulas/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,57 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"3": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"4": {
|
||||
"content": "<mask>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"cls_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"errors": "replace",
|
||||
"mask_token": "<mask>",
|
||||
"model_max_length": 1000000000000000019884624838656,
|
||||
"pad_token": "<pad>",
|
||||
"sep_token": "</s>",
|
||||
"tokenizer_class": "RobertaTokenizer",
|
||||
"trim_offsets": true,
|
||||
"unk_token": "<unk>"
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -4,8 +4,8 @@ from ...globals import VOCAB_SIZE
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-raw')
|
||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TexTeller/src/models/tokenizer/roberta-tokenizer-raw')
|
||||
dataset = load_dataset("/home/lhy/code/TexTeller/src/models/ocr_model/train/data/loader.py")['train']
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(text_iterator=dataset['latex_formula'], vocab_size=VOCAB_SIZE)
|
||||
new_tokenizer.save_pretrained('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||
new_tokenizer.save_pretrained('/home/lhy/code/TexTeller/src/models/tokenizer/roberta-tokenizer-7Mformulas')
|
||||
pause = 1
|
||||
|
||||
Reference in New Issue
Block a user