From 1fba652766149817c2c467b3881002a829fd4c8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Wed, 31 Jan 2024 10:11:07 +0000 Subject: [PATCH] Initial Commit --- src/globals.py | 4 + src/models/ocr_model/inference.py | 35 +--- ...t.tfevents.1706615305.ubuntu-xyp.1506187.0 | Bin 0 -> 349 bytes src/models/ocr_model/train/train.py | 21 ++- .../utils/{preprocess.py => functional.py} | 13 +- src/models/ocr_model/utils/helpers.py | 37 ++++ src/models/ocr_model/utils/transforms.py | 172 +++++++++++++----- 7 files changed, 193 insertions(+), 89 deletions(-) create mode 100644 src/models/ocr_model/train/debug_dir/runs/Jan30_11-48-17_ubuntu-xyp/events.out.tfevents.1706615305.ubuntu-xyp.1506187.0 rename src/models/ocr_model/utils/{preprocess.py => functional.py} (88%) create mode 100644 src/models/ocr_model/utils/helpers.py diff --git a/src/globals.py b/src/globals.py index cecb2a6..4ede0ca 100644 --- a/src/globals.py +++ b/src/globals.py @@ -32,6 +32,10 @@ OCR_IMG_CHANNELS = 1 # 灰度图 # ocr模型训练数据集的最长token数 MAX_TOKEN_SIZE = 600 +# ocr模型训练时随机缩放的比例 +MAX_RESIZE_RATIO = 1.15 +MIN_RESIZE_RATIO = 0.75 + # ============================================================================= # diff --git a/src/models/ocr_model/inference.py b/src/models/ocr_model/inference.py index ce80963..776dc9a 100644 --- a/src/models/ocr_model/inference.py +++ b/src/models/ocr_model/inference.py @@ -8,43 +8,10 @@ from typing import List from .model.TexTeller import TexTeller from .utils.transforms import inference_transform +from .utils.helpers import convert2rgb from ...globals import MAX_TOKEN_SIZE -def convert2rgb(image_paths: List[str]) -> List[Image.Image]: - processed_images = [] - - for path in image_paths: - # 读取图片 - image = cv2.imread(path, cv2.IMREAD_UNCHANGED) - - if image is None: - print(f"Image at {path} could not be read.") - continue - - # 检查图片是否使用 uint16 类型 - if image.dtype == np.uint16: - raise ValueError(f"Image at {path} is stored in uint16, which is not supported.") - - # 获取图片通道数 - channels = 1 if len(image.shape) == 2 else image.shape[2] - - # 如果是 RGBA (4通道), 转换为 RGB - if channels == 4: - image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) - - # 如果是 I 模式 (单通道灰度图), 转换为 RGB - elif channels == 1: - image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) - - # 如果是 BGR (3通道), 转换为 RGB - elif channels == 3: - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - processed_images.append(Image.fromarray(image)) - - return processed_images - - def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]: imgs = convert2rgb(imgs_path) imgs = inference_transform(imgs) diff --git a/src/models/ocr_model/train/debug_dir/runs/Jan30_11-48-17_ubuntu-xyp/events.out.tfevents.1706615305.ubuntu-xyp.1506187.0 b/src/models/ocr_model/train/debug_dir/runs/Jan30_11-48-17_ubuntu-xyp/events.out.tfevents.1706615305.ubuntu-xyp.1506187.0 new file mode 100644 index 0000000000000000000000000000000000000000..e721a9108a70114be8ed1c4f3fbeb3a9151e193d GIT binary patch literal 349 zcmeZZfPjCKJmzxd7Uwou}pnqSPGp`e3u9gs&02gO!Sz?ZU zPJVH*XhC(g{m1KND^#G0owOT#IhS@p6-#o7a|wVIr|0LV15HoLNi7wftiRh%pM@=0 z0jfvy?aBa7=WeJTQ7&OF9VcqlGFke OjoUu|wY8SNFdG1-l5u7L literal 0 HcmV?d00001 diff --git a/src/models/ocr_model/train/train.py b/src/models/ocr_model/train/train.py index 5b39b31..a597fb5 100644 --- a/src/models/ocr_model/train/train.py +++ b/src/models/ocr_model/train/train.py @@ -1,13 +1,15 @@ import os +import numpy as np from functools import partial from pathlib import Path from datasets import load_dataset from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrainingArguments, GenerationConfig + from .training_args import CONFIG from ..model.TexTeller import TexTeller -from ..utils.preprocess import tokenize_fn, collate_fn, img_preprocess +from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn from ..utils.metrics import bleu_metric from ....globals import MAX_TOKEN_SIZE @@ -38,6 +40,7 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn): eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, ) + # eval_config['use_cpu'] = True eval_config['output_dir'] = 'debug_dir' eval_config['predict_with_generate'] = True eval_config['predict_with_generate'] = True @@ -52,7 +55,7 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn): model, seq2seq_config, - eval_dataset=eval_dataset.select(range(16)), + eval_dataset=eval_dataset, tokenizer=tokenizer, data_collator=collate_fn, compute_metrics=partial(bleu_metric, tokenizer=tokenizer) @@ -70,23 +73,27 @@ if __name__ == '__main__': os.chdir(script_dirpath) + # dataset = load_dataset( + # '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', + # 'cleaned_formulas' + # )['train'] dataset = load_dataset( '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', 'cleaned_formulas' - )['train'] + )['train'].select(range(1000)) - pause = dataset[0]['image'] tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') map_fn = partial(tokenize_fn, tokenizer=tokenizer) - tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8) - tokenized_dataset = tokenized_dataset.with_transform(img_preprocess) + # tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=False) + tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=1, load_from_cache_file=False) + tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn) split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42) train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) # model = TexTeller() - model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500') + model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-80500') enable_train = False enable_evaluate = True diff --git a/src/models/ocr_model/utils/preprocess.py b/src/models/ocr_model/utils/functional.py similarity index 88% rename from src/models/ocr_model/utils/preprocess.py rename to src/models/ocr_model/utils/functional.py index 1892e7a..b7cecc1 100644 --- a/src/models/ocr_model/utils/preprocess.py +++ b/src/models/ocr_model/utils/functional.py @@ -1,10 +1,12 @@ import torch -from datasets import load_dataset +import numpy as np from functools import partial +from datasets import load_dataset + from transformers import DataCollatorForLanguageModeling from typing import List, Dict, Any -from ...ocr_model.model.TexTeller import TexTeller +from ..model.TexTeller import TexTeller from .transforms import train_transform @@ -19,7 +21,7 @@ def left_move(x: torch.Tensor, pad_val): def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]: assert tokenizer is not None, 'tokenizer should not be None' tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True) - tokenized_formula['pixel_values'] = samples['image'] + tokenized_formula['pixel_values'] = [np.array(sample) for sample in samples['image']] return tokenized_formula @@ -36,14 +38,13 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[ # 左移labels和decoder_attention_mask batch['labels'] = left_move(batch['labels'], -100) - batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0) # 把list of Image转成一个tensor with (B, C, H, W) batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0) return batch -def img_preprocess(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: +def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: processed_img = train_transform(samples['pixel_values']) samples['pixel_values'] = processed_img return samples @@ -63,7 +64,7 @@ if __name__ == '__main__': tokenized_formula = tokenized_formula.to_dict() # tokenized_formula['pixel_values'] = dataset['image'] # tokenized_formula = dataset.from_dict(tokenized_formula) - tokenized_dataset = tokenized_formula.with_transform(img_preprocess) + tokenized_dataset = tokenized_formula.with_transform(img_transform_fn) dataset_dict = tokenized_dataset[:] dataset_list = [dict(zip(dataset_dict.keys(), x)) for x in zip(*dataset_dict.values())] diff --git a/src/models/ocr_model/utils/helpers.py b/src/models/ocr_model/utils/helpers.py new file mode 100644 index 0000000..dc82565 --- /dev/null +++ b/src/models/ocr_model/utils/helpers.py @@ -0,0 +1,37 @@ +import cv2 +import numpy as np +from typing import List +from PIL import Image + + +def convert2rgb(image_paths: List[str]) -> List[Image.Image]: + processed_images = [] + + for path in image_paths: + # 读取图片 + image = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if image is None: + print(f"Image at {path} could not be read.") + continue + + # 检查图片是否使用 uint16 类型 + if image.dtype == np.uint16: + raise ValueError(f"Image at {path} is stored in uint16, which is not supported.") + + # 获取图片通道数 + channels = 1 if len(image.shape) == 2 else image.shape[2] + + # 如果是 RGBA (4通道), 转换为 RGB + if channels == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) + + # 如果是 I 模式 (单通道灰度图), 转换为 RGB + elif channels == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + + # 如果是 BGR (3通道), 转换为 RGB + elif channels == 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + processed_images.append(Image.fromarray(image)) + + return processed_images \ No newline at end of file diff --git a/src/models/ocr_model/utils/transforms.py b/src/models/ocr_model/utils/transforms.py index f7d6578..9f744e4 100644 --- a/src/models/ocr_model/utils/transforms.py +++ b/src/models/ocr_model/utils/transforms.py @@ -1,67 +1,155 @@ import torch +import random +import numpy as np +import cv2 from torchvision.transforms import v2 from PIL import ImageChops, Image -from typing import List +from typing import List, Union -from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN, IMAGE_STD +from ....globals import ( + OCR_IMG_CHANNELS, + OCR_IMG_SIZE, + OCR_FIX_SIZE, + IMAGE_MEAN, IMAGE_STD, + MAX_RESIZE_RATIO, MIN_RESIZE_RATIO +) -def trim_white_border(image: Image.Image): - if image.mode == 'RGB': - bg_color = (255, 255, 255) - elif image.mode == 'L': - bg_color = 255 - else: - raise ValueError("Only support RGB or L mode") - # 创建一个与图片一样大小的白色背景 - bg = Image.new(image.mode, image.size, bg_color) - # 计算原图像与背景图像的差异。如果原图像在边框区域与左上角像素颜色相同,那么这些区域在差异图像中将是黑色的。 - diff = ImageChops.difference(image, bg) - # 这一步增强差异图像中的对比度,使非背景区域更加明显。这对确定边界框有帮助,但参数的选择可能需要根据具体图像进行调整。 - diff = ImageChops.add(diff, diff, 2.0, -100) - # 找到差异图像中非黑色区域的边界框。如果找到,原图将根据这个边界框被裁剪。 - bbox = diff.getbbox() - return image.crop(bbox) if bbox else image +general_transform_pipeline = v2.Compose([ + v2.ToImage(), # Convert to tensor, only needed if you had a PIL image + #+返回一个List of torchvision.Image,list的长度就是batch_size + #+因此在整个Compose pipeline的最后,输出的也是一个List of torchvision.Image + #+注意:不是返回一整个torchvision.Image,batch_size的维度是拿出来的 + v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point + v2.Grayscale(), # 转灰度图(视具体任务而定) + + v2.Resize( # 固定resize到一个正方形上 + size=OCR_IMG_SIZE - 1, # size必须小于max_size + interpolation=v2.InterpolationMode.BICUBIC, + max_size=OCR_IMG_SIZE, + antialias=True + ), + + v2.ToDtype(torch.float32, scale=True), # Normalize expects float input + v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]), + + # v2.ToPILImage() # 用于观察转换后的结果是否正确(debug用) +]) -def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: - assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now" - assert OCR_FIX_SIZE == True, "Only support fixed size images for now" - images = [trim_white_border(image) for image in images] - transforms = v2.Compose([ - v2.ToImage(), # Convert to tensor, only needed if you had a PIL image - #+返回一个List of torchvision.Image,list的长度就是batch_size - #+因此在整个Compose pipeline的最后,输出的也是一个List of torchvision.Image - #+注意:不是返回一整个torchvision.Image,batch_size的维度是拿出来的 - v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point - v2.Grayscale(), # 转灰度图(视具体任务而定) +def trim_white_border(image: Union[np.ndarray, List[List[List]]]): + # image是一个3维的ndarray,RGB格式,维度分布为[H, W, C](通道维在第三维上) - v2.Resize( # 固定resize到一个正方形上 - size=OCR_IMG_SIZE - 1, # size必须小于max_size - interpolation=v2.InterpolationMode.BICUBIC, - max_size=OCR_IMG_SIZE, - antialias=True - ), + # 检查images中的第一个元素是否是嵌套的列表结构 + if isinstance(image, list): + image = np.array(image, dtype=np.uint8) - v2.ToDtype(torch.float32, scale=True), # Normalize expects float input - v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]), + # 检查图像是否为RGB格式,同时检查通道维是不是在第三维上 + if len(image.shape) != 3 or image.shape[2] != 3: + raise ValueError("Image is not in RGB format or channel is not in third dimension") - # v2.ToPILImage() # 用于观察转换后的结果是否正确(debug用) - ]) + # 检查图片是否使用 uint8 类型 + if image.dtype != np.uint8: + raise ValueError(f"Image should stored in uint8") - images = transforms(images) # imgs: List[PIL.Image.Image] + # 创建与原图像同样大小的纯白背景图像 + h, w = image.shape[:2] + bg = np.full((h, w, 3), 255, dtype=np.uint8) + + # 计算差异 + diff = cv2.absdiff(image, bg) + + # 只要差值大于1,就全部转化为255 + _, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY) + + # 把差值转灰度图 + gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY) + # 计算图像中非零像素点的最小外接矩阵 + x, y, w, h = cv2.boundingRect(gray_diff) + + # 裁剪图像 + trimmed_image = image[y:y+h, x:x+w] + + return trimmed_image + + +def padding(images: List[torch.Tensor], required_size: int): images = [ v2.functional.pad( img, - padding=[0, 0, OCR_IMG_SIZE - img.shape[2], OCR_IMG_SIZE - img.shape[1]] + padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]] ) for img in images ] return images +def random_resize( + images: Union[List[np.ndarray], List[List[List[List]]]], + minr: float, + maxr: float +) -> List[np.ndarray]: + # np.ndarray的格式:3维,RGB格式,维度分布为[H, W, C](通道维在第三维上) + + # 检查images中的第一个元素是否是嵌套的列表结构 + if isinstance(images[0], list): + # 将嵌套的列表结构转换为np.ndarray + images = [np.array(img, dtype=np.uint8) for img in images] + + if len(images[0].shape) != 3 or images[0].shape[2] != 3: + raise ValueError("Image is not in RGB format or channel is not in third dimension") + + ratios = [random.uniform(minr, maxr) for _ in range(len(images))] + return [ + cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4) # 抗锯齿 + for img, r in zip(images, ratios) + ] + + +def general_transform(images: List[Image.Image]) -> List[torch.Tensor]: + # 裁剪掉白边 + images = [trim_white_border(image) for image in images] + # general transform pipeline + images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image] + # padding to fixed size + images = padding(images, OCR_IMG_SIZE) + return images + + +def train_transform(images: List[np.ndarray]) -> List[torch.Tensor]: + assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now" + assert OCR_FIX_SIZE == True, "Only support fixed size images for now" + + # random resize first + # images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO) + return general_transform(images) + + def inference_transform(images: List[Image.Image]) -> List[torch.Tensor]: assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now" assert OCR_FIX_SIZE == True, "Only support fixed size images for now" - return train_transform(images) + + return general_transform(images) + + +if __name__ == '__main__': + from pathlib import Path + from .helpers import convert2rgb + base_dir = Path('/home/lhy/code/TeXify/src/models/ocr_model/model') + imgs_path = [ + base_dir / '1.jpg', + base_dir / '2.jpg', + base_dir / '3.jpg', + base_dir / '4.jpg', + base_dir / '5.jpg', + base_dir / '6.jpg', + base_dir / '7.jpg', + ] + imgs_path = [str(img_path) for img_path in imgs_path] + imgs = convert2rgb(imgs_path) + # res = train_transform(imgs) + # res = [v2.functional.to_pil_image(img) for img in res] + res = random_resize(imgs, 0.5, 1.5) + pause = 1 +