Initial Commit
This commit is contained in:
@@ -32,6 +32,10 @@ OCR_IMG_CHANNELS = 1 # 灰度图
|
||||
# ocr模型训练数据集的最长token数
|
||||
MAX_TOKEN_SIZE = 600
|
||||
|
||||
# ocr模型训练时随机缩放的比例
|
||||
MAX_RESIZE_RATIO = 1.15
|
||||
MIN_RESIZE_RATIO = 0.75
|
||||
|
||||
# ============================================================================= #
|
||||
|
||||
|
||||
|
||||
@@ -8,43 +8,10 @@ from typing import List
|
||||
|
||||
from .model.TexTeller import TexTeller
|
||||
from .utils.transforms import inference_transform
|
||||
from .utils.helpers import convert2rgb
|
||||
from ...globals import MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
|
||||
processed_images = []
|
||||
|
||||
for path in image_paths:
|
||||
# 读取图片
|
||||
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if image is None:
|
||||
print(f"Image at {path} could not be read.")
|
||||
continue
|
||||
|
||||
# 检查图片是否使用 uint16 类型
|
||||
if image.dtype == np.uint16:
|
||||
raise ValueError(f"Image at {path} is stored in uint16, which is not supported.")
|
||||
|
||||
# 获取图片通道数
|
||||
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
||||
|
||||
# 如果是 RGBA (4通道), 转换为 RGB
|
||||
if channels == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
||||
|
||||
# 如果是 I 模式 (单通道灰度图), 转换为 RGB
|
||||
elif channels == 1:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
# 如果是 BGR (3通道), 转换为 RGB
|
||||
elif channels == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
processed_images.append(Image.fromarray(image))
|
||||
|
||||
return processed_images
|
||||
|
||||
|
||||
def inference(model: TexTeller, imgs_path: List[str], tokenizer: RobertaTokenizerFast) -> List[str]:
|
||||
imgs = convert2rgb(imgs_path)
|
||||
imgs = inference_transform(imgs)
|
||||
|
||||
Binary file not shown.
@@ -1,13 +1,15 @@
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import Trainer, TrainingArguments, Seq2SeqTrainer, Seq2SeqTrainingArguments, GenerationConfig
|
||||
|
||||
from .training_args import CONFIG
|
||||
from ..model.TexTeller import TexTeller
|
||||
from ..utils.preprocess import tokenize_fn, collate_fn, img_preprocess
|
||||
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
|
||||
from ..utils.metrics import bleu_metric
|
||||
from ....globals import MAX_TOKEN_SIZE
|
||||
|
||||
@@ -38,6 +40,7 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
)
|
||||
# eval_config['use_cpu'] = True
|
||||
eval_config['output_dir'] = 'debug_dir'
|
||||
eval_config['predict_with_generate'] = True
|
||||
eval_config['predict_with_generate'] = True
|
||||
@@ -52,7 +55,7 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||
model,
|
||||
seq2seq_config,
|
||||
|
||||
eval_dataset=eval_dataset.select(range(16)),
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collate_fn,
|
||||
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
|
||||
@@ -70,23 +73,27 @@ if __name__ == '__main__':
|
||||
os.chdir(script_dirpath)
|
||||
|
||||
|
||||
# dataset = load_dataset(
|
||||
# '/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||
# 'cleaned_formulas'
|
||||
# )['train']
|
||||
dataset = load_dataset(
|
||||
'/home/lhy/code/TeXify/src/models/ocr_model/train/dataset/latex-formulas/latex-formulas.py',
|
||||
'cleaned_formulas'
|
||||
)['train']
|
||||
)['train'].select(range(1000))
|
||||
|
||||
pause = dataset[0]['image']
|
||||
tokenizer = TexTeller.get_tokenizer('/home/lhy/code/TeXify/src/models/tokenizer/roberta-tokenizer-550Kformulas')
|
||||
|
||||
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
|
||||
tokenized_dataset = tokenized_dataset.with_transform(img_preprocess)
|
||||
# tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8, load_from_cache_file=False)
|
||||
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=1, load_from_cache_file=False)
|
||||
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
|
||||
|
||||
split_dataset = tokenized_dataset.train_test_split(test_size=0.05, seed=42)
|
||||
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
|
||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||
# model = TexTeller()
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-57500')
|
||||
model = TexTeller.from_pretrained('/home/lhy/code/TeXify/src/models/ocr_model/train/train_result/checkpoint-80500')
|
||||
|
||||
enable_train = False
|
||||
enable_evaluate = True
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
import numpy as np
|
||||
|
||||
from functools import partial
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
from typing import List, Dict, Any
|
||||
from ...ocr_model.model.TexTeller import TexTeller
|
||||
from ..model.TexTeller import TexTeller
|
||||
from .transforms import train_transform
|
||||
|
||||
|
||||
@@ -19,7 +21,7 @@ def left_move(x: torch.Tensor, pad_val):
|
||||
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
|
||||
assert tokenizer is not None, 'tokenizer should not be None'
|
||||
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
|
||||
tokenized_formula['pixel_values'] = samples['image']
|
||||
tokenized_formula['pixel_values'] = [np.array(sample) for sample in samples['image']]
|
||||
return tokenized_formula
|
||||
|
||||
|
||||
@@ -36,14 +38,13 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[
|
||||
|
||||
# 左移labels和decoder_attention_mask
|
||||
batch['labels'] = left_move(batch['labels'], -100)
|
||||
batch['decoder_attention_mask'] = left_move(batch['decoder_attention_mask'], 0)
|
||||
|
||||
# 把list of Image转成一个tensor with (B, C, H, W)
|
||||
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
||||
return batch
|
||||
|
||||
|
||||
def img_preprocess(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
processed_img = train_transform(samples['pixel_values'])
|
||||
samples['pixel_values'] = processed_img
|
||||
return samples
|
||||
@@ -63,7 +64,7 @@ if __name__ == '__main__':
|
||||
tokenized_formula = tokenized_formula.to_dict()
|
||||
# tokenized_formula['pixel_values'] = dataset['image']
|
||||
# tokenized_formula = dataset.from_dict(tokenized_formula)
|
||||
tokenized_dataset = tokenized_formula.with_transform(img_preprocess)
|
||||
tokenized_dataset = tokenized_formula.with_transform(img_transform_fn)
|
||||
|
||||
dataset_dict = tokenized_dataset[:]
|
||||
dataset_list = [dict(zip(dataset_dict.keys(), x)) for x in zip(*dataset_dict.values())]
|
||||
37
src/models/ocr_model/utils/helpers.py
Normal file
37
src/models/ocr_model/utils/helpers.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def convert2rgb(image_paths: List[str]) -> List[Image.Image]:
|
||||
processed_images = []
|
||||
|
||||
for path in image_paths:
|
||||
# 读取图片
|
||||
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if image is None:
|
||||
print(f"Image at {path} could not be read.")
|
||||
continue
|
||||
|
||||
# 检查图片是否使用 uint16 类型
|
||||
if image.dtype == np.uint16:
|
||||
raise ValueError(f"Image at {path} is stored in uint16, which is not supported.")
|
||||
|
||||
# 获取图片通道数
|
||||
channels = 1 if len(image.shape) == 2 else image.shape[2]
|
||||
|
||||
# 如果是 RGBA (4通道), 转换为 RGB
|
||||
if channels == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
||||
|
||||
# 如果是 I 模式 (单通道灰度图), 转换为 RGB
|
||||
elif channels == 1:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
# 如果是 BGR (3通道), 转换为 RGB
|
||||
elif channels == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
processed_images.append(Image.fromarray(image))
|
||||
|
||||
return processed_images
|
||||
@@ -1,67 +1,155 @@
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from torchvision.transforms import v2
|
||||
from PIL import ImageChops, Image
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from ....globals import OCR_IMG_CHANNELS, OCR_IMG_SIZE, OCR_FIX_SIZE, IMAGE_MEAN, IMAGE_STD
|
||||
from ....globals import (
|
||||
OCR_IMG_CHANNELS,
|
||||
OCR_IMG_SIZE,
|
||||
OCR_FIX_SIZE,
|
||||
IMAGE_MEAN, IMAGE_STD,
|
||||
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
|
||||
)
|
||||
|
||||
|
||||
def trim_white_border(image: Image.Image):
|
||||
if image.mode == 'RGB':
|
||||
bg_color = (255, 255, 255)
|
||||
elif image.mode == 'L':
|
||||
bg_color = 255
|
||||
else:
|
||||
raise ValueError("Only support RGB or L mode")
|
||||
# 创建一个与图片一样大小的白色背景
|
||||
bg = Image.new(image.mode, image.size, bg_color)
|
||||
# 计算原图像与背景图像的差异。如果原图像在边框区域与左上角像素颜色相同,那么这些区域在差异图像中将是黑色的。
|
||||
diff = ImageChops.difference(image, bg)
|
||||
# 这一步增强差异图像中的对比度,使非背景区域更加明显。这对确定边界框有帮助,但参数的选择可能需要根据具体图像进行调整。
|
||||
diff = ImageChops.add(diff, diff, 2.0, -100)
|
||||
# 找到差异图像中非黑色区域的边界框。如果找到,原图将根据这个边界框被裁剪。
|
||||
bbox = diff.getbbox()
|
||||
return image.crop(bbox) if bbox else image
|
||||
general_transform_pipeline = v2.Compose([
|
||||
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
|
||||
#+返回一个List of torchvision.Image,list的长度就是batch_size
|
||||
#+因此在整个Compose pipeline的最后,输出的也是一个List of torchvision.Image
|
||||
#+注意:不是返回一整个torchvision.Image,batch_size的维度是拿出来的
|
||||
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
|
||||
v2.Grayscale(), # 转灰度图(视具体任务而定)
|
||||
|
||||
v2.Resize( # 固定resize到一个正方形上
|
||||
size=OCR_IMG_SIZE - 1, # size必须小于max_size
|
||||
interpolation=v2.InterpolationMode.BICUBIC,
|
||||
max_size=OCR_IMG_SIZE,
|
||||
antialias=True
|
||||
),
|
||||
|
||||
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
||||
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
||||
|
||||
# v2.ToPILImage() # 用于观察转换后的结果是否正确(debug用)
|
||||
])
|
||||
|
||||
|
||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||
images = [trim_white_border(image) for image in images]
|
||||
transforms = v2.Compose([
|
||||
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
|
||||
#+返回一个List of torchvision.Image,list的长度就是batch_size
|
||||
#+因此在整个Compose pipeline的最后,输出的也是一个List of torchvision.Image
|
||||
#+注意:不是返回一整个torchvision.Image,batch_size的维度是拿出来的
|
||||
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
|
||||
v2.Grayscale(), # 转灰度图(视具体任务而定)
|
||||
def trim_white_border(image: Union[np.ndarray, List[List[List]]]):
|
||||
# image是一个3维的ndarray,RGB格式,维度分布为[H, W, C](通道维在第三维上)
|
||||
|
||||
v2.Resize( # 固定resize到一个正方形上
|
||||
size=OCR_IMG_SIZE - 1, # size必须小于max_size
|
||||
interpolation=v2.InterpolationMode.BICUBIC,
|
||||
max_size=OCR_IMG_SIZE,
|
||||
antialias=True
|
||||
),
|
||||
# 检查images中的第一个元素是否是嵌套的列表结构
|
||||
if isinstance(image, list):
|
||||
image = np.array(image, dtype=np.uint8)
|
||||
|
||||
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
||||
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
||||
# 检查图像是否为RGB格式,同时检查通道维是不是在第三维上
|
||||
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
# v2.ToPILImage() # 用于观察转换后的结果是否正确(debug用)
|
||||
])
|
||||
# 检查图片是否使用 uint8 类型
|
||||
if image.dtype != np.uint8:
|
||||
raise ValueError(f"Image should stored in uint8")
|
||||
|
||||
images = transforms(images) # imgs: List[PIL.Image.Image]
|
||||
# 创建与原图像同样大小的纯白背景图像
|
||||
h, w = image.shape[:2]
|
||||
bg = np.full((h, w, 3), 255, dtype=np.uint8)
|
||||
|
||||
# 计算差异
|
||||
diff = cv2.absdiff(image, bg)
|
||||
|
||||
# 只要差值大于1,就全部转化为255
|
||||
_, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY)
|
||||
|
||||
# 把差值转灰度图
|
||||
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
|
||||
# 计算图像中非零像素点的最小外接矩阵
|
||||
x, y, w, h = cv2.boundingRect(gray_diff)
|
||||
|
||||
# 裁剪图像
|
||||
trimmed_image = image[y:y+h, x:x+w]
|
||||
|
||||
return trimmed_image
|
||||
|
||||
|
||||
def padding(images: List[torch.Tensor], required_size: int):
|
||||
images = [
|
||||
v2.functional.pad(
|
||||
img,
|
||||
padding=[0, 0, OCR_IMG_SIZE - img.shape[2], OCR_IMG_SIZE - img.shape[1]]
|
||||
padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def random_resize(
|
||||
images: Union[List[np.ndarray], List[List[List[List]]]],
|
||||
minr: float,
|
||||
maxr: float
|
||||
) -> List[np.ndarray]:
|
||||
# np.ndarray的格式:3维,RGB格式,维度分布为[H, W, C](通道维在第三维上)
|
||||
|
||||
# 检查images中的第一个元素是否是嵌套的列表结构
|
||||
if isinstance(images[0], list):
|
||||
# 将嵌套的列表结构转换为np.ndarray
|
||||
images = [np.array(img, dtype=np.uint8) for img in images]
|
||||
|
||||
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
|
||||
return [
|
||||
cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4) # 抗锯齿
|
||||
for img, r in zip(images, ratios)
|
||||
]
|
||||
|
||||
|
||||
def general_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
# 裁剪掉白边
|
||||
images = [trim_white_border(image) for image in images]
|
||||
# general transform pipeline
|
||||
images = general_transform_pipeline(images) # imgs: List[PIL.Image.Image]
|
||||
# padding to fixed size
|
||||
images = padding(images, OCR_IMG_SIZE)
|
||||
return images
|
||||
|
||||
|
||||
def train_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||
|
||||
# random resize first
|
||||
# images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||
return general_transform(images)
|
||||
|
||||
|
||||
def inference_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
assert OCR_IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||
assert OCR_FIX_SIZE == True, "Only support fixed size images for now"
|
||||
return train_transform(images)
|
||||
|
||||
return general_transform(images)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from pathlib import Path
|
||||
from .helpers import convert2rgb
|
||||
base_dir = Path('/home/lhy/code/TeXify/src/models/ocr_model/model')
|
||||
imgs_path = [
|
||||
base_dir / '1.jpg',
|
||||
base_dir / '2.jpg',
|
||||
base_dir / '3.jpg',
|
||||
base_dir / '4.jpg',
|
||||
base_dir / '5.jpg',
|
||||
base_dir / '6.jpg',
|
||||
base_dir / '7.jpg',
|
||||
]
|
||||
imgs_path = [str(img_path) for img_path in imgs_path]
|
||||
imgs = convert2rgb(imgs_path)
|
||||
# res = train_transform(imgs)
|
||||
# res = [v2.functional.to_pil_image(img) for img in res]
|
||||
res = random_resize(imgs, 0.5, 1.5)
|
||||
pause = 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user