diff --git a/examples/client_demo.py b/examples/client_demo.py new file mode 100644 index 0000000..a6445ad --- /dev/null +++ b/examples/client_demo.py @@ -0,0 +1,10 @@ +import requests + +server_url = "http://127.0.0.1:8000/predict" + +img_path = "/path/to/your/image" +with open(img_path, 'rb') as img: + files = {'img': img} + response = requests.post(server_url, files=files) + +print(response.text) diff --git a/examples/train_texteller/dataset/train/0.png b/examples/train_texteller/dataset/train/0.png new file mode 100644 index 0000000..9f27321 Binary files /dev/null and b/examples/train_texteller/dataset/train/0.png differ diff --git a/examples/train_texteller/dataset/train/1.png b/examples/train_texteller/dataset/train/1.png new file mode 100644 index 0000000..bc65c5f Binary files /dev/null and b/examples/train_texteller/dataset/train/1.png differ diff --git a/examples/train_texteller/dataset/train/10.png b/examples/train_texteller/dataset/train/10.png new file mode 100644 index 0000000..b2306ab Binary files /dev/null and b/examples/train_texteller/dataset/train/10.png differ diff --git a/examples/train_texteller/dataset/train/11.png b/examples/train_texteller/dataset/train/11.png new file mode 100644 index 0000000..f8b20a1 Binary files /dev/null and b/examples/train_texteller/dataset/train/11.png differ diff --git a/examples/train_texteller/dataset/train/12.png b/examples/train_texteller/dataset/train/12.png new file mode 100644 index 0000000..5b3b285 Binary files /dev/null and b/examples/train_texteller/dataset/train/12.png differ diff --git a/examples/train_texteller/dataset/train/13.png b/examples/train_texteller/dataset/train/13.png new file mode 100644 index 0000000..692fcc2 Binary files /dev/null and b/examples/train_texteller/dataset/train/13.png differ diff --git a/examples/train_texteller/dataset/train/14.png b/examples/train_texteller/dataset/train/14.png new file mode 100644 index 0000000..e7fe2fd Binary files /dev/null and b/examples/train_texteller/dataset/train/14.png differ diff --git a/examples/train_texteller/dataset/train/15.png b/examples/train_texteller/dataset/train/15.png new file mode 100644 index 0000000..fbbeb82 Binary files /dev/null and b/examples/train_texteller/dataset/train/15.png differ diff --git a/examples/train_texteller/dataset/train/16.png b/examples/train_texteller/dataset/train/16.png new file mode 100644 index 0000000..be56e99 Binary files /dev/null and b/examples/train_texteller/dataset/train/16.png differ diff --git a/examples/train_texteller/dataset/train/17.png b/examples/train_texteller/dataset/train/17.png new file mode 100644 index 0000000..4f30cf1 Binary files /dev/null and b/examples/train_texteller/dataset/train/17.png differ diff --git a/examples/train_texteller/dataset/train/18.png b/examples/train_texteller/dataset/train/18.png new file mode 100644 index 0000000..8774d25 Binary files /dev/null and b/examples/train_texteller/dataset/train/18.png differ diff --git a/examples/train_texteller/dataset/train/19.png b/examples/train_texteller/dataset/train/19.png new file mode 100644 index 0000000..4d3daa5 Binary files /dev/null and b/examples/train_texteller/dataset/train/19.png differ diff --git a/examples/train_texteller/dataset/train/2.png b/examples/train_texteller/dataset/train/2.png new file mode 100644 index 0000000..8fe5dd9 Binary files /dev/null and b/examples/train_texteller/dataset/train/2.png differ diff --git a/examples/train_texteller/dataset/train/20.png b/examples/train_texteller/dataset/train/20.png new file mode 100644 index 0000000..45c400d Binary files /dev/null and b/examples/train_texteller/dataset/train/20.png differ diff --git a/examples/train_texteller/dataset/train/21.png b/examples/train_texteller/dataset/train/21.png new file mode 100644 index 0000000..311c1fd Binary files /dev/null and b/examples/train_texteller/dataset/train/21.png differ diff --git a/examples/train_texteller/dataset/train/22.png b/examples/train_texteller/dataset/train/22.png new file mode 100644 index 0000000..6273383 Binary files /dev/null and b/examples/train_texteller/dataset/train/22.png differ diff --git a/examples/train_texteller/dataset/train/23.png b/examples/train_texteller/dataset/train/23.png new file mode 100644 index 0000000..06dfcdb Binary files /dev/null and b/examples/train_texteller/dataset/train/23.png differ diff --git a/examples/train_texteller/dataset/train/24.png b/examples/train_texteller/dataset/train/24.png new file mode 100644 index 0000000..c718fd5 Binary files /dev/null and b/examples/train_texteller/dataset/train/24.png differ diff --git a/examples/train_texteller/dataset/train/25.png b/examples/train_texteller/dataset/train/25.png new file mode 100644 index 0000000..b90ab45 Binary files /dev/null and b/examples/train_texteller/dataset/train/25.png differ diff --git a/examples/train_texteller/dataset/train/26.png b/examples/train_texteller/dataset/train/26.png new file mode 100644 index 0000000..087e6de Binary files /dev/null and b/examples/train_texteller/dataset/train/26.png differ diff --git a/examples/train_texteller/dataset/train/27.png b/examples/train_texteller/dataset/train/27.png new file mode 100644 index 0000000..67f552c Binary files /dev/null and b/examples/train_texteller/dataset/train/27.png differ diff --git a/examples/train_texteller/dataset/train/28.png b/examples/train_texteller/dataset/train/28.png new file mode 100644 index 0000000..3b29359 Binary files /dev/null and b/examples/train_texteller/dataset/train/28.png differ diff --git a/examples/train_texteller/dataset/train/29.png b/examples/train_texteller/dataset/train/29.png new file mode 100644 index 0000000..917e0ed Binary files /dev/null and b/examples/train_texteller/dataset/train/29.png differ diff --git a/examples/train_texteller/dataset/train/3.png b/examples/train_texteller/dataset/train/3.png new file mode 100644 index 0000000..0354b68 Binary files /dev/null and b/examples/train_texteller/dataset/train/3.png differ diff --git a/examples/train_texteller/dataset/train/30.png b/examples/train_texteller/dataset/train/30.png new file mode 100644 index 0000000..cb38168 Binary files /dev/null and b/examples/train_texteller/dataset/train/30.png differ diff --git a/examples/train_texteller/dataset/train/31.png b/examples/train_texteller/dataset/train/31.png new file mode 100644 index 0000000..973f951 Binary files /dev/null and b/examples/train_texteller/dataset/train/31.png differ diff --git a/examples/train_texteller/dataset/train/32.png b/examples/train_texteller/dataset/train/32.png new file mode 100644 index 0000000..7c019a5 Binary files /dev/null and b/examples/train_texteller/dataset/train/32.png differ diff --git a/examples/train_texteller/dataset/train/33.png b/examples/train_texteller/dataset/train/33.png new file mode 100644 index 0000000..172ff55 Binary files /dev/null and b/examples/train_texteller/dataset/train/33.png differ diff --git a/examples/train_texteller/dataset/train/34.png b/examples/train_texteller/dataset/train/34.png new file mode 100644 index 0000000..013c1cc Binary files /dev/null and b/examples/train_texteller/dataset/train/34.png differ diff --git a/examples/train_texteller/dataset/train/4.png b/examples/train_texteller/dataset/train/4.png new file mode 100644 index 0000000..b8b0e39 Binary files /dev/null and b/examples/train_texteller/dataset/train/4.png differ diff --git a/examples/train_texteller/dataset/train/5.png b/examples/train_texteller/dataset/train/5.png new file mode 100644 index 0000000..db3af1f Binary files /dev/null and b/examples/train_texteller/dataset/train/5.png differ diff --git a/examples/train_texteller/dataset/train/6.png b/examples/train_texteller/dataset/train/6.png new file mode 100644 index 0000000..c171137 Binary files /dev/null and b/examples/train_texteller/dataset/train/6.png differ diff --git a/examples/train_texteller/dataset/train/7.png b/examples/train_texteller/dataset/train/7.png new file mode 100644 index 0000000..9c2f9a6 Binary files /dev/null and b/examples/train_texteller/dataset/train/7.png differ diff --git a/examples/train_texteller/dataset/train/8.png b/examples/train_texteller/dataset/train/8.png new file mode 100644 index 0000000..54e300a Binary files /dev/null and b/examples/train_texteller/dataset/train/8.png differ diff --git a/examples/train_texteller/dataset/train/9.png b/examples/train_texteller/dataset/train/9.png new file mode 100644 index 0000000..9bf24fb Binary files /dev/null and b/examples/train_texteller/dataset/train/9.png differ diff --git a/examples/train_texteller/dataset/train/metadata.jsonl b/examples/train_texteller/dataset/train/metadata.jsonl new file mode 100644 index 0000000..23279de --- /dev/null +++ b/examples/train_texteller/dataset/train/metadata.jsonl @@ -0,0 +1,35 @@ +{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"} +{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"} +{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"} +{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"} +{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"} +{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"} +{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"} +{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"} +{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"} +{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"} +{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"} +{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"} +{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"} +{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"} +{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"} +{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"} +{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"} +{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"} +{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"} +{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"} +{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"} +{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"} +{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"} +{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"} +{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"} +{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"} +{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"} +{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"} +{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"} +{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"} +{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"} +{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"} +{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"} +{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"} +{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"} diff --git a/examples/train_texteller/train.py b/examples/train_texteller/train.py new file mode 100644 index 0000000..10c5996 --- /dev/null +++ b/examples/train_texteller/train.py @@ -0,0 +1,71 @@ +from functools import partial + +import yaml +from datasets import load_dataset +from transformers import ( + Trainer, + TrainingArguments, +) + +from texteller import load_model, load_tokenizer +from texteller.constants import MIN_HEIGHT, MIN_WIDTH + +from examples.train_texteller.utils import ( + collate_fn, + filter_fn, + img_inf_transform, + img_train_transform, + tokenize_fn, +) + + +def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer): + training_args = TrainingArguments(**training_config) + trainer = Trainer( + model, + training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + data_collator=collate_fn_with_tokenizer, + ) + + trainer.train(resume_from_checkpoint=None) + + +if __name__ == "__main__": + dataset = load_dataset("imagefolder", data_dir="dataset")["train"] + dataset = dataset.filter( + lambda x: x["image"].height > MIN_HEIGHT and x["image"].width > MIN_WIDTH + ) + dataset = dataset.shuffle(seed=42) + dataset = dataset.flatten_indices() + + tokenizer = load_tokenizer() + # If you want use your own tokenizer, please modify the path to your tokenizer + # tokenizer = load_tokenizer("/path/to/your/tokenizer") + filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer) + dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8) + + map_fn = partial(tokenize_fn, tokenizer=tokenizer) + tokenized_dataset = dataset.map( + map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8 + ) + + # Split dataset into train and eval, ratio 9:1 + split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42) + train_dataset, eval_dataset = split_dataset["train"], split_dataset["test"] + train_dataset = train_dataset.with_transform(img_train_transform) + eval_dataset = eval_dataset.with_transform(img_inf_transform) + collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) + + # Train from scratch + model = load_model() + + # If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint + # model = load_model("/path/to/your/model_checkpoint") + + enable_train = True + training_config = yaml.safe_load(open("train_config.yaml")) + if enable_train: + train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer) diff --git a/examples/train_texteller/train_config.yaml b/examples/train_texteller/train_config.yaml new file mode 100644 index 0000000..5bd7953 --- /dev/null +++ b/examples/train_texteller/train_config.yaml @@ -0,0 +1,32 @@ +# For more information, please refer to the official documentation: https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments + +seed: 42 # Random seed for reproducibility +use_cpu: false # Whether to use CPU (it's easier to debug with CPU when starting to test the code) +learning_rate: 5.0e-5 # Learning rate +num_train_epochs: 10 # Total number of training epochs +per_device_train_batch_size: 4 # Batch size per GPU for training +per_device_eval_batch_size: 8 # Batch size per GPU for evaluation +output_dir: "train_result" # Output directory +overwrite_output_dir: false # If the output directory exists, do not delete its content +report_to: + - tensorboard # Report logs to TensorBoard +save_strategy: "steps" # Strategy to save checkpoints +save_steps: 500 # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000) +save_total_limit: 5 # Maximum number of models to save. The oldest models will be deleted if this number is exceeded +logging_strategy: "steps" # Log every certain number of steps +logging_steps: 500 # Number of steps between each log +logging_nan_inf_filter: false # Record logs for loss=nan or inf +optim: "adamw_torch" # Optimizer +lr_scheduler_type: "cosine" # Learning rate scheduler +warmup_ratio: 0.1 # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr) +max_grad_norm: 1.0 # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0) +fp16: false # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode) +bf16: false # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it) +gradient_accumulation_steps: 1 # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large +jit_mode_eval: false # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors) +torch_compile: false # Whether to use torch.compile to compile the model (for better training and inference performance) +dataloader_pin_memory: true # Can speed up data transfer between CPU and GPU +dataloader_num_workers: 1 # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used +evaluation_strategy: "steps" # Evaluation strategy, can be "steps" or "epoch" +eval_steps: 500 # If evaluation_strategy="step" +remove_unused_columns: false # Don't change this unless you really know what you are doing. diff --git a/examples/train_texteller/utils/__init__.py b/examples/train_texteller/utils/__init__.py new file mode 100644 index 0000000..6ae22c1 --- /dev/null +++ b/examples/train_texteller/utils/__init__.py @@ -0,0 +1,17 @@ +from .functional import ( + collate_fn, + filter_fn, + tokenize_fn, +) +from .transforms import ( + img_train_transform, + img_inf_transform, +) + +__all__ = [ + "collate_fn", + "filter_fn", + "tokenize_fn", + "img_train_transform", + "img_inf_transform", +] diff --git a/examples/train_texteller/utils/augraphy_pipe.py b/examples/train_texteller/utils/augraphy_pipe.py new file mode 100644 index 0000000..fa038ae --- /dev/null +++ b/examples/train_texteller/utils/augraphy_pipe.py @@ -0,0 +1,165 @@ +""" +Custom augraphy pipeline for training + +This file implements a custom augraphy data augmentation pipeline. We found that using augraphy's +default pipeline can cause significant degradation to formula images, potentially losing semantic +information. Therefore, we carefully selected several common augmentation effects, +adjusting their parameters and combination methods to preserve the original semantic information +of the images as much as possible. +""" + +from augraphy import ( + InkColorSwap, + LinesDegradation, + OneOf, + Dithering, + InkBleed, + InkShifter, + NoiseTexturize, + BrightnessTexturize, + ColorShift, + DirtyDrum, + LightingGradient, + Brightness, + Gamma, + SubtleNoise, + Jpeg, + AugraphyPipeline, +) +import random + + +def get_custom_augraphy(): + pre_phase = [] + + ink_phase = [ + InkColorSwap( + ink_swap_color="random", + ink_swap_sequence_number_range=(5, 10), + ink_swap_min_width_range=(2, 3), + ink_swap_max_width_range=(100, 120), + ink_swap_min_height_range=(2, 3), + ink_swap_max_height_range=(100, 120), + ink_swap_min_area_range=(10, 20), + ink_swap_max_area_range=(400, 500), + p=0.2, + ), + LinesDegradation( + line_roi=(0.0, 0.0, 1.0, 1.0), + line_gradient_range=(32, 255), + line_gradient_direction=(0, 2), + line_split_probability=(0.2, 0.4), + line_replacement_value=(250, 255), + line_min_length=(30, 40), + line_long_to_short_ratio=(5, 7), + line_replacement_probability=(0.4, 0.5), + line_replacement_thickness=(1, 3), + p=0.2, + ), + # ============================ + OneOf( + [ + Dithering( + dither="floyd-steinberg", + order=(3, 5), + ), + InkBleed( + intensity_range=(0.1, 0.2), + kernel_size=random.choice([(7, 7), (5, 5), (3, 3)]), + severity=(0.4, 0.6), + ), + ], + p=0.2, + ), + # ============================ + # ============================ + InkShifter( + text_shift_scale_range=(18, 27), + text_shift_factor_range=(1, 4), + text_fade_range=(0, 2), + blur_kernel_size=(5, 5), + blur_sigma=0, + noise_type="perlin", + p=0.2, + ), + # ============================ + ] + + paper_phase = [ + NoiseTexturize( + sigma_range=(3, 10), + turbulence_range=(2, 5), + texture_width_range=(300, 500), + texture_height_range=(300, 500), + p=0.2, + ), + BrightnessTexturize(texturize_range=(0.9, 0.99), deviation=0.03, p=0.2), + ] + + post_phase = [ + ColorShift( + color_shift_offset_x_range=(3, 5), + color_shift_offset_y_range=(3, 5), + color_shift_iterations=(2, 3), + color_shift_brightness_range=(0.9, 1.1), + color_shift_gaussian_kernel_range=(3, 3), + p=0.2, + ), + DirtyDrum( + line_width_range=(1, 6), + line_concentration=random.uniform(0.05, 0.15), + direction=random.randint(0, 2), + noise_intensity=random.uniform(0.6, 0.95), + noise_value=(64, 224), + ksize=random.choice([(3, 3), (5, 5), (7, 7)]), + sigmaX=0, + p=0.2, + ), + # ===================================== + OneOf( + [ + LightingGradient( + light_position=None, + direction=None, + max_brightness=255, + min_brightness=0, + mode="gaussian", + linear_decay_rate=None, + transparency=None, + ), + Brightness( + brightness_range=(0.9, 1.1), + min_brightness=0, + min_brightness_value=(120, 150), + ), + Gamma( + gamma_range=(0.9, 1.1), + ), + ], + p=0.2, + ), + # ===================================== + # ===================================== + OneOf( + [ + SubtleNoise( + subtle_range=random.randint(5, 10), + ), + Jpeg( + quality_range=(70, 95), + ), + ], + p=0.2, + ), + # ===================================== + ] + + pipeline = AugraphyPipeline( + ink_phase=ink_phase, + paper_phase=paper_phase, + post_phase=post_phase, + pre_phase=pre_phase, + log=False, + ) + + return pipeline diff --git a/examples/train_texteller/utils/functional.py b/examples/train_texteller/utils/functional.py new file mode 100644 index 0000000..6f5c09e --- /dev/null +++ b/examples/train_texteller/utils/functional.py @@ -0,0 +1,47 @@ +from typing import Any + +import torch +from transformers import DataCollatorForLanguageModeling + +from texteller.constants import MAX_TOKEN_SIZE, MIN_HEIGHT, MIN_WIDTH + + +def _left_move(x: torch.Tensor, pad_val): + assert len(x.shape) == 2, "x should be 2-dimensional" + lefted_x = torch.ones_like(x) + lefted_x[:, :-1] = x[:, 1:] + lefted_x[:, -1] = pad_val + return lefted_x + + +def tokenize_fn(samples: dict[str, list[Any]], tokenizer=None) -> dict[str, list[Any]]: + assert tokenizer is not None, "tokenizer should not be None" + tokenized_formula = tokenizer(samples["latex_formula"], return_special_tokens_mask=True) + tokenized_formula["pixel_values"] = samples["image"] + return tokenized_formula + + +def collate_fn(samples: list[dict[str, Any]], tokenizer=None) -> dict[str, list[Any]]: + assert tokenizer is not None, "tokenizer should not be None" + pixel_values = [dic.pop("pixel_values") for dic in samples] + + clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + batch = clm_collator(samples) + batch["pixel_values"] = pixel_values + batch["decoder_input_ids"] = batch.pop("input_ids") + batch["decoder_attention_mask"] = batch.pop("attention_mask") + + batch["labels"] = _left_move(batch["labels"], -100) + + # convert list of Image to a tensor with (B, C, H, W) + batch["pixel_values"] = torch.stack(batch["pixel_values"], dim=0) + return batch + + +def filter_fn(sample, tokenizer=None) -> bool: + return ( + sample["image"].height > MIN_HEIGHT + and sample["image"].width > MIN_WIDTH + and len(tokenizer(sample["latex_formula"])["input_ids"]) < MAX_TOKEN_SIZE - 10 + ) diff --git a/examples/train_texteller/utils/transforms.py b/examples/train_texteller/utils/transforms.py new file mode 100644 index 0000000..b72998c --- /dev/null +++ b/examples/train_texteller/utils/transforms.py @@ -0,0 +1,154 @@ +import torch +import random +import numpy as np +import cv2 + +from torchvision.transforms import v2 +from typing import Any +from PIL import Image +from collections import Counter + +from texteller.constants import ( + IMG_CHANNELS, + MAX_RESIZE_RATIO, + MIN_RESIZE_RATIO, +) +from texteller.utils import transform as inference_transform +from .augraphy_pipe import get_custom_augraphy + +augraphy_pipeline = get_custom_augraphy() + + +def trim_white_border(image: np.ndarray): + if len(image.shape) != 3 or image.shape[2] != 3: + raise ValueError("Image is not in RGB format or channel is not in third dimension") + + if image.dtype != np.uint8: + raise ValueError(f"Image should stored in uint8") + + corners = [tuple(image[0, 0]), tuple(image[0, -1]), tuple(image[-1, 0]), tuple(image[-1, -1])] + bg_color = Counter(corners).most_common(1)[0][0] + bg_color_np = np.array(bg_color, dtype=np.uint8) + + h, w = image.shape[:2] + bg = np.full((h, w, 3), bg_color_np, dtype=np.uint8) + + diff = cv2.absdiff(image, bg) + mask = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY) + + threshold = 15 + _, diff = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY) + + x, y, w, h = cv2.boundingRect(diff) + + trimmed_image = image[y : y + h, x : x + w] + + return trimmed_image + + +def add_white_border(image: np.ndarray, max_size: int) -> np.ndarray: + randi = [random.randint(0, max_size) for _ in range(4)] + pad_height_size = randi[1] + randi[3] + pad_width_size = randi[0] + randi[2] + if pad_height_size + image.shape[0] < 30: + compensate_height = int((30 - (pad_height_size + image.shape[0])) * 0.5) + 1 + randi[1] += compensate_height + randi[3] += compensate_height + if pad_width_size + image.shape[1] < 30: + compensate_width = int((30 - (pad_width_size + image.shape[1])) * 0.5) + 1 + randi[0] += compensate_width + randi[2] += compensate_width + return v2.functional.pad( + torch.from_numpy(image).permute(2, 0, 1), + padding=randi, + padding_mode="constant", + fill=(255, 255, 255), + ) + + +def padding(images: list[torch.Tensor], required_size: int) -> list[torch.Tensor]: + images = [ + v2.functional.pad( + img, padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]] + ) + for img in images + ] + return images + + +def random_resize(images: list[np.ndarray], minr: float, maxr: float) -> list[np.ndarray]: + if len(images[0].shape) != 3 or images[0].shape[2] != 3: + raise ValueError("Image is not in RGB format or channel is not in third dimension") + + ratios = [random.uniform(minr, maxr) for _ in range(len(images))] + return [ + # Anti-aliasing + cv2.resize( + img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4 + ) + for img, r in zip(images, ratios) + ] + + +def rotate(image: np.ndarray, min_angle: int, max_angle: int) -> np.ndarray: + # Get the center of the image to define the point of rotation + image_center = tuple(np.array(image.shape[1::-1]) / 2) + + # Generate a random angle within the specified range + angle = random.randint(min_angle, max_angle) + + # Get the rotation matrix for rotating the image around its center + rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) + + # Determine the size of the rotated image + cos = np.abs(rotation_mat[0, 0]) + sin = np.abs(rotation_mat[0, 1]) + new_width = int((image.shape[0] * sin) + (image.shape[1] * cos)) + new_height = int((image.shape[0] * cos) + (image.shape[1] * sin)) + + # Adjust the rotation matrix to take into account translation + rotation_mat[0, 2] += (new_width / 2) - image_center[0] + rotation_mat[1, 2] += (new_height / 2) - image_center[1] + + # Rotate the image with the specified border color (white in this case) + rotated_image = cv2.warpAffine( + image, rotation_mat, (new_width, new_height), borderValue=(255, 255, 255) + ) + + return rotated_image + + +def ocr_aug(image: np.ndarray) -> np.ndarray: + if random.random() < 0.2: + image = rotate(image, -5, 5) + image = add_white_border(image, max_size=25).permute(1, 2, 0).numpy() + image = augraphy_pipeline(image) + return image + + +def train_transform(images: list[Image.Image]) -> list[torch.Tensor]: + assert IMG_CHANNELS == 1, "Only support grayscale images for now" + + images = [np.array(img.convert("RGB")) for img in images] + # random resize first + images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO) + images = [trim_white_border(image) for image in images] + + # OCR augmentation + images = [ocr_aug(image) for image in images] + + # general transform pipeline + images = inference_transform(images) + return images + + +def img_train_transform(samples: dict[str, list[Any]]) -> dict[str, list[Any]]: + processed_img = train_transform(samples["pixel_values"]) + samples["pixel_values"] = processed_img + return samples + + +def img_inf_transform(samples: dict[str, list[Any]]) -> dict[str, list[Any]]: + processed_img = inference_transform(samples["pixel_values"]) + samples["pixel_values"] = processed_img + return samples