Initial Commit

This commit is contained in:
三洋三洋
2024-01-31 10:11:07 +00:00
parent b7bf5c444f
commit 1fba652766
7 changed files with 193 additions and 89 deletions

View File

@@ -32,6 +32,10 @@ OCR_IMG_CHANNELS = 1 # 灰度图
# ocr模型训练数据集的最长token数 # ocr模型训练数据集的最长token数
MAX_TOKEN_SIZE = 600 MAX_TOKEN_SIZE = 600
# ocr模型训练时随机缩放的比例
MAX_RESIZE_RATIO = 1.15
MIN_RESIZE_RATIO = 0.75
# ============================================================================= # # ============================================================================= #

View File

@@ -8,43 +8,10 @@ from typing import List
from .model.TexTeller import TexTeller from .model.TexTeller import TexTeller
from .utils.transforms import inference_transform from .utils.transforms import inference_transform
from .utils.helpers import convert2rgb
from ...globals import MAX_TOKEN_SIZE from ...globals import MAX_TOKEN_SIZE
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
processed_images = []
for path in image_paths:
# 读取图片
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
print(f"Image at {path} could not be read.")
continue
# 检查图片是否使用 uint16 类型
if image.dtype == np.uint16:
raise ValueError(f"Image at {path} is stored in uint16, which is not supported.")
# 获取图片通道数
channels = 1 if len(image.shape) == 2 else image.shape[2]
# 如果是 RGBA (4通道), 转换为 RGB
if channels == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
# 如果是 I 模式 (单通道灰度图), 转换为 RGB
elif channels == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# 如果是 BGR (3通道), 转换为 RGB
elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(Image.fromarray(image))
return processed_images
def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]: def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]:
imgs = convert2rgb(imgs_path) imgs = convert2rgb(imgs_path)
imgs = inference_transform(imgs) imgs = inference_transform(imgs)

View File

@@ -1,13 +1,15 @@
import os import os
import numpy as np
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from datasets import load_dataset from datasets import load_dataset
from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrainingArguments, GenerationConfig from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrainingArguments, GenerationConfig
from .training_args import CONFIG from .training_args import CONFIG
from ..model.TexTeller import TexTeller from ..model.TexTeller import TexTeller
from ..utils.preprocess import tokenize_fn, collate_fn, img_preprocess from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
from ..utils.metrics import bleu_metric from ..utils.metrics import bleu_metric
from ....globals import MAX_TOKEN_SIZE from ....globals import MAX_TOKEN_SIZE
@@ -38,6 +40,7 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn):
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id, bos_token_id=tokenizer.bos_token_id,
) )
# eval_config['use_cpu'] = True
eval_config['output_dir'] = 'debug_dir' eval_config['output_dir'] = 'debug_dir'
eval_config['predict_with_generate'] = True eval_config['predict_with_generate'] = True
eval_config['predict_with_generate'] = True eval_config['predict_with_generate'] = True
@@ -52,7 +55,7 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn):
model, model,
seq2seq_config, seq2seq_config,
eval_dataset=eval_dataset.select(range(16)), eval_dataset=eval_dataset,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=collate_fn, data_collator=collate_fn,
compute_metrics=partial(bleu_metric, tokenizer=tokenizer) compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
@@ -70,23 +73,27 @@ if __name__ == '__main__':
os.chdir(script_dirpath) os.chdir(script_dirpath)
# dataset = load_dataset(
# '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
# 'cleaned_formulas'
# )['train']
dataset = load_dataset( dataset = load_dataset(
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py', '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
'cleaned_formulas' 'cleaned_formulas'
)['train'] )['train'].select(range(1000))
pause = dataset[0]['image']
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas') tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
map_fn = partial(tokenize_fn, tokenizer=tokenizer) map_fn = partial(tokenize_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8) # tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=False)
tokenized_dataset = tokenized_dataset.with_transform(img_preprocess) tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=1, load_from_cache_file=False)
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42) split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# model = TexTeller() # model = TexTeller()
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500') model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-80500')
enable_train = False enable_train = False
enable_evaluate = True enable_evaluate = True

View File

@@ -1,10 +1,12 @@
import torch import torch
from datasets import load_dataset import numpy as np
from functools import partial from functools import partial
from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling from transformers import DataCollatorForLanguageModeling
from typing import List, Dict, Any from typing import List, Dict, Any
from ...ocr_model.model.TexTeller import TexTeller from ..model.TexTeller import TexTeller
from .transforms import train_transform from .transforms import train_transform
@@ -19,7 +21,7 @@ def left_move(x: torch.Tensor, pad_val):
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]: def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
assert tokenizer is not None, 'tokenizer should not be None' assert tokenizer is not None, 'tokenizer should not be None'
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True) tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
tokenized_formula['pixel_values'] = samples['image'] tokenized_formula['pixel_values'] = [np.array(sample) for sample in samples['image']]
return tokenized_formula return tokenized_formula
@@ -36,14 +38,13 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[
# 左移labels和decoder_attention_mask # 左移labels和decoder_attention_mask
batch['labels'] = left_move(batch['labels'], -100) batch['labels'] = left_move(batch['labels'], -100)
batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0)
# 把list of Image转成一个tensor with (B, C, H, W) # 把list of Image转成一个tensor with (B, C, H, W)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0) batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
return batch return batch
def img_preprocess(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = train_transform(samples['pixel_values']) processed_img = train_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img samples['pixel_values'] = processed_img
return samples return samples
@@ -63,7 +64,7 @@ if __name__ == '__main__':
tokenized_formula = tokenized_formula.to_dict() tokenized_formula = tokenized_formula.to_dict()
# tokenized_formula['pixel_values'] = dataset['image'] # tokenized_formula['pixel_values'] = dataset['image']
# tokenized_formula = dataset.from_dict(tokenized_formula) # tokenized_formula = dataset.from_dict(tokenized_formula)
tokenized_dataset = tokenized_formula.with_transform(img_preprocess) tokenized_dataset = tokenized_formula.with_transform(img_transform_fn)
dataset_dict = tokenized_dataset[:] dataset_dict = tokenized_dataset[:]
dataset_list = [dict(zip(dataset_dict.keys(), x)) for x in zip(*dataset_dict.values())] dataset_list = [dict(zip(dataset_dict.keys(), x)) for x in zip(*dataset_dict.values())]

View File

@@ -0,0 +1,37 @@
import cv2
import numpy as np
from typing import List
from PIL import Image
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
processed_images = []
for path in image_paths:
# 读取图片
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
print(f"Image at {path} could not be read.")
continue
# 检查图片是否使用 uint16 类型
if image.dtype == np.uint16:
raise ValueError(f"Image at {path} is stored in uint16, which is not supported.")
# 获取图片通道数
channels = 1 if len(image.shape) == 2 else image.shape[2]
# 如果是 RGBA (4通道), 转换为 RGB
if channels == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
# 如果是 I 模式 (单通道灰度图), 转换为 RGB
elif channels == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# 如果是 BGR (3通道), 转换为 RGB
elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(Image.fromarray(image))
return processed_images

View File

@@ -1,67 +1,155 @@
import torch import torch
import random
import numpy as np
import cv2
from torchvision.transforms import v2 from torchvision.transforms import v2
from PIL import ImageChops, Image from PIL import ImageChops, Image
from typing import List from typing import List, Union
from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN, IMAGE_STD from ....globals import (
OCR_IMG_CHANNELS,
OCR_IMG_SIZE,
OCR_FIX_SIZE,
IMAGE_MEAN, IMAGE_STD,
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
)
def trim_white_border(image: Image.Image): general_transform_pipeline = v2.Compose([
if image.mode == 'RGB': v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
bg_color = (255, 255, 255) #+返回一个List of torchvision.Imagelist的长度就是batch_size
elif image.mode == 'L': #+因此在整个Compose pipeline的最后输出的也是一个List of torchvision.Image
bg_color = 255 #+注意不是返回一整个torchvision.Imagebatch_size的维度是拿出来的
else: v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
raise ValueError("Only support RGB or L mode") v2.Grayscale(), # 转灰度图(视具体任务而定)
# 创建一个与图片一样大小的白色背景
bg = Image.new(image.mode, image.size, bg_color) v2.Resize( # 固定resize到一个正方形上
# 计算原图像与背景图像的差异。如果原图像在边框区域与左上角像素颜色相同,那么这些区域在差异图像中将是黑色的。 size=OCR_IMG_SIZE - 1, # size必须小于max_size
diff = ImageChops.difference(image, bg) interpolation=v2.InterpolationMode.BICUBIC,
# 这一步增强差异图像中的对比度,使非背景区域更加明显。这对确定边界框有帮助,但参数的选择可能需要根据具体图像进行调整。 max_size=OCR_IMG_SIZE,
diff = ImageChops.add(diff, diff, 2.0, -100) antialias=True
# 找到差异图像中非黑色区域的边界框。如果找到,原图将根据这个边界框被裁剪。 ),
bbox = diff.getbbox()
return image.crop(bbox) if bbox else image v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
# v2.ToPILImage() # 用于观察转换后的结果是否正确debug用
])
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]: def trim_white_border(image: Union[np.ndarray, List[List[List]]]):
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now" # image是一个3维的ndarrayRGB格式维度分布为[H, W, C](通道维在第三维上)
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
images = [trim_white_border(image) for image in images]
transforms = v2.Compose([
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
#+返回一个List of torchvision.Imagelist的长度就是batch_size
#+因此在整个Compose pipeline的最后输出的也是一个List of torchvision.Image
#+注意不是返回一整个torchvision.Imagebatch_size的维度是拿出来的
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
v2.Grayscale(), # 转灰度图(视具体任务而定)
v2.Resize( # 固定resize到一个正方形上 # 检查images中的第一个元素是否是嵌套的列表结构
size=OCR_IMG_SIZE - 1, # size必须小于max_size if isinstance(image, list):
interpolation=v2.InterpolationMode.BICUBIC, image = np.array(image, dtype=np.uint8)
max_size=OCR_IMG_SIZE,
antialias=True
),
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input # 检查图像是否为RGB格式同时检查通道维是不是在第三维上
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]), if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
# v2.ToPILImage() # 用于观察转换后的结果是否正确debug用 # 检查图片是否使用 uint8 类型
]) if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8")
images = transforms(images) # imgs: List[PIL.Image.Image] # 创建与原图像同样大小的纯白背景图像
h, w = image.shape[:2]
bg = np.full((h, w, 3), 255, dtype=np.uint8)
# 计算差异
diff = cv2.absdiff(image, bg)
# 只要差值大于1就全部转化为255
_, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY)
# 把差值转灰度图
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
# 计算图像中非零像素点的最小外接矩阵
x, y, w, h = cv2.boundingRect(gray_diff)
# 裁剪图像
trimmed_image = image[y:y+h, x:x+w]
return trimmed_image
def padding(images: List[torch.Tensor], required_size: int):
images = [ images = [
v2.functional.pad( v2.functional.pad(
img, img,
padding=[0, 0, OCR_IMG_SIZE - img.shape[2], OCR_IMG_SIZE - img.shape[1]] padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
) )
for img in images for img in images
] ]
return images return images
def random_resize(
images: Union[List[np.ndarray], List[List[List[List]]]],
minr: float,
maxr: float
) -> List[np.ndarray]:
# np.ndarray的格式3维RGB格式维度分布为[H, W, C](通道维在第三维上)
# 检查images中的第一个元素是否是嵌套的列表结构
if isinstance(images[0], list):
# 将嵌套的列表结构转换为np.ndarray
images = [np.array(img, dtype=np.uint8) for img in images]
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
return [
cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4) # 抗锯齿
for img, r in zip(images, ratios)
]
def general_transform(images: List[Image.Image]) -> List[torch.Tensor]:
# 裁剪掉白边
images = [trim_white_border(image) for image in images]
# general transform pipeline
images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image]
# padding to fixed size
images = padding(images, OCR_IMG_SIZE)
return images
def train_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
# random resize first
# images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
return general_transform(images)
def inference_transform(images: List[Image.Image]) -> List[torch.Tensor]: def inference_transform(images: List[Image.Image]) -> List[torch.Tensor]:
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now" assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
assert OCR_FIX_SIZE == True, "Only support fixed size images for now" assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
return train_transform(images)
return general_transform(images)
if __name__ == '__main__':
from pathlib import Path
from .helpers import convert2rgb
base_dir = Path('/home/lhy/code/TeXify/src/models/ocr_model/model')
imgs_path = [
base_dir / '1.jpg',
base_dir / '2.jpg',
base_dir / '3.jpg',
base_dir / '4.jpg',
base_dir / '5.jpg',
base_dir / '6.jpg',
base_dir / '7.jpg',
]
imgs_path = [str(img_path) for img_path in imgs_path]
imgs = convert2rgb(imgs_path)
# res = train_transform(imgs)
# res = [v2.functional.to_pil_image(img) for img in res]
res = random_resize(imgs, 0.5, 1.5)
pause = 1