Merge remote-tracking branch 'origin/dev' into dev
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
from typing import List, Dict, Any
|
||||
from .transforms import train_transform
|
||||
from .transforms import train_transform, inference_transform
|
||||
from ...globals import MIN_HEIGHT, MIN_WIDTH, MAX_TOKEN_SIZE
|
||||
|
||||
|
||||
def left_move(x: torch.Tensor, pad_val):
|
||||
@@ -32,15 +32,28 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[
|
||||
batch['decoder_input_ids'] = batch.pop('input_ids')
|
||||
batch['decoder_attention_mask'] = batch.pop('attention_mask')
|
||||
|
||||
# left shift labels and decoder_attention_mask, padding with -100
|
||||
# 左移labels和decoder_attention_mask
|
||||
batch['labels'] = left_move(batch['labels'], -100)
|
||||
|
||||
# convert list of Image to tensor with (B, C, H, W)
|
||||
# 把list of Image转成一个tensor with (B, C, H, W)
|
||||
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
||||
return batch
|
||||
|
||||
|
||||
def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def img_train_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
processed_img = train_transform(samples['pixel_values'])
|
||||
samples['pixel_values'] = processed_img
|
||||
return samples
|
||||
|
||||
|
||||
def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
processed_img = inference_transform(samples['pixel_values'])
|
||||
samples['pixel_values'] = processed_img
|
||||
return samples
|
||||
|
||||
|
||||
def filter_fn(sample, tokenizer=None) -> bool:
|
||||
return (
|
||||
sample['image'].height > MIN_HEIGHT and sample['image'].width > MIN_WIDTH
|
||||
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
|
||||
)
|
||||
|
||||
@@ -1,23 +1,17 @@
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from transformers import EvalPrediction, RobertaTokenizer
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict:
|
||||
cur_dir = Path(os.getcwd())
|
||||
os.chdir(Path(__file__).resolve().parent)
|
||||
metric = evaluate.load('google_bleu') # Will download the metric from huggingface if not already downloaded
|
||||
os.chdir(cur_dir)
|
||||
def bleu_metric(eval_preds:EvalPrediction, tokenizer:RobertaTokenizer) -> Dict:
|
||||
metric = evaluate.load('/home/lhy/code/TexTeller/src/models/ocr_model/train/google_bleu') # 这里需要联网,所以会卡住
|
||||
|
||||
logits, labels = eval_preds.predictions, eval_preds.label_ids
|
||||
preds = logits
|
||||
# preds = np.argmax(logits, axis=1) # 把logits转成对应的预测标签
|
||||
|
||||
labels = np.where(labels == -100, 1, labels)
|
||||
|
||||
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
return metric.compute(predictions=preds, references=labels)
|
||||
return metric.compute(predictions=preds, references=labels)
|
||||
149
src/models/ocr_model/utils/ocr_aug.py
Normal file
149
src/models/ocr_model/utils/ocr_aug.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from augraphy import *
|
||||
import random
|
||||
|
||||
def ocr_augmentation_pipeline():
|
||||
pre_phase = [
|
||||
# Rescale(scale="optimal", target_dpi = 300, p = 1.0),
|
||||
]
|
||||
|
||||
ink_phase = [
|
||||
InkColorSwap(
|
||||
ink_swap_color="lhy_custom",
|
||||
ink_swap_sequence_number_range=(5, 10),
|
||||
ink_swap_min_width_range=(2, 3),
|
||||
ink_swap_max_width_range=(100, 120),
|
||||
ink_swap_min_height_range=(2, 3),
|
||||
ink_swap_max_height_range=(100, 120),
|
||||
ink_swap_min_area_range=(10, 20),
|
||||
ink_swap_max_area_range=(400, 500),
|
||||
p=0.2
|
||||
),
|
||||
LinesDegradation(
|
||||
line_roi=(0.0, 0.0, 1.0, 1.0),
|
||||
line_gradient_range=(32, 255),
|
||||
line_gradient_direction=(0, 2),
|
||||
line_split_probability=(0.2, 0.4),
|
||||
line_replacement_value=(250, 255),
|
||||
line_min_length=(30, 40),
|
||||
line_long_to_short_ratio=(5, 7),
|
||||
line_replacement_probability=(0.4, 0.5),
|
||||
line_replacement_thickness=(1, 3),
|
||||
p=0.2
|
||||
),
|
||||
|
||||
# ============================
|
||||
OneOf(
|
||||
[
|
||||
Dithering(
|
||||
dither="floyd-steinberg",
|
||||
order=(3, 5),
|
||||
),
|
||||
InkBleed(
|
||||
intensity_range=(0.1, 0.2),
|
||||
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]),
|
||||
severity=(0.4, 0.6),
|
||||
),
|
||||
],
|
||||
p=0.2
|
||||
),
|
||||
# ============================
|
||||
|
||||
# ============================
|
||||
InkShifter(
|
||||
text_shift_scale_range=(18, 27),
|
||||
text_shift_factor_range=(1, 4),
|
||||
text_fade_range=(0, 2),
|
||||
blur_kernel_size=(5, 5),
|
||||
blur_sigma=0,
|
||||
noise_type="perlin",
|
||||
p=0.2
|
||||
),
|
||||
# ============================
|
||||
|
||||
]
|
||||
|
||||
paper_phase = [
|
||||
NoiseTexturize( # tested
|
||||
sigma_range=(3, 10),
|
||||
turbulence_range=(2, 5),
|
||||
texture_width_range=(300, 500),
|
||||
texture_height_range=(300, 500),
|
||||
p=0.2
|
||||
),
|
||||
BrightnessTexturize( # tested
|
||||
texturize_range=(0.9, 0.99),
|
||||
deviation=0.03,
|
||||
p=0.2
|
||||
)
|
||||
]
|
||||
|
||||
post_phase = [
|
||||
ColorShift( # tested
|
||||
color_shift_offset_x_range=(3, 5),
|
||||
color_shift_offset_y_range=(3, 5),
|
||||
color_shift_iterations=(2, 3),
|
||||
color_shift_brightness_range=(0.9, 1.1),
|
||||
color_shift_gaussian_kernel_range=(3, 3),
|
||||
p=0.2
|
||||
),
|
||||
|
||||
DirtyDrum( # tested
|
||||
line_width_range=(1, 6),
|
||||
line_concentration=random.uniform(0.05, 0.15),
|
||||
direction=random.randint(0, 2),
|
||||
noise_intensity=random.uniform(0.6, 0.95),
|
||||
noise_value=(64, 224),
|
||||
ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
|
||||
sigmaX=0,
|
||||
p=0.2
|
||||
),
|
||||
|
||||
# =====================================
|
||||
OneOf(
|
||||
[
|
||||
LightingGradient(
|
||||
light_position=None,
|
||||
direction=None,
|
||||
max_brightness=255,
|
||||
min_brightness=0,
|
||||
mode="gaussian",
|
||||
linear_decay_rate=None,
|
||||
transparency=None,
|
||||
),
|
||||
Brightness(
|
||||
brightness_range=(0.9, 1.1),
|
||||
min_brightness=0,
|
||||
min_brightness_value=(120, 150),
|
||||
),
|
||||
Gamma(
|
||||
gamma_range=(0.9, 1.1),
|
||||
),
|
||||
],
|
||||
p=0.2
|
||||
),
|
||||
# =====================================
|
||||
|
||||
# =====================================
|
||||
OneOf(
|
||||
[
|
||||
SubtleNoise(
|
||||
subtle_range=random.randint(5, 10),
|
||||
),
|
||||
Jpeg(
|
||||
quality_range=(85, 95),
|
||||
),
|
||||
],
|
||||
p=0.2
|
||||
),
|
||||
# =====================================
|
||||
]
|
||||
|
||||
pipeline = AugraphyPipeline(
|
||||
ink_phase=ink_phase,
|
||||
paper_phase=paper_phase,
|
||||
post_phase=post_phase,
|
||||
pre_phase=pre_phase,
|
||||
log=False
|
||||
)
|
||||
|
||||
return pipeline
|
||||
@@ -7,47 +7,96 @@ from torchvision.transforms import v2
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
|
||||
from models.globals import (
|
||||
from ...globals import (
|
||||
IMG_CHANNELS,
|
||||
FIXED_IMG_SIZE,
|
||||
IMAGE_MEAN, IMAGE_STD,
|
||||
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
|
||||
)
|
||||
from .ocr_aug import ocr_augmentation_pipeline
|
||||
|
||||
# train_pipeline = default_augraphy_pipeline(scan_only=True)
|
||||
train_pipeline = ocr_augmentation_pipeline()
|
||||
|
||||
general_transform_pipeline = v2.Compose([
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.uint8, scale=True),
|
||||
v2.Grayscale(),
|
||||
v2.Resize(
|
||||
size=FIXED_IMG_SIZE - 1,
|
||||
v2.ToImage(), # Convert to tensor, only needed if you had a PIL image
|
||||
#+返回一个List of torchvision.Image,list的长度就是batch_size
|
||||
#+因此在整个Compose pipeline的最后,输出的也是一个List of torchvision.Image
|
||||
#+注意:不是返回一整个torchvision.Image,batch_size的维度是拿出来的
|
||||
v2.ToDtype(torch.uint8, scale=True), # optional, most input are already uint8 at this point
|
||||
v2.Grayscale(), # 转灰度图(视具体任务而定)
|
||||
|
||||
v2.Resize( # 固定resize到一个正方形上
|
||||
size=FIXED_IMG_SIZE - 1, # size必须小于max_size
|
||||
interpolation=v2.InterpolationMode.BICUBIC,
|
||||
max_size=FIXED_IMG_SIZE,
|
||||
antialias=True
|
||||
),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
|
||||
v2.ToDtype(torch.float32, scale=True), # Normalize expects float input
|
||||
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
|
||||
|
||||
# v2.ToPILImage() # 用于观察转换后的结果是否正确(debug用)
|
||||
])
|
||||
|
||||
|
||||
def trim_white_border(image: np.ndarray):
|
||||
# image是一个3维的ndarray,RGB格式,维度分布为[H, W, C](通道维在第三维上)
|
||||
|
||||
# # 检查images中的第一个元素是否是嵌套的列表结构
|
||||
# if isinstance(image, list):
|
||||
# image = np.array(image, dtype=np.uint8)
|
||||
|
||||
# 检查图像是否为RGB格式,同时检查通道维是不是在第三维上
|
||||
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
# 检查图片是否使用 uint8 类型
|
||||
if image.dtype != np.uint8:
|
||||
raise ValueError(f"Image should stored in uint8")
|
||||
|
||||
# 创建与原图像同样大小的纯白背景图像
|
||||
h, w = image.shape[:2]
|
||||
bg = np.full((h, w, 3), 255, dtype=np.uint8)
|
||||
|
||||
# 计算差异
|
||||
diff = cv2.absdiff(image, bg)
|
||||
|
||||
# 只要差值大于1,就全部转化为255
|
||||
_, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY)
|
||||
|
||||
# 把差值转灰度图
|
||||
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
|
||||
# 计算图像中非零像素点的最小外接矩阵
|
||||
x, y, w, h = cv2.boundingRect(gray_diff)
|
||||
|
||||
# 裁剪图像
|
||||
trimmed_image = image[y:y+h, x:x+w]
|
||||
|
||||
return trimmed_image
|
||||
|
||||
|
||||
def padding(images: List[torch.Tensor], required_size: int):
|
||||
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
|
||||
randi = [random.randint(0, max_size) for _ in range(4)]
|
||||
pad_height_size = randi[1] + randi[3]
|
||||
pad_width_size = randi[0] + randi[2]
|
||||
if (pad_height_size + image.shape[0] < 30):
|
||||
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
|
||||
randi[1] += compensate_height
|
||||
randi[3] += compensate_height
|
||||
if (pad_width_size + image.shape[1] < 30):
|
||||
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
|
||||
randi[0] += compensate_width
|
||||
randi[2] += compensate_width
|
||||
return v2.functional.pad(
|
||||
torch.from_numpy(image).permute(2, 0, 1),
|
||||
padding=randi,
|
||||
padding_mode='constant',
|
||||
fill=(255, 255, 255)
|
||||
)
|
||||
|
||||
|
||||
def padding(images: List[torch.Tensor], required_size: int) -> List[torch.Tensor]:
|
||||
images = [
|
||||
v2.functional.pad(
|
||||
img,
|
||||
@@ -63,6 +112,13 @@ def random_resize(
|
||||
minr: float,
|
||||
maxr: float
|
||||
) -> List[np.ndarray]:
|
||||
# np.ndarray的格式:3维,RGB格式,维度分布为[H, W, C](通道维在第三维上)
|
||||
|
||||
# # 检查images中的第一个元素是否是嵌套的列表结构
|
||||
# if isinstance(images[0], list):
|
||||
# # 将嵌套的列表结构转换为np.ndarray
|
||||
# images = [np.array(img, dtype=np.uint8) for img in images]
|
||||
|
||||
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
@@ -73,18 +129,90 @@ def random_resize(
|
||||
]
|
||||
|
||||
|
||||
def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
|
||||
# Get the center of the image to define the point of rotation
|
||||
image_center = tuple(np.array(image.shape[1::-1]) / 2)
|
||||
|
||||
# Generate a random angle within the specified range
|
||||
angle = random.randint(min_angle, max_angle)
|
||||
|
||||
# Get the rotation matrix for rotating the image around its center
|
||||
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
||||
|
||||
# Determine the size of the rotated image
|
||||
cos = np.abs(rotation_mat[0, 0])
|
||||
sin = np.abs(rotation_mat[0, 1])
|
||||
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
|
||||
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
|
||||
|
||||
# Adjust the rotation matrix to take into account translation
|
||||
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
|
||||
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
|
||||
|
||||
# Rotate the image with the specified border color (white in this case)
|
||||
rotated_image = cv2.warpAffine(image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255))
|
||||
|
||||
return rotated_image
|
||||
|
||||
|
||||
def ocr_aug(image: np.ndarray) -> np.ndarray:
|
||||
# 20%的概率进行随机旋转
|
||||
if random.random() < 0.2:
|
||||
image = rotate(image, -5, 5)
|
||||
# 增加白边
|
||||
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
|
||||
# 数据增强
|
||||
image = train_pipeline(image)
|
||||
return image
|
||||
|
||||
|
||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
assert IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||
|
||||
images = [np.array(img.convert('RGB')) for img in images]
|
||||
# random resize first
|
||||
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||
# 裁剪掉白边
|
||||
images = [trim_white_border(image) for image in images]
|
||||
images = general_transform_pipeline(images)
|
||||
|
||||
# OCR augmentation
|
||||
images = [ocr_aug(image) for image in images]
|
||||
|
||||
# general transform pipeline
|
||||
images = [general_transform_pipeline(image) for image in images]
|
||||
# padding to fixed size
|
||||
images = padding(images, FIXED_IMG_SIZE)
|
||||
return images
|
||||
|
||||
|
||||
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
|
||||
images = [np.array(img.convert('RGB')) for img in images]
|
||||
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||
return general_transform(images)
|
||||
|
||||
|
||||
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
|
||||
return general_transform(images)
|
||||
assert IMG_CHANNELS == 1 , "Only support grayscale images for now"
|
||||
images = [np.array(img.convert('RGB')) for img in images]
|
||||
# 裁剪掉白边
|
||||
images = [trim_white_border(image) for image in images]
|
||||
# general transform pipeline
|
||||
images = [general_transform_pipeline(image) for image in images] # imgs: List[PIL.Image.Image]
|
||||
# padding to fixed size
|
||||
images = padding(images, FIXED_IMG_SIZE)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from pathlib import Path
|
||||
from .helpers import convert2rgb
|
||||
base_dir = Path('/home/lhy/code/TeXify/src/models/ocr_model/model')
|
||||
imgs_path = [
|
||||
base_dir / '1.jpg',
|
||||
base_dir / '2.jpg',
|
||||
base_dir / '3.jpg',
|
||||
base_dir / '4.jpg',
|
||||
base_dir / '5.jpg',
|
||||
base_dir / '6.jpg',
|
||||
base_dir / '7.jpg',
|
||||
]
|
||||
imgs_path = [str(img_path) for img_path in imgs_path]
|
||||
imgs = convert2rgb(imgs_path)
|
||||
res = random_resize(imgs, 0.5, 1.5)
|
||||
pause = 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user