[feat] Add texteller training script
10
examples/client_demo.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
import requests
|
||||||
|
|
||||||
|
server_url = "http://127.0.0.1:8000/predict"
|
||||||
|
|
||||||
|
img_path = "/path/to/your/image"
|
||||||
|
with open(img_path, 'rb') as img:
|
||||||
|
files = {'img': img}
|
||||||
|
response = requests.post(server_url, files=files)
|
||||||
|
|
||||||
|
print(response.text)
|
||||||
BIN
examples/train_texteller/dataset/train/0.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
examples/train_texteller/dataset/train/1.png
Normal file
|
After Width: | Height: | Size: 8.7 KiB |
BIN
examples/train_texteller/dataset/train/10.png
Normal file
|
After Width: | Height: | Size: 6.8 KiB |
BIN
examples/train_texteller/dataset/train/11.png
Normal file
|
After Width: | Height: | Size: 4.1 KiB |
BIN
examples/train_texteller/dataset/train/12.png
Normal file
|
After Width: | Height: | Size: 5.2 KiB |
BIN
examples/train_texteller/dataset/train/13.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
examples/train_texteller/dataset/train/14.png
Normal file
|
After Width: | Height: | Size: 2.8 KiB |
BIN
examples/train_texteller/dataset/train/15.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
examples/train_texteller/dataset/train/16.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
examples/train_texteller/dataset/train/17.png
Normal file
|
After Width: | Height: | Size: 2.6 KiB |
BIN
examples/train_texteller/dataset/train/18.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
examples/train_texteller/dataset/train/19.png
Normal file
|
After Width: | Height: | Size: 2.7 KiB |
BIN
examples/train_texteller/dataset/train/2.png
Normal file
|
After Width: | Height: | Size: 3.9 KiB |
BIN
examples/train_texteller/dataset/train/20.png
Normal file
|
After Width: | Height: | Size: 3.9 KiB |
BIN
examples/train_texteller/dataset/train/21.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
examples/train_texteller/dataset/train/22.png
Normal file
|
After Width: | Height: | Size: 3.7 KiB |
BIN
examples/train_texteller/dataset/train/23.png
Normal file
|
After Width: | Height: | Size: 3.5 KiB |
BIN
examples/train_texteller/dataset/train/24.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
examples/train_texteller/dataset/train/25.png
Normal file
|
After Width: | Height: | Size: 2.5 KiB |
BIN
examples/train_texteller/dataset/train/26.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
BIN
examples/train_texteller/dataset/train/27.png
Normal file
|
After Width: | Height: | Size: 3.1 KiB |
BIN
examples/train_texteller/dataset/train/28.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
examples/train_texteller/dataset/train/29.png
Normal file
|
After Width: | Height: | Size: 5.3 KiB |
BIN
examples/train_texteller/dataset/train/3.png
Normal file
|
After Width: | Height: | Size: 4.1 KiB |
BIN
examples/train_texteller/dataset/train/30.png
Normal file
|
After Width: | Height: | Size: 3.9 KiB |
BIN
examples/train_texteller/dataset/train/31.png
Normal file
|
After Width: | Height: | Size: 4.9 KiB |
BIN
examples/train_texteller/dataset/train/32.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
examples/train_texteller/dataset/train/33.png
Normal file
|
After Width: | Height: | Size: 1.8 KiB |
BIN
examples/train_texteller/dataset/train/34.png
Normal file
|
After Width: | Height: | Size: 3.2 KiB |
BIN
examples/train_texteller/dataset/train/4.png
Normal file
|
After Width: | Height: | Size: 5.7 KiB |
BIN
examples/train_texteller/dataset/train/5.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
examples/train_texteller/dataset/train/6.png
Normal file
|
After Width: | Height: | Size: 4.8 KiB |
BIN
examples/train_texteller/dataset/train/7.png
Normal file
|
After Width: | Height: | Size: 4.5 KiB |
BIN
examples/train_texteller/dataset/train/8.png
Normal file
|
After Width: | Height: | Size: 2.5 KiB |
BIN
examples/train_texteller/dataset/train/9.png
Normal file
|
After Width: | Height: | Size: 5.2 KiB |
35
examples/train_texteller/dataset/train/metadata.jsonl
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
|
||||||
|
{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
|
||||||
|
{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
|
||||||
|
{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
|
||||||
|
{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
|
||||||
|
{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
|
||||||
|
{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
|
||||||
|
{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
|
||||||
|
{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
|
||||||
|
{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
|
||||||
|
{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
|
||||||
|
{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
|
||||||
|
{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
|
||||||
|
{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
|
||||||
|
{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
|
||||||
|
{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
|
||||||
|
{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
|
||||||
|
{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
|
||||||
|
{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
|
||||||
|
{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
|
||||||
|
{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
|
||||||
|
{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
|
||||||
|
{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
|
||||||
|
{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
|
||||||
|
{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
|
||||||
|
{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||||
|
{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
|
||||||
|
{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
|
||||||
|
{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
|
||||||
|
{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||||
|
{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
|
||||||
|
{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
|
||||||
|
{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
|
||||||
|
{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
|
||||||
|
{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}
|
||||||
71
examples/train_texteller/train.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import (
|
||||||
|
Trainer,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
from texteller import load_model, load_tokenizer
|
||||||
|
from texteller.constants import MIN_HEIGHT, MIN_WIDTH
|
||||||
|
|
||||||
|
from examples.train_texteller.utils import (
|
||||||
|
collate_fn,
|
||||||
|
filter_fn,
|
||||||
|
img_inf_transform,
|
||||||
|
img_train_transform,
|
||||||
|
tokenize_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
|
||||||
|
training_args = TrainingArguments(**training_config)
|
||||||
|
trainer = Trainer(
|
||||||
|
model,
|
||||||
|
training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=collate_fn_with_tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train(resume_from_checkpoint=None)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = load_dataset("imagefolder", data_dir="dataset")["train"]
|
||||||
|
dataset = dataset.filter(
|
||||||
|
lambda x: x["image"].height > MIN_HEIGHT and x["image"].width > MIN_WIDTH
|
||||||
|
)
|
||||||
|
dataset = dataset.shuffle(seed=42)
|
||||||
|
dataset = dataset.flatten_indices()
|
||||||
|
|
||||||
|
tokenizer = load_tokenizer()
|
||||||
|
# If you want use your own tokenizer, please modify the path to your tokenizer
|
||||||
|
# tokenizer = load_tokenizer("/path/to/your/tokenizer")
|
||||||
|
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
|
||||||
|
dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8)
|
||||||
|
|
||||||
|
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
|
||||||
|
tokenized_dataset = dataset.map(
|
||||||
|
map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split dataset into train and eval, ratio 9:1
|
||||||
|
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
|
||||||
|
train_dataset, eval_dataset = split_dataset["train"], split_dataset["test"]
|
||||||
|
train_dataset = train_dataset.with_transform(img_train_transform)
|
||||||
|
eval_dataset = eval_dataset.with_transform(img_inf_transform)
|
||||||
|
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
# Train from scratch
|
||||||
|
model = load_model()
|
||||||
|
|
||||||
|
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
|
||||||
|
# model = load_model("/path/to/your/model_checkpoint")
|
||||||
|
|
||||||
|
enable_train = True
|
||||||
|
training_config = yaml.safe_load(open("train_config.yaml"))
|
||||||
|
if enable_train:
|
||||||
|
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
|
||||||
32
examples/train_texteller/train_config.yaml
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# For more information, please refer to the official documentation: https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments
|
||||||
|
|
||||||
|
seed: 42 # Random seed for reproducibility
|
||||||
|
use_cpu: false # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
|
||||||
|
learning_rate: 5.0e-5 # Learning rate
|
||||||
|
num_train_epochs: 10 # Total number of training epochs
|
||||||
|
per_device_train_batch_size: 4 # Batch size per GPU for training
|
||||||
|
per_device_eval_batch_size: 8 # Batch size per GPU for evaluation
|
||||||
|
output_dir: "train_result" # Output directory
|
||||||
|
overwrite_output_dir: false # If the output directory exists, do not delete its content
|
||||||
|
report_to:
|
||||||
|
- tensorboard # Report logs to TensorBoard
|
||||||
|
save_strategy: "steps" # Strategy to save checkpoints
|
||||||
|
save_steps: 500 # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
|
||||||
|
save_total_limit: 5 # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
|
||||||
|
logging_strategy: "steps" # Log every certain number of steps
|
||||||
|
logging_steps: 500 # Number of steps between each log
|
||||||
|
logging_nan_inf_filter: false # Record logs for loss=nan or inf
|
||||||
|
optim: "adamw_torch" # Optimizer
|
||||||
|
lr_scheduler_type: "cosine" # Learning rate scheduler
|
||||||
|
warmup_ratio: 0.1 # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
|
||||||
|
max_grad_norm: 1.0 # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
|
||||||
|
fp16: false # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
|
||||||
|
bf16: false # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
|
||||||
|
gradient_accumulation_steps: 1 # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
|
||||||
|
jit_mode_eval: false # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
|
||||||
|
torch_compile: false # Whether to use torch.compile to compile the model (for better training and inference performance)
|
||||||
|
dataloader_pin_memory: true # Can speed up data transfer between CPU and GPU
|
||||||
|
dataloader_num_workers: 1 # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
|
||||||
|
evaluation_strategy: "steps" # Evaluation strategy, can be "steps" or "epoch"
|
||||||
|
eval_steps: 500 # If evaluation_strategy="step"
|
||||||
|
remove_unused_columns: false # Don't change this unless you really know what you are doing.
|
||||||
17
examples/train_texteller/utils/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from .functional import (
|
||||||
|
collate_fn,
|
||||||
|
filter_fn,
|
||||||
|
tokenize_fn,
|
||||||
|
)
|
||||||
|
from .transforms import (
|
||||||
|
img_train_transform,
|
||||||
|
img_inf_transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"collate_fn",
|
||||||
|
"filter_fn",
|
||||||
|
"tokenize_fn",
|
||||||
|
"img_train_transform",
|
||||||
|
"img_inf_transform",
|
||||||
|
]
|
||||||
165
examples/train_texteller/utils/augraphy_pipe.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""
|
||||||
|
Custom augraphy pipeline for training
|
||||||
|
|
||||||
|
This file implements a custom augraphy data augmentation pipeline. We found that using augraphy's
|
||||||
|
default pipeline can cause significant degradation to formula images, potentially losing semantic
|
||||||
|
information. Therefore, we carefully selected several common augmentation effects,
|
||||||
|
adjusting their parameters and combination methods to preserve the original semantic information
|
||||||
|
of the images as much as possible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from augraphy import (
|
||||||
|
InkColorSwap,
|
||||||
|
LinesDegradation,
|
||||||
|
OneOf,
|
||||||
|
Dithering,
|
||||||
|
InkBleed,
|
||||||
|
InkShifter,
|
||||||
|
NoiseTexturize,
|
||||||
|
BrightnessTexturize,
|
||||||
|
ColorShift,
|
||||||
|
DirtyDrum,
|
||||||
|
LightingGradient,
|
||||||
|
Brightness,
|
||||||
|
Gamma,
|
||||||
|
SubtleNoise,
|
||||||
|
Jpeg,
|
||||||
|
AugraphyPipeline,
|
||||||
|
)
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_augraphy():
|
||||||
|
pre_phase = []
|
||||||
|
|
||||||
|
ink_phase = [
|
||||||
|
InkColorSwap(
|
||||||
|
ink_swap_color="random",
|
||||||
|
ink_swap_sequence_number_range=(5, 10),
|
||||||
|
ink_swap_min_width_range=(2, 3),
|
||||||
|
ink_swap_max_width_range=(100, 120),
|
||||||
|
ink_swap_min_height_range=(2, 3),
|
||||||
|
ink_swap_max_height_range=(100, 120),
|
||||||
|
ink_swap_min_area_range=(10, 20),
|
||||||
|
ink_swap_max_area_range=(400, 500),
|
||||||
|
p=0.2,
|
||||||
|
),
|
||||||
|
LinesDegradation(
|
||||||
|
line_roi=(0.0, 0.0, 1.0, 1.0),
|
||||||
|
line_gradient_range=(32, 255),
|
||||||
|
line_gradient_direction=(0, 2),
|
||||||
|
line_split_probability=(0.2, 0.4),
|
||||||
|
line_replacement_value=(250, 255),
|
||||||
|
line_min_length=(30, 40),
|
||||||
|
line_long_to_short_ratio=(5, 7),
|
||||||
|
line_replacement_probability=(0.4, 0.5),
|
||||||
|
line_replacement_thickness=(1, 3),
|
||||||
|
p=0.2,
|
||||||
|
),
|
||||||
|
# ============================
|
||||||
|
OneOf(
|
||||||
|
[
|
||||||
|
Dithering(
|
||||||
|
dither="floyd-steinberg",
|
||||||
|
order=(3, 5),
|
||||||
|
),
|
||||||
|
InkBleed(
|
||||||
|
intensity_range=(0.1, 0.2),
|
||||||
|
kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]),
|
||||||
|
severity=(0.4, 0.6),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
p=0.2,
|
||||||
|
),
|
||||||
|
# ============================
|
||||||
|
# ============================
|
||||||
|
InkShifter(
|
||||||
|
text_shift_scale_range=(18, 27),
|
||||||
|
text_shift_factor_range=(1, 4),
|
||||||
|
text_fade_range=(0, 2),
|
||||||
|
blur_kernel_size=(5, 5),
|
||||||
|
blur_sigma=0,
|
||||||
|
noise_type="perlin",
|
||||||
|
p=0.2,
|
||||||
|
),
|
||||||
|
# ============================
|
||||||
|
]
|
||||||
|
|
||||||
|
paper_phase = [
|
||||||
|
NoiseTexturize(
|
||||||
|
sigma_range=(3, 10),
|
||||||
|
turbulence_range=(2, 5),
|
||||||
|
texture_width_range=(300, 500),
|
||||||
|
texture_height_range=(300, 500),
|
||||||
|
p=0.2,
|
||||||
|
),
|
||||||
|
BrightnessTexturize(texturize_range=(0.9, 0.99), deviation=0.03, p=0.2),
|
||||||
|
]
|
||||||
|
|
||||||
|
post_phase = [
|
||||||
|
ColorShift(
|
||||||
|
color_shift_offset_x_range=(3, 5),
|
||||||
|
color_shift_offset_y_range=(3, 5),
|
||||||
|
color_shift_iterations=(2, 3),
|
||||||
|
color_shift_brightness_range=(0.9, 1.1),
|
||||||
|
color_shift_gaussian_kernel_range=(3, 3),
|
||||||
|
p=0.2,
|
||||||
|
),
|
||||||
|
DirtyDrum(
|
||||||
|
line_width_range=(1, 6),
|
||||||
|
line_concentration=random.uniform(0.05, 0.15),
|
||||||
|
direction=random.randint(0, 2),
|
||||||
|
noise_intensity=random.uniform(0.6, 0.95),
|
||||||
|
noise_value=(64, 224),
|
||||||
|
ksize=random.choice([(3, 3), (5, 5), (7, 7)]),
|
||||||
|
sigmaX=0,
|
||||||
|
p=0.2,
|
||||||
|
),
|
||||||
|
# =====================================
|
||||||
|
OneOf(
|
||||||
|
[
|
||||||
|
LightingGradient(
|
||||||
|
light_position=None,
|
||||||
|
direction=None,
|
||||||
|
max_brightness=255,
|
||||||
|
min_brightness=0,
|
||||||
|
mode="gaussian",
|
||||||
|
linear_decay_rate=None,
|
||||||
|
transparency=None,
|
||||||
|
),
|
||||||
|
Brightness(
|
||||||
|
brightness_range=(0.9, 1.1),
|
||||||
|
min_brightness=0,
|
||||||
|
min_brightness_value=(120, 150),
|
||||||
|
),
|
||||||
|
Gamma(
|
||||||
|
gamma_range=(0.9, 1.1),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
p=0.2,
|
||||||
|
),
|
||||||
|
# =====================================
|
||||||
|
# =====================================
|
||||||
|
OneOf(
|
||||||
|
[
|
||||||
|
SubtleNoise(
|
||||||
|
subtle_range=random.randint(5, 10),
|
||||||
|
),
|
||||||
|
Jpeg(
|
||||||
|
quality_range=(70, 95),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
p=0.2,
|
||||||
|
),
|
||||||
|
# =====================================
|
||||||
|
]
|
||||||
|
|
||||||
|
pipeline = AugraphyPipeline(
|
||||||
|
ink_phase=ink_phase,
|
||||||
|
paper_phase=paper_phase,
|
||||||
|
post_phase=post_phase,
|
||||||
|
pre_phase=pre_phase,
|
||||||
|
log=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
47
examples/train_texteller/utils/functional.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import DataCollatorForLanguageModeling
|
||||||
|
|
||||||
|
from texteller.constants import MAX_TOKEN_SIZE, MIN_HEIGHT, MIN_WIDTH
|
||||||
|
|
||||||
|
|
||||||
|
def _left_move(x: torch.Tensor, pad_val):
|
||||||
|
assert len(x.shape) == 2, "x should be 2-dimensional"
|
||||||
|
lefted_x = torch.ones_like(x)
|
||||||
|
lefted_x[:, :-1] = x[:, 1:]
|
||||||
|
lefted_x[:, -1] = pad_val
|
||||||
|
return lefted_x
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_fn(samples: dict[str, list[Any]], tokenizer=None) -> dict[str, list[Any]]:
|
||||||
|
assert tokenizer is not None, "tokenizer should not be None"
|
||||||
|
tokenized_formula = tokenizer(samples["latex_formula"], return_special_tokens_mask=True)
|
||||||
|
tokenized_formula["pixel_values"] = samples["image"]
|
||||||
|
return tokenized_formula
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(samples: list[dict[str, Any]], tokenizer=None) -> dict[str, list[Any]]:
|
||||||
|
assert tokenizer is not None, "tokenizer should not be None"
|
||||||
|
pixel_values = [dic.pop("pixel_values") for dic in samples]
|
||||||
|
|
||||||
|
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
|
batch = clm_collator(samples)
|
||||||
|
batch["pixel_values"] = pixel_values
|
||||||
|
batch["decoder_input_ids"] = batch.pop("input_ids")
|
||||||
|
batch["decoder_attention_mask"] = batch.pop("attention_mask")
|
||||||
|
|
||||||
|
batch["labels"] = _left_move(batch["labels"], -100)
|
||||||
|
|
||||||
|
# convert list of Image to a tensor with (B, C, H, W)
|
||||||
|
batch["pixel_values"] = torch.stack(batch["pixel_values"], dim=0)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def filter_fn(sample, tokenizer=None) -> bool:
|
||||||
|
return (
|
||||||
|
sample["image"].height > MIN_HEIGHT
|
||||||
|
and sample["image"].width > MIN_WIDTH
|
||||||
|
and len(tokenizer(sample["latex_formula"])["input_ids"]) < MAX_TOKEN_SIZE - 10
|
||||||
|
)
|
||||||
154
examples/train_texteller/utils/transforms.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from typing import Any
|
||||||
|
from PIL import Image
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
from texteller.constants import (
|
||||||
|
IMG_CHANNELS,
|
||||||
|
MAX_RESIZE_RATIO,
|
||||||
|
MIN_RESIZE_RATIO,
|
||||||
|
)
|
||||||
|
from texteller.utils import transform as inference_transform
|
||||||
|
from .augraphy_pipe import get_custom_augraphy
|
||||||
|
|
||||||
|
augraphy_pipeline = get_custom_augraphy()
|
||||||
|
|
||||||
|
|
||||||
|
def trim_white_border(image: np.ndarray):
|
||||||
|
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||||
|
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||||
|
|
||||||
|
if image.dtype != np.uint8:
|
||||||
|
raise ValueError(f"Image should stored in uint8")
|
||||||
|
|
||||||
|
corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])]
|
||||||
|
bg_color = Counter(corners).most_common(1)[0][0]
|
||||||
|
bg_color_np = np.array(bg_color, dtype=np.uint8)
|
||||||
|
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8)
|
||||||
|
|
||||||
|
diff = cv2.absdiff(image, bg)
|
||||||
|
mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
threshold = 15
|
||||||
|
_, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
|
||||||
|
|
||||||
|
x, y, w, h = cv2.boundingRect(diff)
|
||||||
|
|
||||||
|
trimmed_image = image[y : y + h, x : x + w]
|
||||||
|
|
||||||
|
return trimmed_image
|
||||||
|
|
||||||
|
|
||||||
|
def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray:
|
||||||
|
randi = [random.randint(0, max_size) for _ in range(4)]
|
||||||
|
pad_height_size = randi[1] + randi[3]
|
||||||
|
pad_width_size = randi[0] + randi[2]
|
||||||
|
if pad_height_size + image.shape[0] < 30:
|
||||||
|
compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1
|
||||||
|
randi[1] += compensate_height
|
||||||
|
randi[3] += compensate_height
|
||||||
|
if pad_width_size + image.shape[1] < 30:
|
||||||
|
compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1
|
||||||
|
randi[0] += compensate_width
|
||||||
|
randi[2] += compensate_width
|
||||||
|
return v2.functional.pad(
|
||||||
|
torch.from_numpy(image).permute(2, 0, 1),
|
||||||
|
padding=randi,
|
||||||
|
padding_mode="constant",
|
||||||
|
fill=(255, 255, 255),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def padding(images: list[torch.Tensor], required_size: int) -> list[torch.Tensor]:
|
||||||
|
images = [
|
||||||
|
v2.functional.pad(
|
||||||
|
img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
|
||||||
|
)
|
||||||
|
for img in images
|
||||||
|
]
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def random_resize(images: list[np.ndarray], minr: float, maxr: float) -> list[np.ndarray]:
|
||||||
|
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
|
||||||
|
raise ValueError("Image is not in RGB format or channel is not in third dimension")
|
||||||
|
|
||||||
|
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
|
||||||
|
return [
|
||||||
|
# Anti-aliasing
|
||||||
|
cv2.resize(
|
||||||
|
img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4
|
||||||
|
)
|
||||||
|
for img, r in zip(images, ratios)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray:
|
||||||
|
# Get the center of the image to define the point of rotation
|
||||||
|
image_center = tuple(np.array(image.shape[1::-1]) / 2)
|
||||||
|
|
||||||
|
# Generate a random angle within the specified range
|
||||||
|
angle = random.randint(min_angle, max_angle)
|
||||||
|
|
||||||
|
# Get the rotation matrix for rotating the image around its center
|
||||||
|
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
||||||
|
|
||||||
|
# Determine the size of the rotated image
|
||||||
|
cos = np.abs(rotation_mat[0, 0])
|
||||||
|
sin = np.abs(rotation_mat[0, 1])
|
||||||
|
new_width = int((image.shape[0] * sin) + (image.shape[1] * cos))
|
||||||
|
new_height = int((image.shape[0] * cos) + (image.shape[1] * sin))
|
||||||
|
|
||||||
|
# Adjust the rotation matrix to take into account translation
|
||||||
|
rotation_mat[0, 2] += (new_width / 2) - image_center[0]
|
||||||
|
rotation_mat[1, 2] += (new_height / 2) - image_center[1]
|
||||||
|
|
||||||
|
# Rotate the image with the specified border color (white in this case)
|
||||||
|
rotated_image = cv2.warpAffine(
|
||||||
|
image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255)
|
||||||
|
)
|
||||||
|
|
||||||
|
return rotated_image
|
||||||
|
|
||||||
|
|
||||||
|
def ocr_aug(image: np.ndarray) -> np.ndarray:
|
||||||
|
if random.random() < 0.2:
|
||||||
|
image = rotate(image, -5, 5)
|
||||||
|
image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy()
|
||||||
|
image = augraphy_pipeline(image)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def train_transform(images: list[Image.Image]) -> list[torch.Tensor]:
|
||||||
|
assert IMG_CHANNELS == 1, "Only support grayscale images for now"
|
||||||
|
|
||||||
|
images = [np.array(img.convert("RGB")) for img in images]
|
||||||
|
# random resize first
|
||||||
|
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
|
||||||
|
images = [trim_white_border(image) for image in images]
|
||||||
|
|
||||||
|
# OCR augmentation
|
||||||
|
images = [ocr_aug(image) for image in images]
|
||||||
|
|
||||||
|
# general transform pipeline
|
||||||
|
images = inference_transform(images)
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def img_train_transform(samples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||||
|
processed_img = train_transform(samples["pixel_values"])
|
||||||
|
samples["pixel_values"] = processed_img
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def img_inf_transform(samples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||||
|
processed_img = inference_transform(samples["pixel_values"])
|
||||||
|
samples["pixel_values"] = processed_img
|
||||||
|
return samples
|
||||||