加入和推理和评估的代码
This commit is contained in:
@@ -1,4 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import RobertaTokenizerFast, GenerationConfig
|
from transformers import RobertaTokenizerFast, GenerationConfig
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import List
|
from typing import List
|
||||||
@@ -8,20 +11,52 @@ from .utils.transforms import inference_transform
|
|||||||
from ...globals import MAX_TOKEN_SIZE
|
from ...globals import MAX_TOKEN_SIZE
|
||||||
|
|
||||||
|
|
||||||
def png2jpg(imgs: List[Image.Image]):
|
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
|
||||||
imgs = [img.convert('RGB') for img in imgs if img.mode in ("RGBA", "P")]
|
processed_images = []
|
||||||
return imgs
|
|
||||||
|
for path in image_paths:
|
||||||
|
# 读取图片
|
||||||
|
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
print(f"Image at {path} could not be read.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查图片是否使用 uint16 类型
|
||||||
|
if image.dtype == np.uint16:
|
||||||
|
raise ValueError(f"Image at {path} is stored in uint16, which is not supported.")
|
||||||
|
|
||||||
|
# 获取图片通道数
|
||||||
|
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
||||||
|
|
||||||
|
# 如果是 RGBA (4通道), 转换为 RGB
|
||||||
|
if channels == 4:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
||||||
|
|
||||||
|
# 如果是 I 模式 (单通道灰度图), 转换为 RGB
|
||||||
|
elif channels == 1:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||||
|
|
||||||
|
# 如果是 BGR (3通道), 转换为 RGB
|
||||||
|
elif channels == 3:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
processed_images.append(Image.fromarray(image))
|
||||||
|
|
||||||
|
return processed_images
|
||||||
|
|
||||||
|
|
||||||
def inference(model: TexTeller, imgs: List[Image.Image], tokenizer: RobertaTokenizerFast) -> List[str]:
|
def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]:
|
||||||
imgs = png2jpg(imgs) if imgs[0].mode in ('RGBA' ,'P') else imgs
|
imgs = convert2rgb(imgs_path)
|
||||||
imgs = inference_transform(imgs)
|
imgs = inference_transform(imgs)
|
||||||
pixel_values = torch.stack(imgs)
|
pixel_values = torch.stack(imgs)
|
||||||
|
|
||||||
generate_config = GenerationConfig(
|
generate_config = GenerationConfig(
|
||||||
max_new_tokens=MAX_TOKEN_SIZE,
|
max_new_tokens=MAX_TOKEN_SIZE,
|
||||||
num_beams=3,
|
num_beams=3,
|
||||||
do_sample=False
|
do_sample=False,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
)
|
)
|
||||||
pred = model.generate(pixel_values, generation_config=generate_config)
|
pred = model.generate(pixel_values, generation_config=generate_config)
|
||||||
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
|
||||||
|
|||||||
@@ -39,16 +39,21 @@ class TexTeller(VisionEncoderDecoderModel):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# texteller = TexTeller()
|
# texteller = TexTeller()
|
||||||
from ..inference import inference
|
from ..inference import inference
|
||||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-22500')
|
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
|
||||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||||
|
|
||||||
img1 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/1.png')
|
base = '/home/lhy/code/TeXify/src/models/ocr_model/model'
|
||||||
img2 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/2.png')
|
imgs_path = [
|
||||||
img3 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/3.png')
|
# base + '/1.jpg',
|
||||||
img4 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/4.png')
|
# base + '/2.jpg',
|
||||||
img5 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/5.png')
|
# base + '/3.jpg',
|
||||||
img6 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/6.png')
|
# base + '/4.jpg',
|
||||||
|
# base + '/5.jpg',
|
||||||
|
# base + '/6.jpg',
|
||||||
|
base + '/foo.jpg'
|
||||||
|
]
|
||||||
|
|
||||||
res = inference(model, [img1, img2, img3, img4, img5, img6], tokenizer)
|
# res = inference(model, [img1, img2, img3, img4, img5, img6, img7], tokenizer)
|
||||||
|
res = inference(model, imgs_path, tokenizer)
|
||||||
pause = 1
|
pause = 1
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
168
src/models/ocr_model/train/google_bleu/google_bleu.py
Normal file
168
src/models/ocr_model/train/google_bleu/google_bleu.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
# Copyright 2020 The HuggingFace Evaluate Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Google BLEU (aka GLEU) metric. """
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
from nltk.translate import gleu_score
|
||||||
|
|
||||||
|
import evaluate
|
||||||
|
from evaluate import MetricInfo
|
||||||
|
|
||||||
|
from .tokenizer_13a import Tokenizer13a
|
||||||
|
|
||||||
|
|
||||||
|
_CITATION = """\
|
||||||
|
@misc{wu2016googles,
|
||||||
|
title={Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation},
|
||||||
|
author={Yonghui Wu and Mike Schuster and Zhifeng Chen and Quoc V. Le and Mohammad Norouzi and Wolfgang Macherey
|
||||||
|
and Maxim Krikun and Yuan Cao and Qin Gao and Klaus Macherey and Jeff Klingner and Apurva Shah and Melvin
|
||||||
|
Johnson and Xiaobing Liu and Łukasz Kaiser and Stephan Gouws and Yoshikiyo Kato and Taku Kudo and Hideto
|
||||||
|
Kazawa and Keith Stevens and George Kurian and Nishant Patil and Wei Wang and Cliff Young and
|
||||||
|
Jason Smith and Jason Riesa and Alex Rudnick and Oriol Vinyals and Greg Corrado and Macduff Hughes
|
||||||
|
and Jeffrey Dean},
|
||||||
|
year={2016},
|
||||||
|
eprint={1609.08144},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.CL}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DESCRIPTION = """\
|
||||||
|
The BLEU score has some undesirable properties when used for single
|
||||||
|
sentences, as it was designed to be a corpus measure. We therefore
|
||||||
|
use a slightly different score for our RL experiments which we call
|
||||||
|
the 'GLEU score'. For the GLEU score, we record all sub-sequences of
|
||||||
|
1, 2, 3 or 4 tokens in output and target sequence (n-grams). We then
|
||||||
|
compute a recall, which is the ratio of the number of matching n-grams
|
||||||
|
to the number of total n-grams in the target (ground truth) sequence,
|
||||||
|
and a precision, which is the ratio of the number of matching n-grams
|
||||||
|
to the number of total n-grams in the generated output sequence. Then
|
||||||
|
GLEU score is simply the minimum of recall and precision. This GLEU
|
||||||
|
score's range is always between 0 (no matches) and 1 (all match) and
|
||||||
|
it is symmetrical when switching output and target. According to
|
||||||
|
our experiments, GLEU score correlates quite well with the BLEU
|
||||||
|
metric on a corpus level but does not have its drawbacks for our per
|
||||||
|
sentence reward objective.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_KWARGS_DESCRIPTION = """\
|
||||||
|
Computes corpus-level Google BLEU (GLEU) score of translated segments against one or more references.
|
||||||
|
Instead of averaging the sentence level GLEU scores (i.e. macro-average precision), Wu et al. (2016) sum up the matching
|
||||||
|
tokens and the max of hypothesis and reference tokens for each sentence, then compute using the aggregate values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictions (list of str): list of translations to score.
|
||||||
|
references (list of list of str): list of lists of references for each translation.
|
||||||
|
tokenizer : approach used for tokenizing `predictions` and `references`.
|
||||||
|
The default tokenizer is `tokenizer_13a`, a minimal tokenization approach that is equivalent to `mteval-v13a`, used by WMT.
|
||||||
|
This can be replaced by any function that takes a string as input and returns a list of tokens as output.
|
||||||
|
min_len (int): The minimum order of n-gram this function should extract. Defaults to 1.
|
||||||
|
max_len (int): The maximum order of n-gram this function should extract. Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
'google_bleu': google_bleu score
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Example 1:
|
||||||
|
>>> predictions = ['It is a guide to action which ensures that the rubber duck always disobeys the commands of the cat', \
|
||||||
|
'he read the book because he was interested in world history']
|
||||||
|
>>> references = [['It is the guiding principle which guarantees the rubber duck forces never being under the command of the cat'], \
|
||||||
|
['he was interested in world history because he read the book']]
|
||||||
|
>>> google_bleu = evaluate.load("google_bleu")
|
||||||
|
>>> results = google_bleu.compute(predictions=predictions, references=references)
|
||||||
|
>>> print(round(results["google_bleu"], 2))
|
||||||
|
0.44
|
||||||
|
|
||||||
|
Example 2:
|
||||||
|
>>> predictions = ['It is a guide to action which ensures that the rubber duck always disobeys the commands of the cat', \
|
||||||
|
'he read the book because he was interested in world history']
|
||||||
|
>>> references = [['It is the guiding principle which guarantees the rubber duck forces never being under the command of the cat', \
|
||||||
|
'It is a guide to action that ensures that the rubber duck will never heed the cat commands', \
|
||||||
|
'It is the practical guide for the rubber duck army never to heed the directions of the cat'], \
|
||||||
|
['he was interested in world history because he read the book']]
|
||||||
|
>>> google_bleu = evaluate.load("google_bleu")
|
||||||
|
>>> results = google_bleu.compute(predictions=predictions, references=references)
|
||||||
|
>>> print(round(results["google_bleu"], 2))
|
||||||
|
0.61
|
||||||
|
|
||||||
|
Example 3:
|
||||||
|
>>> predictions = ['It is a guide to action which ensures that the rubber duck always disobeys the commands of the cat', \
|
||||||
|
'he read the book because he was interested in world history']
|
||||||
|
>>> references = [['It is the guiding principle which guarantees the rubber duck forces never being under the command of the cat', \
|
||||||
|
'It is a guide to action that ensures that the rubber duck will never heed the cat commands', \
|
||||||
|
'It is the practical guide for the rubber duck army never to heed the directions of the cat'], \
|
||||||
|
['he was interested in world history because he read the book']]
|
||||||
|
>>> google_bleu = evaluate.load("google_bleu")
|
||||||
|
>>> results = google_bleu.compute(predictions=predictions, references=references, min_len=2)
|
||||||
|
>>> print(round(results["google_bleu"], 2))
|
||||||
|
0.53
|
||||||
|
|
||||||
|
Example 4:
|
||||||
|
>>> predictions = ['It is a guide to action which ensures that the rubber duck always disobeys the commands of the cat', \
|
||||||
|
'he read the book because he was interested in world history']
|
||||||
|
>>> references = [['It is the guiding principle which guarantees the rubber duck forces never being under the command of the cat', \
|
||||||
|
'It is a guide to action that ensures that the rubber duck will never heed the cat commands', \
|
||||||
|
'It is the practical guide for the rubber duck army never to heed the directions of the cat'], \
|
||||||
|
['he was interested in world history because he read the book']]
|
||||||
|
>>> google_bleu = evaluate.load("google_bleu")
|
||||||
|
>>> results = google_bleu.compute(predictions=predictions,references=references, min_len=2, max_len=6)
|
||||||
|
>>> print(round(results["google_bleu"], 2))
|
||||||
|
0.4
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||||
|
class GoogleBleu(evaluate.Metric):
|
||||||
|
def _info(self) -> MetricInfo:
|
||||||
|
return evaluate.MetricInfo(
|
||||||
|
description=_DESCRIPTION,
|
||||||
|
citation=_CITATION,
|
||||||
|
inputs_description=_KWARGS_DESCRIPTION,
|
||||||
|
features=[
|
||||||
|
datasets.Features(
|
||||||
|
{
|
||||||
|
"predictions": datasets.Value("string", id="sequence"),
|
||||||
|
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
datasets.Features(
|
||||||
|
{
|
||||||
|
"predictions": datasets.Value("string", id="sequence"),
|
||||||
|
"references": datasets.Value("string", id="sequence"),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute(
|
||||||
|
self,
|
||||||
|
predictions: List[str],
|
||||||
|
references: List[List[str]],
|
||||||
|
tokenizer=Tokenizer13a(),
|
||||||
|
min_len: int = 1,
|
||||||
|
max_len: int = 4,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
# if only one reference is provided make sure we still use list of lists
|
||||||
|
if isinstance(references[0], str):
|
||||||
|
references = [[ref] for ref in references]
|
||||||
|
|
||||||
|
references = [[tokenizer(r) for r in ref] for ref in references]
|
||||||
|
predictions = [tokenizer(p) for p in predictions]
|
||||||
|
return {
|
||||||
|
"google_bleu": gleu_score.corpus_gleu(
|
||||||
|
list_of_references=references, hypotheses=predictions, min_len=min_len, max_len=max_len
|
||||||
|
)
|
||||||
|
}
|
||||||
100
src/models/ocr_model/train/google_bleu/tokenizer_13a.py
Normal file
100
src/models/ocr_model/train/google_bleu/tokenizer_13a.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# Source: https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/tokenizers/tokenizer_13a.py
|
||||||
|
# Copyright 2020 SacreBLEU Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTokenizer:
|
||||||
|
"""A base dummy tokenizer to derive from."""
|
||||||
|
|
||||||
|
def signature(self):
|
||||||
|
"""
|
||||||
|
Returns a signature for the tokenizer.
|
||||||
|
:return: signature string
|
||||||
|
"""
|
||||||
|
return "none"
|
||||||
|
|
||||||
|
def __call__(self, line):
|
||||||
|
"""
|
||||||
|
Tokenizes an input line with the tokenizer.
|
||||||
|
:param line: a segment to tokenize
|
||||||
|
:return: the tokenized line
|
||||||
|
"""
|
||||||
|
return line
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerRegexp(BaseTokenizer):
|
||||||
|
def signature(self):
|
||||||
|
return "re"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._re = [
|
||||||
|
# language-dependent part (assuming Western languages)
|
||||||
|
(re.compile(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])"), r" \1 "),
|
||||||
|
# tokenize period and comma unless preceded by a digit
|
||||||
|
(re.compile(r"([^0-9])([\.,])"), r"\1 \2 "),
|
||||||
|
# tokenize period and comma unless followed by a digit
|
||||||
|
(re.compile(r"([\.,])([^0-9])"), r" \1 \2"),
|
||||||
|
# tokenize dash when preceded by a digit
|
||||||
|
(re.compile(r"([0-9])(-)"), r"\1 \2 "),
|
||||||
|
# one space only between words
|
||||||
|
# NOTE: Doing this in Python (below) is faster
|
||||||
|
# (re.compile(r'\s+'), r' '),
|
||||||
|
]
|
||||||
|
|
||||||
|
@lru_cache(maxsize=2**16)
|
||||||
|
def __call__(self, line):
|
||||||
|
"""Common post-processing tokenizer for `13a` and `zh` tokenizers.
|
||||||
|
:param line: a segment to tokenize
|
||||||
|
:return: the tokenized line
|
||||||
|
"""
|
||||||
|
for (_re, repl) in self._re:
|
||||||
|
line = _re.sub(repl, line)
|
||||||
|
|
||||||
|
# no leading or trailing spaces, single space within words
|
||||||
|
# return ' '.join(line.split())
|
||||||
|
# This line is changed with regards to the original tokenizer (seen above) to return individual words
|
||||||
|
return line.split()
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer13a(BaseTokenizer):
|
||||||
|
def signature(self):
|
||||||
|
return "13a"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._post_tokenizer = TokenizerRegexp()
|
||||||
|
|
||||||
|
@lru_cache(maxsize=2**16)
|
||||||
|
def __call__(self, line):
|
||||||
|
"""Tokenizes an input line using a relatively minimal tokenization
|
||||||
|
that is however equivalent to mteval-v13a, used by WMT.
|
||||||
|
|
||||||
|
:param line: a segment to tokenize
|
||||||
|
:return: the tokenized line
|
||||||
|
"""
|
||||||
|
|
||||||
|
# language-independent part:
|
||||||
|
line = line.replace("<skipped>", "")
|
||||||
|
line = line.replace("-\n", "")
|
||||||
|
line = line.replace("\n", " ")
|
||||||
|
|
||||||
|
if "&" in line:
|
||||||
|
line = line.replace(""", '"')
|
||||||
|
line = line.replace("&", "&")
|
||||||
|
line = line.replace("<", "<")
|
||||||
|
line = line.replace(">", ">")
|
||||||
|
|
||||||
|
return self._post_tokenizer(f" {line} ")
|
||||||
@@ -4,120 +4,16 @@ from functools import partial
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import Trainer, TrainingArguments, Seq2SeqTrainer
|
from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrainingArguments, GenerationConfig
|
||||||
|
from .training_args import CONFIG
|
||||||
from ..model.TexTeller import TexTeller
|
from ..model.TexTeller import TexTeller
|
||||||
from ..utils.preprocess import tokenize_fn, collate_fn, img_preprocess
|
from ..utils.preprocess import tokenize_fn, collate_fn, img_preprocess
|
||||||
|
from ..utils.metrics import bleu_metric
|
||||||
training_args = TrainingArguments(
|
from ....globals import MAX_TOKEN_SIZE
|
||||||
seed=42, # 随机种子,用于确保实验的可重复性
|
|
||||||
use_cpu=False, # 是否使用cpu(刚开始测试代码的时候先用cpu跑会更容易debug)
|
|
||||||
# data_seed=42, # data sampler的采样也固定
|
|
||||||
# full_determinism=True, # 使整个训练完全固定(这个设置会有害于模型训练,只用于debug)
|
|
||||||
|
|
||||||
output_dir="train_result", # 输出目录
|
|
||||||
overwrite_output_dir=False, # 如果输出目录存在,不删除原先的内容
|
|
||||||
report_to=["tensorboard"], # 输出日志到TensorBoard,
|
|
||||||
#+通过在命令行:tensorboard --logdir ./logs 来查看日志
|
|
||||||
|
|
||||||
logging_dir=None, # TensorBoard日志文件的存储目录(使用默认值)
|
|
||||||
log_level="info", # 其他可选:‘debug’, ‘info’, ‘warning’, ‘error’ and ‘critical’(由低级别到高级别)
|
|
||||||
logging_strategy="steps", # 每隔一定步数记录一次日志
|
|
||||||
logging_steps=500, # 记录日志的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
|
||||||
#+通常与eval_steps一致
|
|
||||||
logging_nan_inf_filter=False, # 对loss=nan或inf进行记录
|
|
||||||
|
|
||||||
num_train_epochs=10, # 总的训练轮数
|
|
||||||
# max_steps=3, # 训练的最大步骤数。如果设置了这个参数,
|
|
||||||
#+那么num_train_epochs将被忽略(通常用于调试)
|
|
||||||
|
|
||||||
# label_names = ['your_label_name'], # 指定data_loader中的标签名,如果不指定则默认为'labels'
|
|
||||||
|
|
||||||
per_device_train_batch_size=128, # 每个GPU的batch size
|
|
||||||
per_device_eval_batch_size=16, # 每个GPU的evaluation batch size
|
|
||||||
auto_find_batch_size=True, # 自动搜索合适的batch size(指数decay)
|
|
||||||
|
|
||||||
optim = 'adamw_torch', # 还提供了很多AdamW的变体(相较于经典的AdamW更加高效)
|
|
||||||
#+当设置了optim后,就不需要在Trainer中传入optimizer
|
|
||||||
lr_scheduler_type="cosine", # 设置lr_scheduler
|
|
||||||
warmup_ratio=0.1, # warmup占整个训练steps的比例(假如训练1000步,那么前100步就是从lr=0慢慢长到参数设定的lr)
|
|
||||||
# warmup_steps=500, # 预热步数, 这个参数与warmup_ratio是矛盾的
|
|
||||||
weight_decay=0, # 权重衰减
|
|
||||||
learning_rate=5e-5, # 学习率
|
|
||||||
max_grad_norm=1.0, # 用于梯度裁剪,确保梯度的范数不超过1.0(默认1.0)
|
|
||||||
fp16=False, # 是否使用16位浮点数进行训练(一般不推荐,loss很容易炸)
|
|
||||||
bf16=False, # 是否使用16位宽浮点数进行训练(如果架构支持的话推荐使用)
|
|
||||||
gradient_accumulation_steps=2, # 梯度累积步数,当batch size无法开很大时,可以考虑这个参数来实现大batch size的效果
|
|
||||||
gradient_checkpointing=False, # 当为True时,会在forward时适当丢弃一些中间量(用于backward),从而减轻显存压力(但会增加forward的时间)
|
|
||||||
label_smoothing_factor=0.0, # softlabel,等于0时表示未开启
|
|
||||||
# debug='underflow_overflow', # 训练时检查溢出,如果发生,则会发出警告。(该模式通常用于debug)
|
|
||||||
jit_mode_eval=True, # 是否在eval的时候使用PyTorch jit trace(可以加速模型,但模型必须是静态的,否则会报错)
|
|
||||||
torch_compile=True, # 是否使用torch.compile来编译模型(从而获得更好的训练和推理性能)
|
|
||||||
#+ 要求torch > 2.0,这个功能很好使,当模型跑通的时候可以开起来
|
|
||||||
# deepspeed='your_json_path', # 使用deepspeed来训练,需要指定ds_config.json的路径
|
|
||||||
#+ 在Trainer中使用Deepspeed时一定要注意ds_config.json中的配置是否与Trainer的一致(如学习率,batch size,梯度累积步数等)
|
|
||||||
#+ 如果不一致,会出现很奇怪的bug(而且一般还很难发现)
|
|
||||||
|
|
||||||
dataloader_pin_memory=True, # 可以加快数据在cpu和gpu之间转移的速度
|
|
||||||
dataloader_num_workers=16, # 默认不会使用多进程来加载数据,通常设成4*所用的显卡数
|
|
||||||
dataloader_drop_last=True, # 丢掉最后一个minibatch,保证训练的梯度稳定
|
|
||||||
|
|
||||||
evaluation_strategy="steps", # 评估策略,可以是"steps"或"epoch"
|
|
||||||
eval_steps=500, # if evaluation_strategy="step"
|
|
||||||
#+默认情况下与logging_steps一样,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
|
||||||
|
|
||||||
save_strategy="steps", # 保存checkpoint的策略
|
|
||||||
save_steps=500, # checkpoint保存的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
|
||||||
save_total_limit=5, # 保存的模型的最大数量。如果超过这个数量,最旧的模型将被删除
|
|
||||||
|
|
||||||
load_best_model_at_end=True, # 训练结束时是否加载最佳模型
|
|
||||||
#+当设置True时,会保存训练时评估结果最好的checkpoint
|
|
||||||
#+当设置True时,evaluation_strategy必须与save_strategy一样,并且save_steps必须是eval_steps的整数倍
|
|
||||||
metric_for_best_model="eval_loss", # 用于选择最佳模型的指标(必须与load_best_model_at_end一起用)
|
|
||||||
#+可以使用compute_metrics输出的evaluation的结果中(一个字典)的某个值
|
|
||||||
#+注意:Trainer会在compute_metrics输出的字典的键前面加上一个prefix,默认就是“eval_”
|
|
||||||
greater_is_better=False, # 指标值越小越好(必须与metric_for_best_model一起用)
|
|
||||||
|
|
||||||
do_train=True, # 是否进行训练,通常用于调试
|
|
||||||
do_eval=True, # 是否进行评估,通常用于调试
|
|
||||||
|
|
||||||
remove_unused_columns=False, # 是否删除没有用到的列(特征),默认为True
|
|
||||||
#+当删除了没用到的列后,making it easier to unpack inputs into the model’s call function
|
|
||||||
#+注意:remove_unused_columns去除列的操作会把传入的dataset的columns_names与模型forward方法中的参数名进行配对,对于不存在forward方法中的列名就会直接删掉整个feature
|
|
||||||
#+因此如果在dataset.with_transform(..)中给数据进行改名,那么这个remove操作会直接把原始的数据直接删掉,从而导致之后会拿到一个空的dataset,导致在对dataset进行切片取值时出问题
|
|
||||||
#+例如读进来的dataset图片对应的feature name叫"images",而模型forward方法中对应的参数名叫“pixel_values”,
|
|
||||||
#+此时如果是在data.withtransfrom(..)中根据这个"images"生成其他模型forward方法中需要的参数,然后再把"images"改名成“pixel_values”,那么整个过程就会出问题
|
|
||||||
#+因为设置了remove_unused_columns=True后,会先给dataset进行列名检查,然后“images”这个feature会直接被删掉(导致with_transform的transform_fn拿不到“images”这个feature)
|
|
||||||
#+所以一个good practice就是:对于要改名的特征,先提前使用dataset.rename_column进行改名
|
|
||||||
|
|
||||||
push_to_hub=False, # 是否训练完后上传hub,需要先在命令行:huggingface-cli login进行登录认证的配置,配置完后,认证信息会存到cache文件夹里
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||||
# dataset = load_dataset(
|
training_args = TrainingArguments(**CONFIG)
|
||||||
# '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
|
||||||
# 'cleaned_formulas'
|
|
||||||
# )['train'].select(range(500))
|
|
||||||
dataset = load_dataset(
|
|
||||||
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
|
||||||
'cleaned_formulas'
|
|
||||||
)['train']
|
|
||||||
|
|
||||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
|
||||||
|
|
||||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
|
||||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
|
||||||
|
|
||||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
|
|
||||||
# tokenized_formula = tokenized_formula.to_dict()
|
|
||||||
# tokenized_formula['pixel_values'] = dataset['image']
|
|
||||||
# tokenized_dataset = dataset.from_dict(tokenized_formula)
|
|
||||||
tokenized_dataset = tokenized_dataset.with_transform(img_preprocess)
|
|
||||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
|
||||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
|
||||||
|
|
||||||
model = TexTeller()
|
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model,
|
model,
|
||||||
training_args,
|
training_args,
|
||||||
@@ -132,19 +28,40 @@ def main():
|
|||||||
trainer.train(resume_from_checkpoint=None)
|
trainer.train(resume_from_checkpoint=None)
|
||||||
|
|
||||||
|
|
||||||
"""
|
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||||
一个metric_function的另一个case:
|
eval_config = CONFIG.copy()
|
||||||
|
generate_config = GenerationConfig(
|
||||||
|
max_new_tokens=MAX_TOKEN_SIZE,
|
||||||
|
num_beams=1,
|
||||||
|
do_sample=False,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
|
)
|
||||||
|
eval_config['output_dir'] = 'debug_dir'
|
||||||
|
eval_config['predict_with_generate'] = True
|
||||||
|
eval_config['predict_with_generate'] = True
|
||||||
|
eval_config['dataloader_num_workers'] = 1
|
||||||
|
eval_config['jit_mode_eval'] = False
|
||||||
|
eval_config['torch_compile'] = False
|
||||||
|
eval_config['auto_find_batch_size'] = False
|
||||||
|
eval_config['generation_config'] = generate_config
|
||||||
|
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
|
||||||
|
|
||||||
# Setup evaluation
|
trainer = Seq2SeqTrainer(
|
||||||
metric = evaluate.load("accuracy")
|
model,
|
||||||
|
seq2seq_config,
|
||||||
|
|
||||||
def compute_metrics(eval_pred):
|
eval_dataset=eval_dataset.select(range(16)),
|
||||||
logits, labels = eval_pred
|
tokenizer=tokenizer,
|
||||||
predictions = np.argmax(logits, axis=-1)
|
data_collator=collate_fn,
|
||||||
return metric.compute(predictions=predictions, references=labels)
|
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
||||||
"""
|
)
|
||||||
|
|
||||||
|
res = trainer.evaluate()
|
||||||
pause = 1
|
pause = 1
|
||||||
model.generate()
|
...
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -152,7 +69,32 @@ if __name__ == '__main__':
|
|||||||
script_dirpath = Path(__file__).resolve().parent
|
script_dirpath = Path(__file__).resolve().parent
|
||||||
os.chdir(script_dirpath)
|
os.chdir(script_dirpath)
|
||||||
|
|
||||||
main()
|
|
||||||
|
dataset = load_dataset(
|
||||||
|
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||||
|
'cleaned_formulas'
|
||||||
|
)['train']
|
||||||
|
|
||||||
|
pause = dataset[0]['image']
|
||||||
|
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||||
|
|
||||||
|
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||||
|
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
|
||||||
|
tokenized_dataset = tokenized_dataset.with_transform(img_preprocess)
|
||||||
|
|
||||||
|
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
||||||
|
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||||
|
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||||
|
# model = TexTeller()
|
||||||
|
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
|
||||||
|
|
||||||
|
enable_train = False
|
||||||
|
enable_evaluate = True
|
||||||
|
if enable_train:
|
||||||
|
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||||
|
if enable_evaluate:
|
||||||
|
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
os.chdir(cur_path)
|
os.chdir(cur_path)
|
||||||
|
|
||||||
|
|||||||
83
src/models/ocr_model/train/training_args.py
Normal file
83
src/models/ocr_model/train/training_args.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
CONFIG = {
|
||||||
|
"seed": 42, # 随机种子,用于确保实验的可重复性
|
||||||
|
"use_cpu": False, # 是否使用cpu(刚开始测试代码的时候先用cpu跑会更容易debug)
|
||||||
|
# "data_seed": 42, # data sampler的采样也固定
|
||||||
|
# "full_determinism": True, # 使整个训练完全固定(这个设置会有害于模型训练,只用于debug)
|
||||||
|
|
||||||
|
"output_dir": "train_result", # 输出目录
|
||||||
|
"overwrite_output_dir": False, # 如果输出目录存在,不删除原先的内容
|
||||||
|
"report_to": ["tensorboard"], # 输出日志到TensorBoard,
|
||||||
|
#+通过在命令行:tensorboard --logdir ./logs 来查看日志
|
||||||
|
|
||||||
|
"logging_dir": None, # TensorBoard日志文件的存储目录(使用默认值)
|
||||||
|
"log_level": "info", # 其他可选:‘debug’, ‘info’, ‘warning’, ‘error’ and ‘critical’(由低级别到高级别)
|
||||||
|
"logging_strategy": "steps", # 每隔一定步数记录一次日志
|
||||||
|
"logging_steps": 500, # 记录日志的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
||||||
|
#+通常与eval_steps一致
|
||||||
|
"logging_nan_inf_filter": False, # 对loss=nan或inf进行记录
|
||||||
|
|
||||||
|
"num_train_epochs": 10, # 总的训练轮数
|
||||||
|
# "max_steps": 3, # 训练的最大步骤数。如果设置了这个参数,
|
||||||
|
#+那么num_train_epochs将被忽略(通常用于调试)
|
||||||
|
|
||||||
|
# "label_names": ['your_label_name'], # 指定data_loader中的标签名,如果不指定则默认为'labels'
|
||||||
|
|
||||||
|
"per_device_train_batch_size": 128, # 每个GPU的batch size
|
||||||
|
"per_device_eval_batch_size": 16, # 每个GPU的evaluation batch size
|
||||||
|
"auto_find_batch_size": True, # 自动搜索合适的batch size(指数decay)
|
||||||
|
|
||||||
|
"optim": "adamw_torch", # 还提供了很多AdamW的变体(相较于经典的AdamW更加高效)
|
||||||
|
#+当设置了optim后,就不需要在Trainer中传入optimizer
|
||||||
|
"lr_scheduler_type": "cosine", # 设置lr_scheduler
|
||||||
|
"warmup_ratio": 0.1, # warmup占整个训练steps的比例(假如训练1000步,那么前100步就是从lr=0慢慢长到参数设定的lr)
|
||||||
|
# "warmup_steps": 500, # 预热步数, 这个参数与warmup_ratio是矛盾的
|
||||||
|
"weight_decay": 0, # 权重衰减
|
||||||
|
"learning_rate": 5e-5, # 学习率
|
||||||
|
"max_grad_norm": 1.0, # 用于梯度裁剪,确保梯度的范数不超过1.0(默认1.0)
|
||||||
|
"fp16": False, # 是否使用16位浮点数进行训练(一般不推荐,loss很容易炸)
|
||||||
|
"bf16": False, # 是否使用16位宽浮点数进行训练(如果架构支持的话推荐使用)
|
||||||
|
"gradient_accumulation_steps": 2, # 梯度累积步数,当batch size无法开很大时,可以考虑这个参数来实现大batch size的效果
|
||||||
|
"gradient_checkpointing": False, # 当为True时,会在forward时适当丢弃一些中间量(用于backward),从而减轻显存压力(但会增加forward的时间)
|
||||||
|
"label_smoothing_factor": 0.0, # softlabel,等于0时表示未开启
|
||||||
|
# "debug": "underflow_overflow", # 训练时检查溢出,如果发生,则会发出警告。(该模式通常用于debug)
|
||||||
|
"jit_mode_eval": True, # 是否在eval的时候使用PyTorch jit trace(可以加速模型,但模型必须是静态的,否则会报错)
|
||||||
|
"torch_compile": True, # 是否使用torch.compile来编译模型(从而获得更好的训练和推理性能)
|
||||||
|
#+ 要求torch > 2.0,这个功能很好使,当模型跑通的时候可以开起来
|
||||||
|
# "deepspeed": "your_json_path", # 使用deepspeed来训练,需要指定ds_config.json的路径
|
||||||
|
#+ 在Trainer中使用Deepspeed时一定要注意ds_config.json中的配置是否与Trainer的一致(如学习率,batch size,梯度累积步数等)
|
||||||
|
#+ 如果不一致,会出现很奇怪的bug(而且一般还很难发现)
|
||||||
|
|
||||||
|
"dataloader_pin_memory": True, # 可以加快数据在cpu和gpu之间转移的速度
|
||||||
|
"dataloader_num_workers": 16, # 默认不会使用多进程来加载数据,通常设成4*所用的显卡数
|
||||||
|
"dataloader_drop_last": True, # 丢掉最后一个minibatch,保证训练的梯度稳定
|
||||||
|
|
||||||
|
"evaluation_strategy": "steps", # 评估策略,可以是"steps"或"epoch"
|
||||||
|
"eval_steps": 500, # if evaluation_strategy="step"
|
||||||
|
#+默认情况下与logging_steps一样,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
||||||
|
|
||||||
|
"save_strategy": "steps", # 保存checkpoint的策略
|
||||||
|
"save_steps": 500, # checkpoint保存的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000)
|
||||||
|
"save_total_limit": 5, # 保存的模型的最大数量。如果超过这个数量,最旧的模型将被删除
|
||||||
|
|
||||||
|
"load_best_model_at_end": True, # 训练结束时是否加载最佳模型
|
||||||
|
#+当设置True时,会保存训练时评估结果最好的checkpoint
|
||||||
|
#+当设置True时,evaluation_strategy必须与save_strategy一样,并且save_steps必须是eval_steps的整数倍
|
||||||
|
"metric_for_best_model": "eval_loss", # 用于选择最佳模型的指标(必须与load_best_model_at_end一起用)
|
||||||
|
#+可以使用compute_metrics输出的evaluation的结果中(一个字典)的某个值
|
||||||
|
#+注意:Trainer会在compute_metrics输出的字典的键前面加上一个prefix,默认就是“eval_”
|
||||||
|
"greater_is_better": False, # 指标值越小越好(必须与metric_for_best_model一起用)
|
||||||
|
|
||||||
|
"do_train": True, # 是否进行训练,通常用于调试
|
||||||
|
"do_eval": True, # 是否进行评估,通常用于调试
|
||||||
|
|
||||||
|
"remove_unused_columns": False, # 是否删除没有用到的列(特征),默认为True
|
||||||
|
#+当删除了没用到的列后,making it easier to unpack inputs into the model’s call function
|
||||||
|
#+注意:remove_unused_columns去除列的操作会把传入的dataset的columns_names与模型forward方法中的参数名进行配对,对于不存在forward方法中的列名就会直接删掉整个feature
|
||||||
|
#+因此如果在dataset.with_transform(..)中给数据进行改名,那么这个remove操作会直接把原始的数据直接删掉,从而导致之后会拿到一个空的dataset,导致在对dataset进行切片取值时出问题
|
||||||
|
#+例如读进来的dataset图片对应的feature name叫"images",而模型forward方法中对应的参数名叫“pixel_values”,
|
||||||
|
#+此时如果是在data.withtransfrom(..)中根据这个"images"生成其他模型forward方法中需要的参数,然后再把"images"改名成“pixel_values”,那么整个过程就会出问题
|
||||||
|
#+因为设置了remove_unused_columns=True后,会先给dataset进行列名检查,然后“images”这个feature会直接被删掉(导致with_transform的transform_fn拿不到“images”这个feature)
|
||||||
|
#+所以一个good practice就是:对于要改名的特征,先提前使用dataset.rename_column进行改名
|
||||||
|
|
||||||
|
"push_to_hub": False, # 是否训练完后上传hub,需要先在命令行:huggingface-cli login进行登录认证的配置,配置完后,认证信息会存到cache文件夹里
|
||||||
|
}
|
||||||
17
src/models/ocr_model/utils/metrics.py
Normal file
17
src/models/ocr_model/utils/metrics.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import evaluate
|
||||||
|
import numpy as np
|
||||||
|
from transformers import EvalPrediction, RobertaTokenizer
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
def bleu_metric(eval_preds:EvalPrediction, tokenizer:RobertaTokenizer) -> Dict:
|
||||||
|
metric = evaluate.load('/home/lhy/code/TeXify/src/models/ocr_model/train/google_bleu/google_bleu.py') # 这里需要联网,所以会卡住
|
||||||
|
|
||||||
|
logits, labels = eval_preds.predictions, eval_preds.label_ids
|
||||||
|
preds = logits
|
||||||
|
# preds = np.argmax(logits, axis=1) # 把logits转成对应的预测标签
|
||||||
|
|
||||||
|
labels = np.where(labels == -100, 1, labels)
|
||||||
|
|
||||||
|
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||||
|
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
|
return metric.compute(predictions=preds, references=labels)
|
||||||
Reference in New Issue
Block a user