[feat] Add texteller training script
This commit is contained in:
17
examples/train_texteller/utils/__init__.py
Normal file
17
examples/train_texteller/utils/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .functional import (
|
||||
collate_fn,
|
||||
filter_fn,
|
||||
tokenize_fn,
|
||||
)
|
||||
from .transforms import (
|
||||
img_train_transform,
|
||||
img_inf_transform,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"collate_fn",
|
||||
"filter_fn",
|
||||
"tokenize_fn",
|
||||
"img_train_transform",
|
||||
"img_inf_transform",
|
||||
]
|
||||
165
examples/train_texteller/utils/augraphy_pipe.py
Normal file
165
examples/train_texteller/utils/augraphy_pipe.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Custom augraphy pipeline for training
|
||||
|
||||
This file implements a custom augraphy data augmentation pipeline. We found that using augraphy's
|
||||
default pipeline can cause significant degradation to formula images, potentially losing semantic
|
||||
information. Therefore, we carefully selected several common augmentation effects,
|
||||
adjusting their parameters and combination methods to preserve the original semantic information
|
||||
of the images as much as possible.
|
||||
"""
|
||||
|
||||
from augraphy import (
|
||||
InkColorSwap,
|
||||
LinesDegradation,
|
||||
OneOf,
|
||||
Dithering,
|
||||
InkBleed,
|
||||
InkShifter,
|
||||
NoiseTexturize,
|
||||
BrightnessTexturize,
|
||||
ColorShift,
|
||||
DirtyDrum,
|
||||
LightingGradient,
|
||||
Brightness,
|
||||
Gamma,
|
||||
SubtleNoise,
|
||||
Jpeg,
|
||||
AugraphyPipeline,
|
||||
)
|
||||
import random
|
||||
|
||||
|
||||
def get_custom_augraphy():
|
||||
pre_phase = []
|
||||
|
||||
ink_phase = [
|
||||
InkColorSwap(
|
||||
ink_swap_color="random",
|
||||
ink_swap_sequence_number_range=(5, 10),
|
||||
ink_swap_min_width_range=(2, 3),
|
||||
ink_swap_max_width_range=(100, 120),
|
||||
ink_swap_min_height_range=(2, 3),
|
||||
ink_swap_max_height_range=(100, 120),
|
||||
ink_swap_min_area_range=(10, 20),
|
||||
ink_swap_max_area_range=(400, 500),
|
||||
p=0.2,
|
||||
),
|
||||
LinesDegradation(
|
||||
line_roi=(0.0, 0.0, 1.0, 1.0),
|
||||
line_gradient_range=(32, 255),
|
||||
line_gradient_direction=(0, 2),
|
||||
line_split_probability=(0.2, 0.4),
|
||||
line_replacement_value=(250, 255),
|
||||
line_min_length=(30, 40),
|
||||
line_long_to_short_ratio=(5, 7),
|
||||
line_replacement_probability=(0.4, 0.5),
|
||||
line_replacement_thickness=(1, 3),
|
||||
p=0.2,
|
||||
),
|
||||
# ============================
|
||||
OneOf(
|
||||
[
|
||||
Dithering(
|
||||
dither="floyd-steinberg",
|
||||
order=(3, 5),
|
||||
),
|
||||
InkBleed(
|
||||
intensity_range=(0.1, 0.2),
|
||||
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]),
|
||||
severity=(0.4, 0.6),
|
||||
),
|
||||
],
|
||||
p=0.2,
|
||||
),
|
||||
# ============================
|
||||
# ============================
|
||||
InkShifter(
|
||||
text_shift_scale_range=(18, 27),
|
||||
text_shift_factor_range=(1, 4),
|
||||
text_fade_range=(0, 2),
|
||||
blur_kernel_size=(5, 5),
|
||||
blur_sigma=0,
|
||||
noise_type="perlin",
|
||||
p=0.2,
|
||||
),
|
||||
# ============================
|
||||
]
|
||||
|
||||
paper_phase = [
|
||||
NoiseTexturize(
|
||||
sigma_range=(3, 10),
|
||||
turbulence_range=(2, 5),
|
||||
texture_width_range=(300, 500),
|
||||
texture_height_range=(300, 500),
|
||||
p=0.2,
|
||||
),
|
||||
BrightnessTexturize(texturize_range=(0.9, 0.99), deviation=0.03, p=0.2),
|
||||
]
|
||||
|
||||
post_phase = [
|
||||
ColorShift(
|
||||
color_shift_offset_x_range=(3, 5),
|
||||
color_shift_offset_y_range=(3, 5),
|
||||
color_shift_iterations=(2, 3),
|
||||
color_shift_brightness_range=(0.9, 1.1),
|
||||
color_shift_gaussian_kernel_range=(3, 3),
|
||||
p=0.2,
|
||||
),
|
||||
DirtyDrum(
|
||||
line_width_range=(1, 6),
|
||||
line_concentration=random.uniform(0.05, 0.15),
|
||||
direction=random.randint(0, 2),
|
||||
noise_intensity=random.uniform(0.6, 0.95),
|
||||
noise_value=(64, 224),
|
||||
ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
|
||||
sigmaX=0,
|
||||
p=0.2,
|
||||
),
|
||||
# =====================================
|
||||
OneOf(
|
||||
[
|
||||
LightingGradient(
|
||||
light_position=None,
|
||||
direction=None,
|
||||
max_brightness=255,
|
||||
min_brightness=0,
|
||||
mode="gaussian",
|
||||
linear_decay_rate=None,
|
||||
transparency=None,
|
||||
),
|
||||
Brightness(
|
||||
brightness_range=(0.9, 1.1),
|
||||
min_brightness=0,
|
||||
min_brightness_value=(120, 150),
|
||||
),
|
||||
Gamma(
|
||||
gamma_range=(0.9, 1.1),
|
||||
),
|
||||
],
|
||||
p=0.2,
|
||||
),
|
||||
# =====================================
|
||||
# =====================================
|
||||
OneOf(
|
||||
[
|
||||
SubtleNoise(
|
||||
subtle_range=random.randint(5, 10),
|
||||
),
|
||||
Jpeg(
|
||||
quality_range=(70, 95),
|
||||
),
|
||||
],
|
||||
p=0.2,
|
||||
),
|
||||
# =====================================
|
||||
]
|
||||
|
||||
pipeline = AugraphyPipeline(
|
||||
ink_phase=ink_phase,
|
||||
paper_phase=paper_phase,
|
||||
post_phase=post_phase,
|
||||
pre_phase=pre_phase,
|
||||
log=False,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
47
examples/train_texteller/utils/functional.py
Normal file
47
examples/train_texteller/utils/functional.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
from texteller.constants import MAX_TOKEN_SIZE, MIN_HEIGHT, MIN_WIDTH
|
||||
|
||||
|
||||
def _left_move(x: torch.Tensor, pad_val):
|
||||
assert len(x.shape) == 2, "x should be 2-dimensional"
|
||||
lefted_x = torch.ones_like(x)
|
||||
lefted_x[:, :-1] = x[:, 1:]
|
||||
lefted_x[:, -1] = pad_val
|
||||
return lefted_x
|
||||
|
||||
|
||||
def tokenize_fn(samples: dict[str, list[Any]], tokenizer=None) -> dict[str, list[Any]]:
|
||||
assert tokenizer is not None, "tokenizer should not be None"
|
||||
tokenized_formula = tokenizer(samples["latex_formula"], return_special_tokens_mask=True)
|
||||
tokenized_formula["pixel_values"] = samples["image"]
|
||||
return tokenized_formula
|
||||
|
||||
|
||||
def collate_fn(samples: list[dict[str, Any]], tokenizer=None) -> dict[str, list[Any]]:
|
||||
assert tokenizer is not None, "tokenizer should not be None"
|
||||
pixel_values = [dic.pop("pixel_values") for dic in samples]
|
||||
|
||||
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
batch = clm_collator(samples)
|
||||
batch["pixel_values"] = pixel_values
|
||||
batch["decoder_input_ids"] = batch.pop("input_ids")
|
||||
batch["decoder_attention_mask"] = batch.pop("attention_mask")
|
||||
|
||||
batch["labels"] = _left_move(batch["labels"], -100)
|
||||
|
||||
# convert list of Image to a tensor with (B, C, H, W)
|
||||
batch["pixel_values"] = torch.stack(batch["pixel_values"], dim=0)
|
||||
return batch
|
||||
|
||||
|
||||
def filter_fn(sample, tokenizer=None) -> bool:
|
||||
return (
|
||||
sample["image"].height > MIN_HEIGHT
|
||||
and sample["image"].width > MIN_WIDTH
|
||||
and len(tokenizer(sample["latex_formula"])["input_ids"]) < MAX_TOKEN_SIZE - 10
|
||||
)
|
||||
154
examples/train_texteller/utils/transforms.py
Normal file
154
examples/train_texteller/utils/transforms.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from torchvision.transforms import v2
|
||||
from typing import Any
|
||||
from PIL import Image
|
||||
from collections import Counter
|
||||
|
||||
from texteller.constants import (
|
||||
IMG_CHANNELS,
|
||||
MAX_RESIZE_RATIO,
|
||||
MIN_RESIZE_RATIO,
|
||||
)
|
||||
from texteller.utils import transform as inference_transform
|
||||
from .augraphy_pipe import get_custom_augraphy
|
||||
|
||||
augraphy_pipeline = get_custom_augraphy()
|
||||
|
||||
|
||||
def trim_white_border(image: np.ndarray):
|
||||
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
if image.dtype != np.uint8:
|
||||
raise ValueError(f"Image should stored in uint8")
|
||||
|
||||
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
|
||||
bg_color = Counter(corners).most_common(1)[0][0]
|
||||
bg_color_np = np.array(bg_color, dtype=np.uint8)
|
||||
|
||||
h, w = image.shape[:2]
|
||||
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
|
||||
|
||||
diff = cv2.absdiff(image, bg)
|
||||
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
threshold = 15
|
||||
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
|
||||
|
||||
x, y, w, h = cv2.boundingRect(diff)
|
||||
|
||||
trimmed_image = image[y : y + h, x : x + w]
|
||||
|
||||
return trimmed_image
|
||||
|
||||
|
||||
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
|
||||
randi = [random.randint(0, max_size) for _ in range(4)]
|
||||
pad_height_size = randi[1] + randi[3]
|
||||
pad_width_size = randi[0] + randi[2]
|
||||
if pad_height_size + image.shape[0] < 30:
|
||||
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
|
||||
randi[1] += compensate_height
|
||||
randi[3] += compensate_height
|
||||
if pad_width_size + image.shape[1] < 30:
|
||||
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
|
||||
randi[0] += compensate_width
|
||||
randi[2] += compensate_width
|
||||
return v2.functional.pad(
|
||||
torch.from_numpy(image).permute(2, 0, 1),
|
||||
padding=randi,
|
||||
padding_mode="constant",
|
||||
fill=(255, 255, 255),
|
||||
)
|
||||
|
||||
|
||||
def padding(images: list[torch.Tensor], required_size: int) -> list[torch.Tensor]:
|
||||
images = [
|
||||
v2.functional.pad(
|
||||
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def random_resize(images: list[np.ndarray], minr: float, maxr: float) -> list[np.ndarray]:
|
||||
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
||||
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||
|
||||
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
|
||||
return [
|
||||
# Anti-aliasing
|
||||
cv2.resize(
|
||||
img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4
|
||||
)
|
||||
for img, r in zip(images, ratios)
|
||||
]
|
||||
|
||||
|
||||
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
|
||||
# Get the center of the image to define the point of rotation
|
||||
image_center = tuple(np.array(image.shape[1::-1]) / 2)
|
||||
|
||||
# Generate a random angle within the specified range
|
||||
angle = random.randint(min_angle, max_angle)
|
||||
|
||||
# Get the rotation matrix for rotating the image around its center
|
||||
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
||||
|
||||
# Determine the size of the rotated image
|
||||
cos = np.abs(rotation_mat[0, 0])
|
||||
sin = np.abs(rotation_mat[0, 1])
|
||||
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
|
||||
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
|
||||
|
||||
# Adjust the rotation matrix to take into account translation
|
||||
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
|
||||
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
|
||||
|
||||
# Rotate the image with the specified border color (white in this case)
|
||||
rotated_image = cv2.warpAffine(
|
||||
image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255)
|
||||
)
|
||||
|
||||
return rotated_image
|
||||
|
||||
|
||||
def ocr_aug(image: np.ndarray) -> np.ndarray:
|
||||
if random.random() < 0.2:
|
||||
image = rotate(image, -5, 5)
|
||||
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
|
||||
image = augraphy_pipeline(image)
|
||||
return image
|
||||
|
||||
|
||||
def train_transform(images: list[Image.Image]) -> list[torch.Tensor]:
|
||||
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
||||
|
||||
images = [np.array(img.convert("RGB")) for img in images]
|
||||
# random resize first
|
||||
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||
images = [trim_white_border(image) for image in images]
|
||||
|
||||
# OCR augmentation
|
||||
images = [ocr_aug(image) for image in images]
|
||||
|
||||
# general transform pipeline
|
||||
images = inference_transform(images)
|
||||
return images
|
||||
|
||||
|
||||
def img_train_transform(samples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
processed_img = train_transform(samples["pixel_values"])
|
||||
samples["pixel_values"] = processed_img
|
||||
return samples
|
||||
|
||||
|
||||
def img_inf_transform(samples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
processed_img = inference_transform(samples["pixel_values"])
|
||||
samples["pixel_values"] = processed_img
|
||||
return samples
|
||||
Reference in New Issue
Block a user