diff --git a/.swp b/.swp deleted file mode 100644 index ff496de..0000000 Binary files a/.swp and /dev/null differ diff --git a/src/models/ocr_model/inference.py b/src/models/ocr_model/inference.py index bc747ac..ce80963 100644 --- a/src/models/ocr_model/inference.py +++ b/src/models/ocr_model/inference.py @@ -1,4 +1,7 @@ import torch +import cv2 +import numpy as np + from transformers import RobertaTokenizerFast, GenerationConfig from PIL import Image from typing import List @@ -8,20 +11,52 @@ from .utils.transforms import inference_transform from ...globals import MAX_TOKEN_SIZE -def png2jpg(imgs: List[Image.Image]): - imgs = [img.convert('RGB') for img in imgs if img.mode in ("RGBA", "P")] - return imgs +def convert2rgb(image_paths: List[str]) -> List[Image.Image]: + processed_images = [] + + for path in image_paths: + # 读取图片 + image = cv2.imread(path, cv2.IMREAD_UNCHANGED) + + if image is None: + print(f"Image at {path} could not be read.") + continue + + # 检查图片是否使用 uint16 类型 + if image.dtype == np.uint16: + raise ValueError(f"Image at {path} is stored in uint16, which is not supported.") + + # 获取图片通道数 + channels = 1 if len(image.shape) == 2 else image.shape[2] + + # 如果是 RGBA (4通道), 转换为 RGB + if channels == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) + + # 如果是 I 模式 (单通道灰度图), 转换为 RGB + elif channels == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + + # 如果是 BGR (3通道), 转换为 RGB + elif channels == 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + processed_images.append(Image.fromarray(image)) + + return processed_images -def inference(model: TexTeller, imgs: List[Image.Image], tokenizer: RobertaTokenizerFast) -> List[str]: - imgs = png2jpg(imgs) if imgs[0].mode in ('RGBA' ,'P') else imgs +def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]: + imgs = convert2rgb(imgs_path) imgs = inference_transform(imgs) pixel_values = torch.stack(imgs) generate_config = GenerationConfig( max_new_tokens=MAX_TOKEN_SIZE, num_beams=3, - do_sample=False + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + bos_token_id=tokenizer.bos_token_id, ) pred = model.generate(pixel_values, generation_config=generate_config) res = tokenizer.batch_decode(pred, skip_special_tokens=True) diff --git a/src/models/ocr_model/model/TexTeller.py b/src/models/ocr_model/model/TexTeller.py index 7f38ec3..ede1260 100644 --- a/src/models/ocr_model/model/TexTeller.py +++ b/src/models/ocr_model/model/TexTeller.py @@ -39,16 +39,21 @@ class TexTeller(VisionEncoderDecoderModel): if __name__ == "__main__": # texteller = TexTeller() from ..inference import inference - model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-22500') + model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500') tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') - img1 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/1.png') - img2 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/2.png') - img3 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/3.png') - img4 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/4.png') - img5 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/5.png') - img6 = Image.open('/home/lhy/code/TeXify/src/models/ocr_model/model/6.png') + base = '/home/lhy/code/TeXify/src/models/ocr_model/model' + imgs_path = [ + # base + '/1.jpg', + # base + '/2.jpg', + # base + '/3.jpg', + # base + '/4.jpg', + # base + '/5.jpg', + # base + '/6.jpg', + base + '/foo.jpg' + ] - res = inference(model, [img1, img2, img3, img4, img5, img6], tokenizer) + # res = inference(model, [img1, img2, img3, img4, img5, img6, img7], tokenizer) + res = inference(model, imgs_path, tokenizer) pause = 1 diff --git a/src/models/ocr_model/train/debug_dir/runs/Jan29_11-44-00_ubuntu-xyp/events.out.tfevents.1706528656.ubuntu-xyp.1184426.0 b/src/models/ocr_model/train/debug_dir/runs/Jan29_11-44-00_ubuntu-xyp/events.out.tfevents.1706528656.ubuntu-xyp.1184426.0 new file mode 100644 index 0000000..f9a268a Binary files /dev/null and b/src/models/ocr_model/train/debug_dir/runs/Jan29_11-44-00_ubuntu-xyp/events.out.tfevents.1706528656.ubuntu-xyp.1184426.0 differ diff --git a/src/models/ocr_model/train/debug_dir/runs/Jan29_11-44-43_ubuntu-xyp/events.out.tfevents.1706528694.ubuntu-xyp.1185240.0 b/src/models/ocr_model/train/debug_dir/runs/Jan29_11-44-43_ubuntu-xyp/events.out.tfevents.1706528694.ubuntu-xyp.1185240.0 new file mode 100644 index 0000000..4695c5c Binary files /dev/null and b/src/models/ocr_model/train/debug_dir/runs/Jan29_11-44-43_ubuntu-xyp/events.out.tfevents.1706528694.ubuntu-xyp.1185240.0 differ diff --git a/src/models/ocr_model/train/debug_dir/runs/Jan30_06-17-24_ubuntu-xyp/events.out.tfevents.1706595465.ubuntu-xyp.1434641.0 b/src/models/ocr_model/train/debug_dir/runs/Jan30_06-17-24_ubuntu-xyp/events.out.tfevents.1706595465.ubuntu-xyp.1434641.0 new file mode 100644 index 0000000..6efff40 Binary files /dev/null and b/src/models/ocr_model/train/debug_dir/runs/Jan30_06-17-24_ubuntu-xyp/events.out.tfevents.1706595465.ubuntu-xyp.1434641.0 differ diff --git a/src/models/ocr_model/train/debug_dir/runs/Jan30_06-18-27_ubuntu-xyp/events.out.tfevents.1706595552.ubuntu-xyp.1435357.0 b/src/models/ocr_model/train/debug_dir/runs/Jan30_06-18-27_ubuntu-xyp/events.out.tfevents.1706595552.ubuntu-xyp.1435357.0 new file mode 100644 index 0000000..f8e143a Binary files /dev/null and b/src/models/ocr_model/train/debug_dir/runs/Jan30_06-18-27_ubuntu-xyp/events.out.tfevents.1706595552.ubuntu-xyp.1435357.0 differ diff --git a/src/models/ocr_model/train/google_bleu/google_bleu.py b/src/models/ocr_model/train/google_bleu/google_bleu.py new file mode 100644 index 0000000..adcc0a3 --- /dev/null +++ b/src/models/ocr_model/train/google_bleu/google_bleu.py @@ -0,0 +1,168 @@ +# Copyright 2020 The HuggingFace Evaluate Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Google BLEU (aka GLEU) metric. """ + +from typing import Dict, List + +import datasets +from nltk.translate import gleu_score + +import evaluate +from evaluate import MetricInfo + +from .tokenizer_13a import Tokenizer13a + + +_CITATION = """\ +@misc{wu2016googles, + title={Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation}, + author={Yonghui Wu and Mike Schuster and Zhifeng Chen and Quoc V. Le and Mohammad Norouzi and Wolfgang Macherey + and Maxim Krikun and Yuan Cao and Qin Gao and Klaus Macherey and Jeff Klingner and Apurva Shah and Melvin + Johnson and Xiaobing Liu and Łukasz Kaiser and Stephan Gouws and Yoshikiyo Kato and Taku Kudo and Hideto + Kazawa and Keith Stevens and George Kurian and Nishant Patil and Wei Wang and Cliff Young and + Jason Smith and Jason Riesa and Alex Rudnick and Oriol Vinyals and Greg Corrado and Macduff Hughes + and Jeffrey Dean}, + year={2016}, + eprint={1609.08144}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +""" + +_DESCRIPTION = """\ +The BLEU score has some undesirable properties when used for single +sentences, as it was designed to be a corpus measure. We therefore +use a slightly different score for our RL experiments which we call +the 'GLEU score'. For the GLEU score, we record all sub-sequences of +1, 2, 3 or 4 tokens in output and target sequence (n-grams). We then +compute a recall, which is the ratio of the number of matching n-grams +to the number of total n-grams in the target (ground truth) sequence, +and a precision, which is the ratio of the number of matching n-grams +to the number of total n-grams in the generated output sequence. Then +GLEU score is simply the minimum of recall and precision. This GLEU +score's range is always between 0 (no matches) and 1 (all match) and +it is symmetrical when switching output and target. According to +our experiments, GLEU score correlates quite well with the BLEU +metric on a corpus level but does not have its drawbacks for our per +sentence reward objective. +""" + +_KWARGS_DESCRIPTION = """\ +Computes corpus-level Google BLEU (GLEU) score of translated segments against one or more references. +Instead of averaging the sentence level GLEU scores (i.e. macro-average precision), Wu et al. (2016) sum up the matching +tokens and the max of hypothesis and reference tokens for each sentence, then compute using the aggregate values. + +Args: + predictions (list of str): list of translations to score. + references (list of list of str): list of lists of references for each translation. + tokenizer : approach used for tokenizing `predictions` and `references`. + The default tokenizer is `tokenizer_13a`, a minimal tokenization approach that is equivalent to `mteval-v13a`, used by WMT. + This can be replaced by any function that takes a string as input and returns a list of tokens as output. + min_len (int): The minimum order of n-gram this function should extract. Defaults to 1. + max_len (int): The maximum order of n-gram this function should extract. Defaults to 4. + +Returns: + 'google_bleu': google_bleu score + +Examples: + Example 1: + >>> predictions = ['It is a guide to action which ensures that the rubber duck always disobeys the commands of the cat', \ + 'he read the book because he was interested in world history'] + >>> references = [['It is the guiding principle which guarantees the rubber duck forces never being under the command of the cat'], \ + ['he was interested in world history because he read the book']] + >>> google_bleu = evaluate.load("google_bleu") + >>> results = google_bleu.compute(predictions=predictions, references=references) + >>> print(round(results["google_bleu"], 2)) + 0.44 + + Example 2: + >>> predictions = ['It is a guide to action which ensures that the rubber duck always disobeys the commands of the cat', \ + 'he read the book because he was interested in world history'] + >>> references = [['It is the guiding principle which guarantees the rubber duck forces never being under the command of the cat', \ + 'It is a guide to action that ensures that the rubber duck will never heed the cat commands', \ + 'It is the practical guide for the rubber duck army never to heed the directions of the cat'], \ + ['he was interested in world history because he read the book']] + >>> google_bleu = evaluate.load("google_bleu") + >>> results = google_bleu.compute(predictions=predictions, references=references) + >>> print(round(results["google_bleu"], 2)) + 0.61 + + Example 3: + >>> predictions = ['It is a guide to action which ensures that the rubber duck always disobeys the commands of the cat', \ + 'he read the book because he was interested in world history'] + >>> references = [['It is the guiding principle which guarantees the rubber duck forces never being under the command of the cat', \ + 'It is a guide to action that ensures that the rubber duck will never heed the cat commands', \ + 'It is the practical guide for the rubber duck army never to heed the directions of the cat'], \ + ['he was interested in world history because he read the book']] + >>> google_bleu = evaluate.load("google_bleu") + >>> results = google_bleu.compute(predictions=predictions, references=references, min_len=2) + >>> print(round(results["google_bleu"], 2)) + 0.53 + + Example 4: + >>> predictions = ['It is a guide to action which ensures that the rubber duck always disobeys the commands of the cat', \ + 'he read the book because he was interested in world history'] + >>> references = [['It is the guiding principle which guarantees the rubber duck forces never being under the command of the cat', \ + 'It is a guide to action that ensures that the rubber duck will never heed the cat commands', \ + 'It is the practical guide for the rubber duck army never to heed the directions of the cat'], \ + ['he was interested in world history because he read the book']] + >>> google_bleu = evaluate.load("google_bleu") + >>> results = google_bleu.compute(predictions=predictions,references=references, min_len=2, max_len=6) + >>> print(round(results["google_bleu"], 2)) + 0.4 +""" + + +@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class GoogleBleu(evaluate.Metric): + def _info(self) -> MetricInfo: + return evaluate.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=[ + datasets.Features( + { + "predictions": datasets.Value("string", id="sequence"), + "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"), + } + ), + datasets.Features( + { + "predictions": datasets.Value("string", id="sequence"), + "references": datasets.Value("string", id="sequence"), + } + ), + ], + ) + + def _compute( + self, + predictions: List[str], + references: List[List[str]], + tokenizer=Tokenizer13a(), + min_len: int = 1, + max_len: int = 4, + ) -> Dict[str, float]: + # if only one reference is provided make sure we still use list of lists + if isinstance(references[0], str): + references = [[ref] for ref in references] + + references = [[tokenizer(r) for r in ref] for ref in references] + predictions = [tokenizer(p) for p in predictions] + return { + "google_bleu": gleu_score.corpus_gleu( + list_of_references=references, hypotheses=predictions, min_len=min_len, max_len=max_len + ) + } diff --git a/src/models/ocr_model/train/google_bleu/tokenizer_13a.py b/src/models/ocr_model/train/google_bleu/tokenizer_13a.py new file mode 100644 index 0000000..c7a1b3d --- /dev/null +++ b/src/models/ocr_model/train/google_bleu/tokenizer_13a.py @@ -0,0 +1,100 @@ +# Source: https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/tokenizers/tokenizer_13a.py +# Copyright 2020 SacreBLEU Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from functools import lru_cache + + +class BaseTokenizer: + """A base dummy tokenizer to derive from.""" + + def signature(self): + """ + Returns a signature for the tokenizer. + :return: signature string + """ + return "none" + + def __call__(self, line): + """ + Tokenizes an input line with the tokenizer. + :param line: a segment to tokenize + :return: the tokenized line + """ + return line + + +class TokenizerRegexp(BaseTokenizer): + def signature(self): + return "re" + + def __init__(self): + self._re = [ + # language-dependent part (assuming Western languages) + (re.compile(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])"), r" \1 "), + # tokenize period and comma unless preceded by a digit + (re.compile(r"([^0-9])([\.,])"), r"\1 \2 "), + # tokenize period and comma unless followed by a digit + (re.compile(r"([\.,])([^0-9])"), r" \1 \2"), + # tokenize dash when preceded by a digit + (re.compile(r"([0-9])(-)"), r"\1 \2 "), + # one space only between words + # NOTE: Doing this in Python (below) is faster + # (re.compile(r'\s+'), r' '), + ] + + @lru_cache(maxsize=2**16) + def __call__(self, line): + """Common post-processing tokenizer for `13a` and `zh` tokenizers. + :param line: a segment to tokenize + :return: the tokenized line + """ + for (_re, repl) in self._re: + line = _re.sub(repl, line) + + # no leading or trailing spaces, single space within words + # return ' '.join(line.split()) + # This line is changed with regards to the original tokenizer (seen above) to return individual words + return line.split() + + +class Tokenizer13a(BaseTokenizer): + def signature(self): + return "13a" + + def __init__(self): + self._post_tokenizer = TokenizerRegexp() + + @lru_cache(maxsize=2**16) + def __call__(self, line): + """Tokenizes an input line using a relatively minimal tokenization + that is however equivalent to mteval-v13a, used by WMT. + + :param line: a segment to tokenize + :return: the tokenized line + """ + + # language-independent part: + line = line.replace("", "") + line = line.replace("-\n", "") + line = line.replace("\n", " ") + + if "&" in line: + line = line.replace(""", '"') + line = line.replace("&", "&") + line = line.replace("<", "<") + line = line.replace(">", ">") + + return self._post_tokenizer(f" {line} ") diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py index b0c8246..5b39b31 100644 --- a/src/models/ocr_model/train/train.py +++ b/src/models/ocr_model/train/train.py @@ -4,120 +4,16 @@ from functools import partial from pathlib import Path from datasets import load_dataset -from transformers import Trainer, TrainingArguments, Seq2SeqTrainer +from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrainingArguments, GenerationConfig +from .training_args import CONFIG from ..model.TexTeller import TexTeller from ..utils.preprocess import tokenize_fn, collate_fn, img_preprocess - -training_args = TrainingArguments( - seed=42, # 随机种子,用于确保实验的可重复性 - use_cpu=False, # 是否使用cpu(刚开始测试代码的时候先用cpu跑会更容易debug) - # data_seed=42, # data sampler的采样也固定 - # full_determinism=True, # 使整个训练完全固定(这个设置会有害于模型训练,只用于debug) - - output_dir="train_result", # 输出目录 - overwrite_output_dir=False, # 如果输出目录存在,不删除原先的内容 - report_to=["tensorboard"], # 输出日志到TensorBoard, - #+通过在命令行:tensorboard --logdir ./logs 来查看日志 - - logging_dir=None, # TensorBoard日志文件的存储目录(使用默认值) - log_level="info", # 其他可选:‘debug’, ‘info’, ‘warning’, ‘error’ and ‘critical’(由低级别到高级别) - logging_strategy="steps", # 每隔一定步数记录一次日志 - logging_steps=500, # 记录日志的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000) - #+通常与eval_steps一致 - logging_nan_inf_filter=False, # 对loss=nan或inf进行记录 - - num_train_epochs=10, # 总的训练轮数 - # max_steps=3, # 训练的最大步骤数。如果设置了这个参数, - #+那么num_train_epochs将被忽略(通常用于调试) - - # label_names = ['your_label_name'], # 指定data_loader中的标签名,如果不指定则默认为'labels' - - per_device_train_batch_size=128, # 每个GPU的batch size - per_device_eval_batch_size=16, # 每个GPU的evaluation batch size - auto_find_batch_size=True, # 自动搜索合适的batch size(指数decay) - - optim = 'adamw_torch', # 还提供了很多AdamW的变体(相较于经典的AdamW更加高效) - #+当设置了optim后,就不需要在Trainer中传入optimizer - lr_scheduler_type="cosine", # 设置lr_scheduler - warmup_ratio=0.1, # warmup占整个训练steps的比例(假如训练1000步,那么前100步就是从lr=0慢慢长到参数设定的lr) - # warmup_steps=500, # 预热步数, 这个参数与warmup_ratio是矛盾的 - weight_decay=0, # 权重衰减 - learning_rate=5e-5, # 学习率 - max_grad_norm=1.0, # 用于梯度裁剪,确保梯度的范数不超过1.0(默认1.0) - fp16=False, # 是否使用16位浮点数进行训练(一般不推荐,loss很容易炸) - bf16=False, # 是否使用16位宽浮点数进行训练(如果架构支持的话推荐使用) - gradient_accumulation_steps=2, # 梯度累积步数,当batch size无法开很大时,可以考虑这个参数来实现大batch size的效果 - gradient_checkpointing=False, # 当为True时,会在forward时适当丢弃一些中间量(用于backward),从而减轻显存压力(但会增加forward的时间) - label_smoothing_factor=0.0, # softlabel,等于0时表示未开启 - # debug='underflow_overflow', # 训练时检查溢出,如果发生,则会发出警告。(该模式通常用于debug) - jit_mode_eval=True, # 是否在eval的时候使用PyTorch jit trace(可以加速模型,但模型必须是静态的,否则会报错) - torch_compile=True, # 是否使用torch.compile来编译模型(从而获得更好的训练和推理性能) - #+ 要求torch > 2.0,这个功能很好使,当模型跑通的时候可以开起来 - # deepspeed='your_json_path', # 使用deepspeed来训练,需要指定ds_config.json的路径 - #+ 在Trainer中使用Deepspeed时一定要注意ds_config.json中的配置是否与Trainer的一致(如学习率,batch size,梯度累积步数等) - #+ 如果不一致,会出现很奇怪的bug(而且一般还很难发现) - - dataloader_pin_memory=True, # 可以加快数据在cpu和gpu之间转移的速度 - dataloader_num_workers=16, # 默认不会使用多进程来加载数据,通常设成4*所用的显卡数 - dataloader_drop_last=True, # 丢掉最后一个minibatch,保证训练的梯度稳定 - - evaluation_strategy="steps", # 评估策略,可以是"steps"或"epoch" - eval_steps=500, # if evaluation_strategy="step" - #+默认情况下与logging_steps一样,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000) - - save_strategy="steps", # 保存checkpoint的策略 - save_steps=500, # checkpoint保存的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000) - save_total_limit=5, # 保存的模型的最大数量。如果超过这个数量,最旧的模型将被删除 - - load_best_model_at_end=True, # 训练结束时是否加载最佳模型 - #+当设置True时,会保存训练时评估结果最好的checkpoint - #+当设置True时,evaluation_strategy必须与save_strategy一样,并且save_steps必须是eval_steps的整数倍 - metric_for_best_model="eval_loss", # 用于选择最佳模型的指标(必须与load_best_model_at_end一起用) - #+可以使用compute_metrics输出的evaluation的结果中(一个字典)的某个值 - #+注意:Trainer会在compute_metrics输出的字典的键前面加上一个prefix,默认就是“eval_” - greater_is_better=False, # 指标值越小越好(必须与metric_for_best_model一起用) - - do_train=True, # 是否进行训练,通常用于调试 - do_eval=True, # 是否进行评估,通常用于调试 - - remove_unused_columns=False, # 是否删除没有用到的列(特征),默认为True - #+当删除了没用到的列后,making it easier to unpack inputs into the model’s call function - #+注意:remove_unused_columns去除列的操作会把传入的dataset的columns_names与模型forward方法中的参数名进行配对,对于不存在forward方法中的列名就会直接删掉整个feature - #+因此如果在dataset.with_transform(..)中给数据进行改名,那么这个remove操作会直接把原始的数据直接删掉,从而导致之后会拿到一个空的dataset,导致在对dataset进行切片取值时出问题 - #+例如读进来的dataset图片对应的feature name叫"images",而模型forward方法中对应的参数名叫“pixel_values”, - #+此时如果是在data.withtransfrom(..)中根据这个"images"生成其他模型forward方法中需要的参数,然后再把"images"改名成“pixel_values”,那么整个过程就会出问题 - #+因为设置了remove_unused_columns=True后,会先给dataset进行列名检查,然后“images”这个feature会直接被删掉(导致with_transform的transform_fn拿不到“images”这个feature) - #+所以一个good practice就是:对于要改名的特征,先提前使用dataset.rename_column进行改名 - - push_to_hub=False, # 是否训练完后上传hub,需要先在命令行:huggingface-cli login进行登录认证的配置,配置完后,认证信息会存到cache文件夹里 -) +from ..utils.metrics import bleu_metric +from ....globals import MAX_TOKEN_SIZE -def main(): - # dataset = load_dataset( - # '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', - # 'cleaned_formulas' - # )['train'].select(range(500)) - dataset = load_dataset( - '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', - 'cleaned_formulas' - )['train'] - - tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') - - map_fn = partial(tokenize_fn, tokenizer=tokenizer) - collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) - - tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8) - # tokenized_formula = tokenized_formula.to_dict() - # tokenized_formula['pixel_values'] = dataset['image'] - # tokenized_dataset = dataset.from_dict(tokenized_formula) - tokenized_dataset = tokenized_dataset.with_transform(img_preprocess) - split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42) - train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] - - model = TexTeller() - +def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer): + training_args = TrainingArguments(**CONFIG) trainer = Trainer( model, training_args, @@ -132,19 +28,40 @@ def main(): trainer.train(resume_from_checkpoint=None) - """ - 一个metric_function的另一个case: +def evaluate(model, tokenizer, eval_dataset, collate_fn): + eval_config = CONFIG.copy() + generate_config = GenerationConfig( + max_new_tokens=MAX_TOKEN_SIZE, + num_beams=1, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + bos_token_id=tokenizer.bos_token_id, + ) + eval_config['output_dir'] = 'debug_dir' + eval_config['predict_with_generate'] = True + eval_config['predict_with_generate'] = True + eval_config['dataloader_num_workers'] = 1 + eval_config['jit_mode_eval'] = False + eval_config['torch_compile'] = False + eval_config['auto_find_batch_size'] = False + eval_config['generation_config'] = generate_config + seq2seq_config = Seq2SeqTrainingArguments(**eval_config) - # Setup evaluation - metric = evaluate.load("accuracy") + trainer = Seq2SeqTrainer( + model, + seq2seq_config, - def compute_metrics(eval_pred): - logits, labels = eval_pred - predictions = np.argmax(logits, axis=-1) - return metric.compute(predictions=predictions, references=labels) - """ + eval_dataset=eval_dataset.select(range(16)), + tokenizer=tokenizer, + data_collator=collate_fn, + compute_metrics=partial(bleu_metric, tokenizer=tokenizer) + ) + + res = trainer.evaluate() pause = 1 - model.generate() + ... + if __name__ == '__main__': @@ -152,7 +69,32 @@ if __name__ == '__main__': script_dirpath = Path(__file__).resolve().parent os.chdir(script_dirpath) - main() + + dataset = load_dataset( + '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', + 'cleaned_formulas' + )['train'] + + pause = dataset[0]['image'] + tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') + + map_fn = partial(tokenize_fn, tokenizer=tokenizer) + tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8) + tokenized_dataset = tokenized_dataset.with_transform(img_preprocess) + + split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42) + train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] + collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) + # model = TexTeller() + model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500') + + enable_train = False + enable_evaluate = True + if enable_train: + train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer) + if enable_evaluate: + evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer) + os.chdir(cur_path) diff --git a/src/models/ocr_model/train/training_args.py b/src/models/ocr_model/train/training_args.py new file mode 100644 index 0000000..ddec056 --- /dev/null +++ b/src/models/ocr_model/train/training_args.py @@ -0,0 +1,83 @@ +CONFIG = { + "seed": 42, # 随机种子,用于确保实验的可重复性 + "use_cpu": False, # 是否使用cpu(刚开始测试代码的时候先用cpu跑会更容易debug) + # "data_seed": 42, # data sampler的采样也固定 + # "full_determinism": True, # 使整个训练完全固定(这个设置会有害于模型训练,只用于debug) + + "output_dir": "train_result", # 输出目录 + "overwrite_output_dir": False, # 如果输出目录存在,不删除原先的内容 + "report_to": ["tensorboard"], # 输出日志到TensorBoard, + #+通过在命令行:tensorboard --logdir ./logs 来查看日志 + + "logging_dir": None, # TensorBoard日志文件的存储目录(使用默认值) + "log_level": "info", # 其他可选:‘debug’, ‘info’, ‘warning’, ‘error’ and ‘critical’(由低级别到高级别) + "logging_strategy": "steps", # 每隔一定步数记录一次日志 + "logging_steps": 500, # 记录日志的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000) + #+通常与eval_steps一致 + "logging_nan_inf_filter": False, # 对loss=nan或inf进行记录 + + "num_train_epochs": 10, # 总的训练轮数 + # "max_steps": 3, # 训练的最大步骤数。如果设置了这个参数, + #+那么num_train_epochs将被忽略(通常用于调试) + + # "label_names": ['your_label_name'], # 指定data_loader中的标签名,如果不指定则默认为'labels' + + "per_device_train_batch_size": 128, # 每个GPU的batch size + "per_device_eval_batch_size": 16, # 每个GPU的evaluation batch size + "auto_find_batch_size": True, # 自动搜索合适的batch size(指数decay) + + "optim": "adamw_torch", # 还提供了很多AdamW的变体(相较于经典的AdamW更加高效) + #+当设置了optim后,就不需要在Trainer中传入optimizer + "lr_scheduler_type": "cosine", # 设置lr_scheduler + "warmup_ratio": 0.1, # warmup占整个训练steps的比例(假如训练1000步,那么前100步就是从lr=0慢慢长到参数设定的lr) + # "warmup_steps": 500, # 预热步数, 这个参数与warmup_ratio是矛盾的 + "weight_decay": 0, # 权重衰减 + "learning_rate": 5e-5, # 学习率 + "max_grad_norm": 1.0, # 用于梯度裁剪,确保梯度的范数不超过1.0(默认1.0) + "fp16": False, # 是否使用16位浮点数进行训练(一般不推荐,loss很容易炸) + "bf16": False, # 是否使用16位宽浮点数进行训练(如果架构支持的话推荐使用) + "gradient_accumulation_steps": 2, # 梯度累积步数,当batch size无法开很大时,可以考虑这个参数来实现大batch size的效果 + "gradient_checkpointing": False, # 当为True时,会在forward时适当丢弃一些中间量(用于backward),从而减轻显存压力(但会增加forward的时间) + "label_smoothing_factor": 0.0, # softlabel,等于0时表示未开启 + # "debug": "underflow_overflow", # 训练时检查溢出,如果发生,则会发出警告。(该模式通常用于debug) + "jit_mode_eval": True, # 是否在eval的时候使用PyTorch jit trace(可以加速模型,但模型必须是静态的,否则会报错) + "torch_compile": True, # 是否使用torch.compile来编译模型(从而获得更好的训练和推理性能) + #+ 要求torch > 2.0,这个功能很好使,当模型跑通的时候可以开起来 + # "deepspeed": "your_json_path", # 使用deepspeed来训练,需要指定ds_config.json的路径 + #+ 在Trainer中使用Deepspeed时一定要注意ds_config.json中的配置是否与Trainer的一致(如学习率,batch size,梯度累积步数等) + #+ 如果不一致,会出现很奇怪的bug(而且一般还很难发现) + + "dataloader_pin_memory": True, # 可以加快数据在cpu和gpu之间转移的速度 + "dataloader_num_workers": 16, # 默认不会使用多进程来加载数据,通常设成4*所用的显卡数 + "dataloader_drop_last": True, # 丢掉最后一个minibatch,保证训练的梯度稳定 + + "evaluation_strategy": "steps", # 评估策略,可以是"steps"或"epoch" + "eval_steps": 500, # if evaluation_strategy="step" + #+默认情况下与logging_steps一样,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000) + + "save_strategy": "steps", # 保存checkpoint的策略 + "save_steps": 500, # checkpoint保存的步数间隔,可以是int也可以是(0~1)的float,当是float时表示总的训练步数的ratio(比方说可以设置成1.0 / 2000) + "save_total_limit": 5, # 保存的模型的最大数量。如果超过这个数量,最旧的模型将被删除 + + "load_best_model_at_end": True, # 训练结束时是否加载最佳模型 + #+当设置True时,会保存训练时评估结果最好的checkpoint + #+当设置True时,evaluation_strategy必须与save_strategy一样,并且save_steps必须是eval_steps的整数倍 + "metric_for_best_model": "eval_loss", # 用于选择最佳模型的指标(必须与load_best_model_at_end一起用) + #+可以使用compute_metrics输出的evaluation的结果中(一个字典)的某个值 + #+注意:Trainer会在compute_metrics输出的字典的键前面加上一个prefix,默认就是“eval_” + "greater_is_better": False, # 指标值越小越好(必须与metric_for_best_model一起用) + + "do_train": True, # 是否进行训练,通常用于调试 + "do_eval": True, # 是否进行评估,通常用于调试 + + "remove_unused_columns": False, # 是否删除没有用到的列(特征),默认为True + #+当删除了没用到的列后,making it easier to unpack inputs into the model’s call function + #+注意:remove_unused_columns去除列的操作会把传入的dataset的columns_names与模型forward方法中的参数名进行配对,对于不存在forward方法中的列名就会直接删掉整个feature + #+因此如果在dataset.with_transform(..)中给数据进行改名,那么这个remove操作会直接把原始的数据直接删掉,从而导致之后会拿到一个空的dataset,导致在对dataset进行切片取值时出问题 + #+例如读进来的dataset图片对应的feature name叫"images",而模型forward方法中对应的参数名叫“pixel_values”, + #+此时如果是在data.withtransfrom(..)中根据这个"images"生成其他模型forward方法中需要的参数,然后再把"images"改名成“pixel_values”,那么整个过程就会出问题 + #+因为设置了remove_unused_columns=True后,会先给dataset进行列名检查,然后“images”这个feature会直接被删掉(导致with_transform的transform_fn拿不到“images”这个feature) + #+所以一个good practice就是:对于要改名的特征,先提前使用dataset.rename_column进行改名 + + "push_to_hub": False, # 是否训练完后上传hub,需要先在命令行:huggingface-cli login进行登录认证的配置,配置完后,认证信息会存到cache文件夹里 +} \ No newline at end of file diff --git a/src/models/ocr_model/utils/metrics.py b/src/models/ocr_model/utils/metrics.py new file mode 100644 index 0000000..a21d131 --- /dev/null +++ b/src/models/ocr_model/utils/metrics.py @@ -0,0 +1,17 @@ +import evaluate +import numpy as np +from transformers import EvalPrediction, RobertaTokenizer +from typing import Dict + +def bleu_metric(eval_preds:EvalPrediction, tokenizer:RobertaTokenizer) -> Dict: + metric = evaluate.load('/home/lhy/code/TeXify/src/models/ocr_model/train/google_bleu/google_bleu.py') # 这里需要联网,所以会卡住 + + logits, labels = eval_preds.predictions, eval_preds.label_ids + preds = logits + # preds = np.argmax(logits, axis=1) # 把logits转成对应的预测标签 + + labels = np.where(labels == -100, 1, labels) + + preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + return metric.compute(predictions=preds, references=labels) \ No newline at end of file