完成了所有代码
This commit is contained in:
@@ -30,7 +30,8 @@ OCR_IMG_MAX_WIDTH = 768
|
|||||||
OCR_IMG_CHANNELS = 1 # 灰度图
|
OCR_IMG_CHANNELS = 1 # 灰度图
|
||||||
|
|
||||||
# ocr模型训练数据集的最长token数
|
# ocr模型训练数据集的最长token数
|
||||||
MAX_TOKEN_SIZE = 600
|
MAX_TOKEN_SIZE = 512 # 模型最长的embedding长度被设置成了512,所以这里必须是512
|
||||||
|
# MAX_TOKEN_SIZE = 600
|
||||||
|
|
||||||
# ocr模型训练时随机缩放的比例
|
# ocr模型训练时随机缩放的比例
|
||||||
MAX_RESIZE_RATIO = 1.15
|
MAX_RESIZE_RATIO = 1.15
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -32,6 +32,7 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
|
|||||||
|
|
||||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||||
eval_config = CONFIG.copy()
|
eval_config = CONFIG.copy()
|
||||||
|
eval_config['predict_with_generate'] = True
|
||||||
generate_config = GenerationConfig(
|
generate_config = GenerationConfig(
|
||||||
max_new_tokens=MAX_TOKEN_SIZE,
|
max_new_tokens=MAX_TOKEN_SIZE,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
@@ -40,31 +41,22 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
|||||||
eos_token_id=tokenizer.eos_token_id,
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
)
|
)
|
||||||
# eval_config['use_cpu'] = True
|
|
||||||
eval_config['output_dir'] = 'debug_dir'
|
|
||||||
eval_config['predict_with_generate'] = True
|
|
||||||
eval_config['predict_with_generate'] = True
|
|
||||||
eval_config['dataloader_num_workers'] = 1
|
|
||||||
eval_config['jit_mode_eval'] = False
|
|
||||||
eval_config['torch_compile'] = False
|
|
||||||
eval_config['auto_find_batch_size'] = False
|
|
||||||
eval_config['generation_config'] = generate_config
|
eval_config['generation_config'] = generate_config
|
||||||
|
eval_config['auto_find_batch_size'] = False
|
||||||
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
|
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
|
||||||
|
|
||||||
trainer = Seq2SeqTrainer(
|
trainer = Seq2SeqTrainer(
|
||||||
model,
|
model,
|
||||||
seq2seq_config,
|
seq2seq_config,
|
||||||
|
|
||||||
eval_dataset=eval_dataset.select(range(100)),
|
eval_dataset=eval_dataset,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=collate_fn,
|
data_collator=collate_fn,
|
||||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
||||||
)
|
)
|
||||||
|
|
||||||
res = trainer.evaluate()
|
res = trainer.evaluate()
|
||||||
pause = 1
|
print(res)
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -72,31 +64,27 @@ if __name__ == '__main__':
|
|||||||
script_dirpath = Path(__file__).resolve().parent
|
script_dirpath = Path(__file__).resolve().parent
|
||||||
os.chdir(script_dirpath)
|
os.chdir(script_dirpath)
|
||||||
|
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||||
'cleaned_formulas'
|
'cleaned_formulas'
|
||||||
)['train']
|
)['train']
|
||||||
# dataset = load_dataset(
|
dataset = dataset.shuffle(seed=42)
|
||||||
# '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
dataset = dataset.flatten_indices()
|
||||||
# 'cleaned_formulas'
|
|
||||||
# )['train'].select(range(1000))
|
|
||||||
|
|
||||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||||
|
|
||||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True)
|
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=True)
|
||||||
# tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=1)
|
|
||||||
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
|
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
|
||||||
|
|
||||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
||||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||||
# model = TexTeller()
|
# model = TexTeller()
|
||||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-80500')
|
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/bugy_train_without_random_resize/checkpoint-82000')
|
||||||
|
|
||||||
enable_train = False
|
enable_train = True
|
||||||
enable_evaluate = True
|
enable_evaluate = False
|
||||||
if enable_train:
|
if enable_train:
|
||||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||||
if enable_evaluate:
|
if enable_evaluate:
|
||||||
@@ -104,42 +92,3 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
|
|
||||||
os.chdir(cur_path)
|
os.chdir(cur_path)
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
if __name__ == '__main__':
|
|
||||||
cur_path = os.getcwd()
|
|
||||||
script_dirpath = Path(__file__).resolve().parent
|
|
||||||
os.chdir(script_dirpath)
|
|
||||||
|
|
||||||
|
|
||||||
dataset = load_dataset(
|
|
||||||
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
|
||||||
'cleaned_formulas'
|
|
||||||
)['train']
|
|
||||||
|
|
||||||
pause = dataset[0]['image']
|
|
||||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
|
||||||
|
|
||||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
|
||||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
|
|
||||||
tokenized_dataset = tokenized_dataset.with_transform(img_preprocess)
|
|
||||||
|
|
||||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
|
||||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
|
||||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
|
||||||
# model = TexTeller()
|
|
||||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-81000')
|
|
||||||
|
|
||||||
enable_train = False
|
|
||||||
enable_evaluate = True
|
|
||||||
if enable_train:
|
|
||||||
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
|
||||||
if enable_evaluate:
|
|
||||||
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
|
||||||
|
|
||||||
|
|
||||||
os.chdir(cur_path)
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
@@ -4,13 +4,13 @@ CONFIG = {
|
|||||||
# "data_seed": 42, # data sampler的采样也固定
|
# "data_seed": 42, # data sampler的采样也固定
|
||||||
# "full_determinism": True, # 使整个训练完全固定(这个设置会有害于模型训练,只用于debug)
|
# "full_determinism": True, # 使整个训练完全固定(这个设置会有害于模型训练,只用于debug)
|
||||||
|
|
||||||
"output_dir": "train_result", # 输出目录
|
"output_dir": "train_result/train_with_random_resize", # 输出目录
|
||||||
"overwrite_output_dir": False, # 如果输出目录存在,不删除原先的内容
|
"overwrite_output_dir": False, # 如果输出目录存在,不删除原先的内容
|
||||||
"report_to": ["tensorboard"], # 输出日志到TensorBoard,
|
"report_to": ["tensorboard"], # 输出日志到TensorBoard,
|
||||||
#+通过在命令行:tensorboard --logdir ./logs 来查看日志
|
#+通过在命令行:tensorboard --logdir ./logs 来查看日志
|
||||||
|
|
||||||
"logging_dir": None, # TensorBoard日志文件的存储目录(使用默认值)
|
"logging_dir": None, # TensorBoard日志文件的存储目录(使用默认值)
|
||||||
"log_level": "info", # 其他可选:‘debug’, ‘info’, ‘warning’, ‘error’ and ‘critical’(由低级别到高级别)
|
"log_level": "warning", # 其他可选:‘debug’, ‘info’, ‘warning’, ‘error’ and ‘critical’(由低级别到高级别)
|
||||||
"logging_strategy": "steps", # 每隔一定步数记录一次日志
|
"logging_strategy": "steps", # 每隔一定步数记录一次日志
|
||||||
"logging_steps": 500, # 记录日志的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
"logging_steps": 500, # 记录日志的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
||||||
#+通常与eval_steps一致
|
#+通常与eval_steps一致
|
||||||
@@ -22,7 +22,7 @@ CONFIG = {
|
|||||||
|
|
||||||
# "label_names": ['your_label_name'], # 指定data_loader中的标签名,如果不指定则默认为'labels'
|
# "label_names": ['your_label_name'], # 指定data_loader中的标签名,如果不指定则默认为'labels'
|
||||||
|
|
||||||
"per_device_train_batch_size": 128, # 每个GPU的batch size
|
"per_device_train_batch_size": 64, # 每个GPU的batch size
|
||||||
"per_device_eval_batch_size": 16, # 每个GPU的evaluation batch size
|
"per_device_eval_batch_size": 16, # 每个GPU的evaluation batch size
|
||||||
"auto_find_batch_size": True, # 自动搜索合适的batch size(指数decay)
|
"auto_find_batch_size": True, # 自动搜索合适的batch size(指数decay)
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[
|
|||||||
|
|
||||||
# 左移labels和decoder_attention_mask
|
# 左移labels和decoder_attention_mask
|
||||||
batch['labels'] = left_move(batch['labels'], -100)
|
batch['labels'] = left_move(batch['labels'], -100)
|
||||||
# batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0)
|
|
||||||
|
|
||||||
# 把list of Image转成一个tensor with (B, C, H, W)
|
# 把list of Image转成一个tensor with (B, C, H, W)
|
||||||
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
||||||
@@ -76,48 +75,3 @@ if __name__ == '__main__':
|
|||||||
out = model(**batch)
|
out = model(**batch)
|
||||||
|
|
||||||
pause = 1
|
pause = 1
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
def left_move(x: torch.Tensor, pad_val):
|
|
||||||
assert len(x.shape) == 2, 'x should be 2-dimensional'
|
|
||||||
lefted_x = torch.ones_like(x)
|
|
||||||
lefted_x[:, :-1] = x[:, 1:]
|
|
||||||
lefted_x[:, -1] = pad_val
|
|
||||||
return lefted_x
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
|
||||||
assert tokenizer is not None, 'tokenizer should not be None'
|
|
||||||
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
|
||||||
tokenized_formula['pixel_values'] = samples['image']
|
|
||||||
return tokenized_formula
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
|
||||||
assert tokenizer is not None, 'tokenizer should not be None'
|
|
||||||
pixel_values = [dic.pop('pixel_values') for dic in samples]
|
|
||||||
|
|
||||||
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
|
||||||
|
|
||||||
batch = clm_collator(samples)
|
|
||||||
batch['pixel_values'] = pixel_values
|
|
||||||
batch['decoder_input_ids'] = batch.pop('input_ids')
|
|
||||||
batch['decoder_attention_mask'] = batch.pop('attention_mask')
|
|
||||||
|
|
||||||
# 左移labels和decoder_attention_mask
|
|
||||||
batch['labels'] = left_move(batch['labels'], -100)
|
|
||||||
batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0)
|
|
||||||
|
|
||||||
# 把list of Image转成一个tensor with (B, C, H, W)
|
|
||||||
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def img_preprocess(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
|
||||||
processed_img = train_transform(samples['pixel_values'])
|
|
||||||
samples['pixel_values'] = processed_img
|
|
||||||
return samples
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import numpy as np
|
|||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from PIL import ImageChops, Image
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from ....globals import (
|
from ....globals import (
|
||||||
@@ -107,7 +106,7 @@ def random_resize(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def general_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||||
# 裁剪掉白边
|
# 裁剪掉白边
|
||||||
images = [trim_white_border(image) for image in images]
|
images = [trim_white_border(image) for image in images]
|
||||||
# general transform pipeline
|
# general transform pipeline
|
||||||
@@ -117,16 +116,16 @@ def general_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
def train_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
def train_transform(images: List[List[List[List]]]) -> List[torch.Tensor]:
|
||||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||||
|
|
||||||
# random resize first
|
# random resize first
|
||||||
# images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||||
return general_transform(images)
|
return general_transform(images)
|
||||||
|
|
||||||
|
|
||||||
def inference_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user