加入和推理和评估的代码

This commit is contained in:
三洋三洋
2024-01-30 08:36:23 +00:00
parent e03b877ed1
commit b7bf5c444f
12 changed files with 485 additions and 135 deletions

BIN
.swp

Binary file not shown.

View File

@@ -1,4 +1,7 @@
import torch
import cv2
import numpy as np
from transformers import RobertaTokenizerFast, GenerationConfig
from PIL import Image
from typing import List
@@ -8,20 +11,52 @@ from .utils.transforms import inference_transform
from ...globals import MAX_TOKEN_SIZE
def png2jpg(imgs: List[Image.Image]):
imgs = [img.convert('RGB') for img in imgs if img.mode in ("RGBA", "P")]
return imgs
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
processed_images = []
for path in image_paths:
# 读取图片
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
print(f"Image at {path} could not be read.")
continue
# 检查图片是否使用 uint16 类型
if image.dtype == np.uint16:
raise ValueError(f"Image at {path} is stored in uint16, which is not supported.")
# 获取图片通道数
channels = 1 if len(image.shape) == 2 else image.shape[2]
# 如果是 RGBA (4通道), 转换为 RGB
if channels == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
# 如果是 I 模式 (单通道灰度图), 转换为 RGB
elif channels == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# 如果是 BGR (3通道), 转换为 RGB
elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(Image.fromarray(image))
return processed_images
def inference(model: TexTeller, imgs: List[Image.Image], tokenizer: RobertaTokenizerFast) -> List[str]:
imgs = png2jpg(imgs) if imgs[0].mode in ('RGBA' ,'P') else imgs
def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]:
imgs = convert2rgb(imgs_path)
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=3,
do_sample=False
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
pred = model.generate(pixel_values, generation_config=generate_config)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)

View File

@@ -39,16 +39,21 @@ class TexTeller(VisionEncoderDecoderModel):
if __name__ == "__main__":
# texteller = TexTeller()
from ..inference import inference
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-22500')
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
img1 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/1.png')
img2 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/2.png')
img3 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/3.png')
img4 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/4.png')
img5 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/5.png')
img6 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/6.png')
base = '/home/lhy/code/TeXify/src/models/ocr_model/model'
imgs_path = [
# base + '/1.jpg',
# base + '/2.jpg',
# base + '/3.jpg',
# base + '/4.jpg',
# base + '/5.jpg',
# base + '/6.jpg',
base + '/foo.jpg'
]
res = inference(model, [img1, img2, img3, img4, img5, img6], tokenizer)
# res = inference(model, [img1, img2, img3, img4, img5, img6, img7], tokenizer)
res = inference(model, imgs_path, tokenizer)
pause = 1

View File

@@ -0,0 +1,168 @@
# Copyright 2020 The HuggingFace Evaluate Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Google BLEU (aka GLEU) metric. """
from typing import Dict, List
import datasets
from nltk.translate import gleu_score
import evaluate
from evaluate import MetricInfo
from .tokenizer_13a import Tokenizer13a
_CITATION = """\
@misc{wu2016googles,
title={Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation},
author={Yonghui Wu and Mike Schuster and Zhifeng Chen and Quoc V. Le and Mohammad Norouzi and Wolfgang Macherey
and Maxim Krikun and Yuan Cao and Qin Gao and Klaus Macherey and Jeff Klingner and Apurva Shah and Melvin
Johnson and Xiaobing Liu and Łukasz Kaiser and Stephan Gouws and Yoshikiyo Kato and Taku Kudo and Hideto
Kazawa and Keith Stevens and George Kurian and Nishant Patil and Wei Wang and Cliff Young and
Jason Smith and Jason Riesa and Alex Rudnick and Oriol Vinyals and Greg Corrado and Macduff Hughes
and Jeffrey Dean},
year={2016},
eprint={1609.08144},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""
_DESCRIPTION = """\
The BLEU score has some undesirable properties when used for single
sentences, as it was designed to be a corpus measure. We therefore
use a slightly different score for our RL experiments which we call
the 'GLEU score'. For the GLEU score, we record all sub-sequences of
1, 2, 3 or 4 tokens in output and target sequence (n-grams). We then
compute a recall, which is the ratio of the number of matching n-grams
to the number of total n-grams in the target (ground truth) sequence,
and a precision, which is the ratio of the number of matching n-grams
to the number of total n-grams in the generated output sequence. Then
GLEU score is simply the minimum of recall and precision. This GLEU
score's range is always between 0 (no matches) and 1 (all match) and
it is symmetrical when switching output and target. According to
our experiments, GLEU score correlates quite well with the BLEU
metric on a corpus level but does not have its drawbacks for our per
sentence reward objective.
"""
_KWARGS_DESCRIPTION = """\
Computes corpus-level Google BLEU (GLEU) score of translated segments against one or more references.
Instead of averaging the sentence level GLEU scores (i.e. macro-average precision), Wu et al. (2016) sum up the matching
tokens and the max of hypothesis and reference tokens for each sentence, then compute using the aggregate values.
Args:
predictions (list of str): list of translations to score.
references (list of list of str): list of lists of references for each translation.
tokenizer : approach used for tokenizing `predictions` and `references`.
The default tokenizer is `tokenizer_13a`, a minimal tokenization approach that is equivalent to `mteval-v13a`, used by WMT.
This can be replaced by any function that takes a string as input and returns a list of tokens as output.
min_len (int): The minimum order of n-gram this function should extract. Defaults to 1.
max_len (int): The maximum order of n-gram this function should extract. Defaults to 4.
Returns:
'google_bleu': google_bleu score
Examples:
Example 1:
>>> predictions = ['It is a guide to action which ensures that the rubber duck always disobeys the commands of the cat', \
'he read the book because he was interested in world history']
>>> references = [['It is the guiding principle which guarantees the rubber duck forces never being under the command of the cat'], \
['he was interested in world history because he read the book']]
>>> google_bleu = evaluate.load("google_bleu")
>>> results = google_bleu.compute(predictions=predictions, references=references)
>>> print(round(results["google_bleu"], 2))
0.44
Example 2:
>>> predictions = ['It is a guide to action which ensures that the rubber duck always disobeys the commands of the cat', \
'he read the book because he was interested in world history']
>>> references = [['It is the guiding principle which guarantees the rubber duck forces never being under the command of the cat', \
'It is a guide to action that ensures that the rubber duck will never heed the cat commands', \
'It is the practical guide for the rubber duck army never to heed the directions of the cat'], \
['he was interested in world history because he read the book']]
>>> google_bleu = evaluate.load("google_bleu")
>>> results = google_bleu.compute(predictions=predictions, references=references)
>>> print(round(results["google_bleu"], 2))
0.61
Example 3:
>>> predictions = ['It is a guide to action which ensures that the rubber duck always disobeys the commands of the cat', \
'he read the book because he was interested in world history']
>>> references = [['It is the guiding principle which guarantees the rubber duck forces never being under the command of the cat', \
'It is a guide to action that ensures that the rubber duck will never heed the cat commands', \
'It is the practical guide for the rubber duck army never to heed the directions of the cat'], \
['he was interested in world history because he read the book']]
>>> google_bleu = evaluate.load("google_bleu")
>>> results = google_bleu.compute(predictions=predictions, references=references, min_len=2)
>>> print(round(results["google_bleu"], 2))
0.53
Example 4:
>>> predictions = ['It is a guide to action which ensures that the rubber duck always disobeys the commands of the cat', \
'he read the book because he was interested in world history']
>>> references = [['It is the guiding principle which guarantees the rubber duck forces never being under the command of the cat', \
'It is a guide to action that ensures that the rubber duck will never heed the cat commands', \
'It is the practical guide for the rubber duck army never to heed the directions of the cat'], \
['he was interested in world history because he read the book']]
>>> google_bleu = evaluate.load("google_bleu")
>>> results = google_bleu.compute(predictions=predictions,references=references, min_len=2, max_len=6)
>>> print(round(results["google_bleu"], 2))
0.4
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class GoogleBleu(evaluate.Metric):
def _info(self) -> MetricInfo:
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=[
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
}
),
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Value("string", id="sequence"),
}
),
],
)
def _compute(
self,
predictions: List[str],
references: List[List[str]],
tokenizer=Tokenizer13a(),
min_len: int = 1,
max_len: int = 4,
) -> Dict[str, float]:
# if only one reference is provided make sure we still use list of lists
if isinstance(references[0], str):
references = [[ref] for ref in references]
references = [[tokenizer(r) for r in ref] for ref in references]
predictions = [tokenizer(p) for p in predictions]
return {
"google_bleu": gleu_score.corpus_gleu(
list_of_references=references, hypotheses=predictions, min_len=min_len, max_len=max_len
)
}

View File

@@ -0,0 +1,100 @@
# Source: https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/tokenizers/tokenizer_13a.py
# Copyright 2020 SacreBLEU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from functools import lru_cache
class BaseTokenizer:
"""A base dummy tokenizer to derive from."""
def signature(self):
"""
Returns a signature for the tokenizer.
:return: signature string
"""
return "none"
def __call__(self, line):
"""
Tokenizes an input line with the tokenizer.
:param line: a segment to tokenize
:return: the tokenized line
"""
return line
class TokenizerRegexp(BaseTokenizer):
def signature(self):
return "re"
def __init__(self):
self._re = [
# language-dependent part (assuming Western languages)
(re.compile(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])"), r" \1 "),
# tokenize period and comma unless preceded by a digit
(re.compile(r"([^0-9])([\.,])"), r"\1 \2 "),
# tokenize period and comma unless followed by a digit
(re.compile(r"([\.,])([^0-9])"), r" \1 \2"),
# tokenize dash when preceded by a digit
(re.compile(r"([0-9])(-)"), r"\1 \2 "),
# one space only between words
# NOTE: Doing this in Python (below) is faster
# (re.compile(r'\s+'), r' '),
]
@lru_cache(maxsize=2**16)
def __call__(self, line):
"""Common post-processing tokenizer for `13a` and `zh` tokenizers.
:param line: a segment to tokenize
:return: the tokenized line
"""
for (_re, repl) in self._re:
line = _re.sub(repl, line)
# no leading or trailing spaces, single space within words
# return ' '.join(line.split())
# This line is changed with regards to the original tokenizer (seen above) to return individual words
return line.split()
class Tokenizer13a(BaseTokenizer):
def signature(self):
return "13a"
def __init__(self):
self._post_tokenizer = TokenizerRegexp()
@lru_cache(maxsize=2**16)
def __call__(self, line):
"""Tokenizes an input line using a relatively minimal tokenization
that is however equivalent to mteval-v13a, used by WMT.
:param line: a segment to tokenize
:return: the tokenized line
"""
# language-independent part:
line = line.replace("<skipped>", "")
line = line.replace("-\n", "")
line = line.replace("\n", " ")
if "&" in line:
line = line.replace("&quot;", '"')
line = line.replace("&amp;", "&")
line = line.replace("&lt;", "<")
line = line.replace("&gt;", ">")
return self._post_tokenizer(f" {line} ")

View File

@@ -4,120 +4,16 @@ from functools import partial
from pathlib import Path
from datasets import load_dataset
from transformers import Trainer, TrainingArguments, Seq2SeqTrainer
from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrainingArguments, GenerationConfig
from .training_args import CONFIG
from ..model.TexTeller import TexTeller
from ..utils.preprocess import tokenize_fn, collate_fn, img_preprocess
training_args = TrainingArguments(
seed=42, # 随机种子,用于确保实验的可重复性
use_cpu=False, # 是否使用cpu刚开始测试代码的时候先用cpu跑会更容易debug
# data_seed=42, # data sampler的采样也固定
# full_determinism=True, # 使整个训练完全固定这个设置会有害于模型训练只用于debug
output_dir="train_result", # 输出目录
overwrite_output_dir=False, # 如果输出目录存在,不删除原先的内容
report_to=["tensorboard"], # 输出日志到TensorBoard
#+通过在命令行tensorboard --logdir ./logs 来查看日志
logging_dir=None, # TensorBoard日志文件的存储目录(使用默认值)
log_level="info", # 其他可选:debug, info, warning, error and critical由低级别到高级别
logging_strategy="steps", # 每隔一定步数记录一次日志
logging_steps=500, # 记录日志的步数间隔可以是int也可以是(0~1)的float当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
#+通常与eval_steps一致
logging_nan_inf_filter=False, # 对loss=nan或inf进行记录
num_train_epochs=10, # 总的训练轮数
# max_steps=3, # 训练的最大步骤数。如果设置了这个参数,
#+那么num_train_epochs将被忽略通常用于调试
# label_names = ['your_label_name'], # 指定data_loader中的标签名如果不指定则默认为'labels'
per_device_train_batch_size=128, # 每个GPU的batch size
per_device_eval_batch_size=16, # 每个GPU的evaluation batch size
auto_find_batch_size=True, # 自动搜索合适的batch size指数decay
optim = 'adamw_torch', # 还提供了很多AdamW的变体相较于经典的AdamW更加高效
#+当设置了optim后就不需要在Trainer中传入optimizer
lr_scheduler_type="cosine", # 设置lr_scheduler
warmup_ratio=0.1, # warmup占整个训练steps的比例(假如训练1000步那么前100步就是从lr=0慢慢长到参数设定的lr)
# warmup_steps=500, # 预热步数, 这个参数与warmup_ratio是矛盾的
weight_decay=0, # 权重衰减
learning_rate=5e-5, # 学习率
max_grad_norm=1.0, # 用于梯度裁剪确保梯度的范数不超过1.0默认1.0
fp16=False, # 是否使用16位浮点数进行训练一般不推荐loss很容易炸
bf16=False, # 是否使用16位宽浮点数进行训练如果架构支持的话推荐使用
gradient_accumulation_steps=2, # 梯度累积步数当batch size无法开很大时可以考虑这个参数来实现大batch size的效果
gradient_checkpointing=False, # 当为True时会在forward时适当丢弃一些中间量用于backward从而减轻显存压力但会增加forward的时间
label_smoothing_factor=0.0, # softlabel等于0时表示未开启
# debug='underflow_overflow', # 训练时检查溢出如果发生则会发出警告。该模式通常用于debug
jit_mode_eval=True, # 是否在eval的时候使用PyTorch jit trace可以加速模型但模型必须是静态的否则会报错
torch_compile=True, # 是否使用torch.compile来编译模型从而获得更好的训练和推理性能
#+ 要求torch > 2.0,这个功能很好使,当模型跑通的时候可以开起来
# deepspeed='your_json_path', # 使用deepspeed来训练需要指定ds_config.json的路径
#+ 在Trainer中使用Deepspeed时一定要注意ds_config.json中的配置是否与Trainer的一致如学习率batch size梯度累积步数等
#+ 如果不一致会出现很奇怪的bug而且一般还很难发现
dataloader_pin_memory=True, # 可以加快数据在cpu和gpu之间转移的速度
dataloader_num_workers=16, # 默认不会使用多进程来加载数据通常设成4*所用的显卡数
dataloader_drop_last=True, # 丢掉最后一个minibatch保证训练的梯度稳定
evaluation_strategy="steps", # 评估策略,可以是"steps"或"epoch"
eval_steps=500, # if evaluation_strategy="step"
#+默认情况下与logging_steps一样可以是int也可以是(0~1)的float当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
save_strategy="steps", # 保存checkpoint的策略
save_steps=500, # checkpoint保存的步数间隔可以是int也可以是(0~1)的float当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
save_total_limit=5, # 保存的模型的最大数量。如果超过这个数量,最旧的模型将被删除
load_best_model_at_end=True, # 训练结束时是否加载最佳模型
#+当设置True时会保存训练时评估结果最好的checkpoint
#+当设置True时evaluation_strategy必须与save_strategy一样并且save_steps必须是eval_steps的整数倍
metric_for_best_model="eval_loss", # 用于选择最佳模型的指标(必须与load_best_model_at_end一起用)
#+可以使用compute_metrics输出的evaluation的结果中一个字典的某个值
#+注意Trainer会在compute_metrics输出的字典的键前面加上一个prefix默认就是“eval_”
greater_is_better=False, # 指标值越小越好(必须与metric_for_best_model一起用)
do_train=True, # 是否进行训练,通常用于调试
do_eval=True, # 是否进行评估,通常用于调试
remove_unused_columns=False, # 是否删除没有用到的列特征默认为True
#+当删除了没用到的列后making it easier to unpack inputs into the models call function
#+注意remove_unused_columns去除列的操作会把传入的dataset的columns_names与模型forward方法中的参数名进行配对对于不存在forward方法中的列名就会直接删掉整个feature
#+因此如果在dataset.with_transform(..)中给数据进行改名那么这个remove操作会直接把原始的数据直接删掉从而导致之后会拿到一个空的dataset导致在对dataset进行切片取值时出问题
#+例如读进来的dataset图片对应的feature name叫"images"而模型forward方法中对应的参数名叫“pixel_values”
#+此时如果是在data.withtransfrom(..)中根据这个"images"生成其他模型forward方法中需要的参数然后再把"images"改名成“pixel_values”那么整个过程就会出问题
#+因为设置了remove_unused_columns=True后会先给dataset进行列名检查然后“images”这个feature会直接被删掉导致with_transform的transform_fn拿不到“images”这个feature
#+所以一个good practice就是对于要改名的特征先提前使用dataset.rename_column进行改名
push_to_hub=False, # 是否训练完后上传hub需要先在命令行huggingface-cli login进行登录认证的配置配置完后认证信息会存到cache文件夹里
)
from ..utils.metrics import bleu_metric
from ....globals import MAX_TOKEN_SIZE
def main():
# dataset = load_dataset(
# '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
# 'cleaned_formulas'
# )['train'].select(range(500))
dataset = load_dataset(
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
'cleaned_formulas'
)['train']
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
# tokenized_formula = tokenized_formula.to_dict()
# tokenized_formula['pixel_values'] = dataset['image']
# tokenized_dataset = dataset.from_dict(tokenized_formula)
tokenized_dataset = tokenized_dataset.with_transform(img_preprocess)
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
model = TexTeller()
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
training_args = TrainingArguments(**CONFIG)
trainer = Trainer(
model,
training_args,
@@ -132,19 +28,40 @@ def main():
trainer.train(resume_from_checkpoint=None)
"""
一个metric_function的另一个case
def evaluate(model, tokenizer, eval_dataset, collate_fn):
eval_config = CONFIG.copy()
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
eval_config['output_dir'] = 'debug_dir'
eval_config['predict_with_generate'] = True
eval_config['predict_with_generate'] = True
eval_config['dataloader_num_workers'] = 1
eval_config['jit_mode_eval'] = False
eval_config['torch_compile'] = False
eval_config['auto_find_batch_size'] = False
eval_config['generation_config'] = generate_config
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
# Setup evaluation
metric = evaluate.load("accuracy")
trainer = Seq2SeqTrainer(
model,
seq2seq_config,
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
"""
eval_dataset=eval_dataset.select(range(16)),
tokenizer=tokenizer,
data_collator=collate_fn,
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
)
res = trainer.evaluate()
pause = 1
model.generate()
...
if __name__ == '__main__':
@@ -152,7 +69,32 @@ if __name__ == '__main__':
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
main()
dataset = load_dataset(
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
'cleaned_formulas'
)['train']
pause = dataset[0]['image']
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
tokenized_dataset = tokenized_dataset.with_transform(img_preprocess)
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
enable_train = False
enable_evaluate = True
if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
if enable_evaluate:
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
os.chdir(cur_path)

View File

@@ -0,0 +1,83 @@
CONFIG = {
"seed": 42, # 随机种子,用于确保实验的可重复性
"use_cpu": False, # 是否使用cpu刚开始测试代码的时候先用cpu跑会更容易debug
# "data_seed": 42, # data sampler的采样也固定
# "full_determinism": True, # 使整个训练完全固定这个设置会有害于模型训练只用于debug
"output_dir": "train_result", # 输出目录
"overwrite_output_dir": False, # 如果输出目录存在,不删除原先的内容
"report_to": ["tensorboard"], # 输出日志到TensorBoard
#+通过在命令行tensorboard --logdir ./logs 来查看日志
"logging_dir": None, # TensorBoard日志文件的存储目录(使用默认值)
"log_level": "info", # 其他可选:debug, info, warning, error and critical由低级别到高级别
"logging_strategy": "steps", # 每隔一定步数记录一次日志
"logging_steps": 500, # 记录日志的步数间隔可以是int也可以是(0~1)的float当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
#+通常与eval_steps一致
"logging_nan_inf_filter": False, # 对loss=nan或inf进行记录
"num_train_epochs": 10, # 总的训练轮数
# "max_steps": 3, # 训练的最大步骤数。如果设置了这个参数,
#+那么num_train_epochs将被忽略通常用于调试
# "label_names": ['your_label_name'], # 指定data_loader中的标签名如果不指定则默认为'labels'
"per_device_train_batch_size": 128, # 每个GPU的batch size
"per_device_eval_batch_size": 16, # 每个GPU的evaluation batch size
"auto_find_batch_size": True, # 自动搜索合适的batch size指数decay
"optim": "adamw_torch", # 还提供了很多AdamW的变体相较于经典的AdamW更加高效
#+当设置了optim后就不需要在Trainer中传入optimizer
"lr_scheduler_type": "cosine", # 设置lr_scheduler
"warmup_ratio": 0.1, # warmup占整个训练steps的比例(假如训练1000步那么前100步就是从lr=0慢慢长到参数设定的lr)
# "warmup_steps": 500, # 预热步数, 这个参数与warmup_ratio是矛盾的
"weight_decay": 0, # 权重衰减
"learning_rate": 5e-5, # 学习率
"max_grad_norm": 1.0, # 用于梯度裁剪确保梯度的范数不超过1.0默认1.0
"fp16": False, # 是否使用16位浮点数进行训练一般不推荐loss很容易炸
"bf16": False, # 是否使用16位宽浮点数进行训练如果架构支持的话推荐使用
"gradient_accumulation_steps": 2, # 梯度累积步数当batch size无法开很大时可以考虑这个参数来实现大batch size的效果
"gradient_checkpointing": False, # 当为True时会在forward时适当丢弃一些中间量用于backward从而减轻显存压力但会增加forward的时间
"label_smoothing_factor": 0.0, # softlabel等于0时表示未开启
# "debug": "underflow_overflow", # 训练时检查溢出如果发生则会发出警告。该模式通常用于debug
"jit_mode_eval": True, # 是否在eval的时候使用PyTorch jit trace可以加速模型但模型必须是静态的否则会报错
"torch_compile": True, # 是否使用torch.compile来编译模型从而获得更好的训练和推理性能
#+ 要求torch > 2.0,这个功能很好使,当模型跑通的时候可以开起来
# "deepspeed": "your_json_path", # 使用deepspeed来训练需要指定ds_config.json的路径
#+ 在Trainer中使用Deepspeed时一定要注意ds_config.json中的配置是否与Trainer的一致如学习率batch size梯度累积步数等
#+ 如果不一致会出现很奇怪的bug而且一般还很难发现
"dataloader_pin_memory": True, # 可以加快数据在cpu和gpu之间转移的速度
"dataloader_num_workers": 16, # 默认不会使用多进程来加载数据通常设成4*所用的显卡数
"dataloader_drop_last": True, # 丢掉最后一个minibatch保证训练的梯度稳定
"evaluation_strategy": "steps", # 评估策略,可以是"steps"或"epoch"
"eval_steps": 500, # if evaluation_strategy="step"
#+默认情况下与logging_steps一样可以是int也可以是(0~1)的float当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
"save_strategy": "steps", # 保存checkpoint的策略
"save_steps": 500, # checkpoint保存的步数间隔可以是int也可以是(0~1)的float当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
"save_total_limit": 5, # 保存的模型的最大数量。如果超过这个数量,最旧的模型将被删除
"load_best_model_at_end": True, # 训练结束时是否加载最佳模型
#+当设置True时会保存训练时评估结果最好的checkpoint
#+当设置True时evaluation_strategy必须与save_strategy一样并且save_steps必须是eval_steps的整数倍
"metric_for_best_model": "eval_loss", # 用于选择最佳模型的指标(必须与load_best_model_at_end一起用)
#+可以使用compute_metrics输出的evaluation的结果中一个字典的某个值
#+注意Trainer会在compute_metrics输出的字典的键前面加上一个prefix默认就是“eval_”
"greater_is_better": False, # 指标值越小越好(必须与metric_for_best_model一起用)
"do_train": True, # 是否进行训练,通常用于调试
"do_eval": True, # 是否进行评估,通常用于调试
"remove_unused_columns": False, # 是否删除没有用到的列特征默认为True
#+当删除了没用到的列后making it easier to unpack inputs into the models call function
#+注意remove_unused_columns去除列的操作会把传入的dataset的columns_names与模型forward方法中的参数名进行配对对于不存在forward方法中的列名就会直接删掉整个feature
#+因此如果在dataset.with_transform(..)中给数据进行改名那么这个remove操作会直接把原始的数据直接删掉从而导致之后会拿到一个空的dataset导致在对dataset进行切片取值时出问题
#+例如读进来的dataset图片对应的feature name叫"images"而模型forward方法中对应的参数名叫“pixel_values”
#+此时如果是在data.withtransfrom(..)中根据这个"images"生成其他模型forward方法中需要的参数然后再把"images"改名成“pixel_values”那么整个过程就会出问题
#+因为设置了remove_unused_columns=True后会先给dataset进行列名检查然后“images”这个feature会直接被删掉导致with_transform的transform_fn拿不到“images”这个feature
#+所以一个good practice就是对于要改名的特征先提前使用dataset.rename_column进行改名
"push_to_hub": False, # 是否训练完后上传hub需要先在命令行huggingface-cli login进行登录认证的配置配置完后认证信息会存到cache文件夹里
}

View File

@@ -0,0 +1,17 @@
import evaluate
import numpy as np
from transformers import EvalPrediction, RobertaTokenizer
from typing import Dict
def bleu_metric(eval_preds:EvalPrediction, tokenizer:RobertaTokenizer) -> Dict:
metric = evaluate.load('/home/lhy/code/TeXify/src/models/ocr_model/train/google_bleu/google_bleu.py') # 这里需要联网,所以会卡住
logits, labels = eval_preds.predictions, eval_preds.label_ids
preds = logits
# preds = np.argmax(logits, axis=1) # 把logits转成对应的预测标签
labels = np.where(labels == -100, 1, labels)
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
return metric.compute(predictions=preds, references=labels)