[feat] Add texteller training script

This commit is contained in:
OleehyO
2025-04-19 16:29:49 +00:00
parent 991d6bc00d
commit a7a296025a
43 changed files with 531 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
from .functional import (
collate_fn,
filter_fn,
tokenize_fn,
)
from .transforms import (
img_train_transform,
img_inf_transform,
)
__all__ = [
"collate_fn",
"filter_fn",
"tokenize_fn",
"img_train_transform",
"img_inf_transform",
]

View File

@@ -0,0 +1,165 @@
"""
Custom augraphy pipeline for training
This file implements a custom augraphy data augmentation pipeline. We found that using augraphy's
default pipeline can cause significant degradation to formula images, potentially losing semantic
information. Therefore, we carefully selected several common augmentation effects,
adjusting their parameters and combination methods to preserve the original semantic information
of the images as much as possible.
"""
from augraphy import (
InkColorSwap,
LinesDegradation,
OneOf,
Dithering,
InkBleed,
InkShifter,
NoiseTexturize,
BrightnessTexturize,
ColorShift,
DirtyDrum,
LightingGradient,
Brightness,
Gamma,
SubtleNoise,
Jpeg,
AugraphyPipeline,
)
import random
def get_custom_augraphy():
pre_phase = []
ink_phase = [
InkColorSwap(
ink_swap_color="random",
ink_swap_sequence_number_range=(5, 10),
ink_swap_min_width_range=(2, 3),
ink_swap_max_width_range=(100, 120),
ink_swap_min_height_range=(2, 3),
ink_swap_max_height_range=(100, 120),
ink_swap_min_area_range=(10, 20),
ink_swap_max_area_range=(400, 500),
p=0.2,
),
LinesDegradation(
line_roi=(0.0, 0.0, 1.0, 1.0),
line_gradient_range=(32, 255),
line_gradient_direction=(0, 2),
line_split_probability=(0.2, 0.4),
line_replacement_value=(250, 255),
line_min_length=(30, 40),
line_long_to_short_ratio=(5, 7),
line_replacement_probability=(0.4, 0.5),
line_replacement_thickness=(1, 3),
p=0.2,
),
# ============================
OneOf(
[
Dithering(
dither="floyd-steinberg",
order=(3, 5),
),
InkBleed(
intensity_range=(0.1, 0.2),
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]),
severity=(0.4, 0.6),
),
],
p=0.2,
),
# ============================
# ============================
InkShifter(
text_shift_scale_range=(18, 27),
text_shift_factor_range=(1, 4),
text_fade_range=(0, 2),
blur_kernel_size=(5, 5),
blur_sigma=0,
noise_type="perlin",
p=0.2,
),
# ============================
]
paper_phase = [
NoiseTexturize(
sigma_range=(3, 10),
turbulence_range=(2, 5),
texture_width_range=(300, 500),
texture_height_range=(300, 500),
p=0.2,
),
BrightnessTexturize(texturize_range=(0.9, 0.99), deviation=0.03, p=0.2),
]
post_phase = [
ColorShift(
color_shift_offset_x_range=(3, 5),
color_shift_offset_y_range=(3, 5),
color_shift_iterations=(2, 3),
color_shift_brightness_range=(0.9, 1.1),
color_shift_gaussian_kernel_range=(3, 3),
p=0.2,
),
DirtyDrum(
line_width_range=(1, 6),
line_concentration=random.uniform(0.05, 0.15),
direction=random.randint(0, 2),
noise_intensity=random.uniform(0.6, 0.95),
noise_value=(64, 224),
ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
sigmaX=0,
p=0.2,
),
# =====================================
OneOf(
[
LightingGradient(
light_position=None,
direction=None,
max_brightness=255,
min_brightness=0,
mode="gaussian",
linear_decay_rate=None,
transparency=None,
),
Brightness(
brightness_range=(0.9, 1.1),
min_brightness=0,
min_brightness_value=(120, 150),
),
Gamma(
gamma_range=(0.9, 1.1),
),
],
p=0.2,
),
# =====================================
# =====================================
OneOf(
[
SubtleNoise(
subtle_range=random.randint(5, 10),
),
Jpeg(
quality_range=(70, 95),
),
],
p=0.2,
),
# =====================================
]
pipeline = AugraphyPipeline(
ink_phase=ink_phase,
paper_phase=paper_phase,
post_phase=post_phase,
pre_phase=pre_phase,
log=False,
)
return pipeline

View File

@@ -0,0 +1,47 @@
from typing import Any
import torch
from transformers import DataCollatorForLanguageModeling
from texteller.constants import MAX_TOKEN_SIZE, MIN_HEIGHT, MIN_WIDTH
def _left_move(x: torch.Tensor, pad_val):
assert len(x.shape) == 2, "x should be 2-dimensional"
lefted_x = torch.ones_like(x)
lefted_x[:, :-1] = x[:, 1:]
lefted_x[:, -1] = pad_val
return lefted_x
def tokenize_fn(samples: dict[str, list[Any]], tokenizer=None) -> dict[str, list[Any]]:
assert tokenizer is not None, "tokenizer should not be None"
tokenized_formula = tokenizer(samples["latex_formula"], return_special_tokens_mask=True)
tokenized_formula["pixel_values"] = samples["image"]
return tokenized_formula
def collate_fn(samples: list[dict[str, Any]], tokenizer=None) -> dict[str, list[Any]]:
assert tokenizer is not None, "tokenizer should not be None"
pixel_values = [dic.pop("pixel_values") for dic in samples]
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
batch = clm_collator(samples)
batch["pixel_values"] = pixel_values
batch["decoder_input_ids"] = batch.pop("input_ids")
batch["decoder_attention_mask"] = batch.pop("attention_mask")
batch["labels"] = _left_move(batch["labels"], -100)
# convert list of Image to a tensor with (B, C, H, W)
batch["pixel_values"] = torch.stack(batch["pixel_values"], dim=0)
return batch
def filter_fn(sample, tokenizer=None) -> bool:
return (
sample["image"].height > MIN_HEIGHT
and sample["image"].width > MIN_WIDTH
and len(tokenizer(sample["latex_formula"])["input_ids"]) < MAX_TOKEN_SIZE - 10
)

View File

@@ -0,0 +1,154 @@
import torch
import random
import numpy as np
import cv2
from torchvision.transforms import v2
from typing import Any
from PIL import Image
from collections import Counter
from texteller.constants import (
IMG_CHANNELS,
MAX_RESIZE_RATIO,
MIN_RESIZE_RATIO,
)
from texteller.utils import transform as inference_transform
from .augraphy_pipe import get_custom_augraphy
augraphy_pipeline = get_custom_augraphy()
def trim_white_border(image: np.ndarray):
if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8")
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
bg_color = Counter(corners).most_common(1)[0][0]
bg_color_np = np.array(bg_color, dtype=np.uint8)
h, w = image.shape[:2]
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
diff = cv2.absdiff(image, bg)
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
threshold = 15
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
x, y, w, h = cv2.boundingRect(diff)
trimmed_image = image[y : y + h, x : x + w]
return trimmed_image
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
randi = [random.randint(0, max_size) for _ in range(4)]
pad_height_size = randi[1] + randi[3]
pad_width_size = randi[0] + randi[2]
if pad_height_size + image.shape[0] < 30:
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
randi[1] += compensate_height
randi[3] += compensate_height
if pad_width_size + image.shape[1] < 30:
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
randi[0] += compensate_width
randi[2] += compensate_width
return v2.functional.pad(
torch.from_numpy(image).permute(2, 0, 1),
padding=randi,
padding_mode="constant",
fill=(255, 255, 255),
)
def padding(images: list[torch.Tensor], required_size: int) -> list[torch.Tensor]:
images = [
v2.functional.pad(
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
)
for img in images
]
return images
def random_resize(images: list[np.ndarray], minr: float, maxr: float) -> list[np.ndarray]:
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
return [
# Anti-aliasing
cv2.resize(
img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4
)
for img, r in zip(images, ratios)
]
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
# Get the center of the image to define the point of rotation
image_center = tuple(np.array(image.shape[1::-1]) / 2)
# Generate a random angle within the specified range
angle = random.randint(min_angle, max_angle)
# Get the rotation matrix for rotating the image around its center
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
# Determine the size of the rotated image
cos = np.abs(rotation_mat[0, 0])
sin = np.abs(rotation_mat[0, 1])
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
# Adjust the rotation matrix to take into account translation
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
# Rotate the image with the specified border color (white in this case)
rotated_image = cv2.warpAffine(
image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255)
)
return rotated_image
def ocr_aug(image: np.ndarray) -> np.ndarray:
if random.random() < 0.2:
image = rotate(image, -5, 5)
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
image = augraphy_pipeline(image)
return image
def train_transform(images: list[Image.Image]) -> list[torch.Tensor]:
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
images = [np.array(img.convert("RGB")) for img in images]
# random resize first
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
images = [trim_white_border(image) for image in images]
# OCR augmentation
images = [ocr_aug(image) for image in images]
# general transform pipeline
images = inference_transform(images)
return images
def img_train_transform(samples: dict[str, list[Any]]) -> dict[str, list[Any]]:
processed_img = train_transform(samples["pixel_values"])
samples["pixel_values"] = processed_img
return samples
def img_inf_transform(samples: dict[str, list[Any]]) -> dict[str, list[Any]]:
processed_img = inference_transform(samples["pixel_values"])
samples["pixel_values"] = processed_img
return samples