Initial commit

This commit is contained in:
三洋三洋
2024-02-11 08:06:50 +00:00
commit f057490bdb
56 changed files with 815 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@@ -0,0 +1,7 @@
**/__pycache__
**/.vscode
**/train_result
**/logs
**/.cache
**/tmp*

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 OleehyO
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

5
README.md Normal file
View File

@@ -0,0 +1,5 @@
# TexTeller
## Prerequisites
python=3.10

9
requirements.txt Normal file
View File

@@ -0,0 +1,9 @@
transformers
datasets
evaluate
streamlit
opencv-python
ray[serve]
accelerate
tensorboardX
nltk

11
src/client_demo.py Normal file
View File

@@ -0,0 +1,11 @@
import requests
url = "http://127.0.0.1:8000/predict"
img_path = "/your/image/path/"
data = {"img_path": img_path}
response = requests.post(url, json=data)
print(response.text)

36
src/inference.py Normal file
View File

@@ -0,0 +1,36 @@
import os
import argparse
from pathlib import Path
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
if __name__ == '__main__':
os.chdir(Path(__file__).resolve().parent)
parser = argparse.ArgumentParser()
parser.add_argument(
'-img',
type=str,
required=True,
help='path to the input image'
)
parser.add_argument(
'-cuda',
default=False,
action='store_true',
help='use cuda or not'
)
args = parser.parse_args()
# You can use your own checkpoint and tokenizer path.
print('Loading model and tokenizer...')
model = TexTeller.from_pretrained()
tokenizer = TexTeller.get_tokenizer()
print('Model and tokenizer loaded.')
img_path = [args.img]
print('Inference...')
res = inference(model, tokenizer, img_path, args.cuda)
print(res[0])

26
src/models/globals.py Normal file
View File

@@ -0,0 +1,26 @@
# Formula image(grayscale) mean and variance
IMAGE_MEAN = 0.9545467
IMAGE_STD = 0.15394445
# Density value for pdf to image conversion
TEXTELL_DENSITY = 200
# Vocabulary size for TexTeller
VOCAB_SIZE = 10000
# Fixed size for input image for TexTeller
FIXED_IMG_SIZE = 448
# Image channel for TexTeller
IMG_CHANNELS = 1 # grayscale image
# Max size of token for embedding
MAX_TOKEN_SIZE = 512
# Scaling ratio for random resizing when training
MAX_RESIZE_RATIO = 1.15
MIN_RESIZE_RATIO = 0.75
# Minimum height and width for input image for TexTeller
MIN_HEIGHT = 12
MIN_WIDTH = 30

View File

@@ -0,0 +1,43 @@
from pathlib import Path
from models.globals import (
VOCAB_SIZE,
FIXED_IMG_SIZE,
IMG_CHANNELS,
)
from transformers import (
ViTConfig,
ViTModel,
TrOCRConfig,
TrOCRForCausalLM,
RobertaTokenizerFast,
VisionEncoderDecoderModel,
)
class TexTeller(VisionEncoderDecoderModel):
REPO_NAME = 'OleehyO/TexTeller'
def __init__(self, decoder_path=None, tokenizer_path=None):
encoder = ViTModel(ViTConfig(
image_size=FIXED_IMG_SIZE,
num_channels=IMG_CHANNELS
))
decoder = TrOCRForCausalLM(TrOCRConfig(
vocab_size=VOCAB_SIZE,
))
super().__init__(encoder=encoder, decoder=decoder)
@classmethod
def from_pretrained(cls, model_path: str = None):
if model_path is None or model_path == cls.REPO_NAME:
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
model_path = Path(model_path).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_path))
@classmethod
def get_tokenizer(cls, tokenizer_path: str = None) -> RobertaTokenizerFast:
if tokenizer_path is None or tokenizer_path == cls.REPO_NAME:
return RobertaTokenizerFast.from_pretrained(cls.REPO_NAME)
tokenizer_path = Path(tokenizer_path).resolve()
return RobertaTokenizerFast.from_pretrained(str(tokenizer_path))

View File

@@ -0,0 +1,35 @@
{"img_name": "0.png", "formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
{"img_name": "1.png", "formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
{"img_name": "2.png", "formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
{"img_name": "3.png", "formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
{"img_name": "4.png", "formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
{"img_name": "5.png", "formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
{"img_name": "6.png", "formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
{"img_name": "7.png", "formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
{"img_name": "8.png", "formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
{"img_name": "9.png", "formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
{"img_name": "10.png", "formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
{"img_name": "11.png", "formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
{"img_name": "12.png", "formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
{"img_name": "13.png", "formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
{"img_name": "14.png", "formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
{"img_name": "15.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
{"img_name": "16.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
{"img_name": "17.png", "formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
{"img_name": "18.png", "formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
{"img_name": "19.png", "formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
{"img_name": "20.png", "formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
{"img_name": "21.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
{"img_name": "22.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
{"img_name": "23.png", "formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
{"img_name": "24.png", "formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
{"img_name": "25.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
{"img_name": "26.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
{"img_name": "27.png", "formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
{"img_name": "28.png", "formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
{"img_name": "29.png", "formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"img_name": "30.png", "formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
{"img_name": "31.png", "formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"img_name": "32.png", "formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
{"img_name": "33.png", "formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
{"img_name": "34.png", "formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.2 KiB

View File

@@ -0,0 +1,50 @@
from PIL import Image
from pathlib import Path
import datasets
import json
DIR_URL = Path('absolute/path/to/dataset/directory')
# e.g. DIR_URL = Path('/home/OleehyO/TeXTeller/src/models/ocr_model/train/dataset')
class LatexFormulas(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = []
def _info(self):
return datasets.DatasetInfo(
features=datasets.Features({
"image": datasets.Image(),
"latex_formula": datasets.Value("string")
})
)
def _split_generators(self, dl_manager: datasets.DownloadManager):
dir_path = Path(dl_manager.download(str(DIR_URL)))
assert dir_path.is_dir()
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
'dir_path': dir_path,
}
)
]
def _generate_examples(self, dir_path: Path):
images_path = dir_path / 'images'
formulas_path = dir_path / 'formulas.jsonl'
img2formula = {}
with formulas_path.open('r', encoding='utf-8') as f:
for line in f:
single_json = json.loads(line)
img2formula[single_json['img_name']] = single_json['formula']
for img_path in images_path.iterdir():
if img_path.suffix not in ['.jpg', '.png']:
continue
yield str(img_path), {
"image": Image.open(img_path),
"latex_formula": img2formula[img_path.name]
}

View File

@@ -0,0 +1,103 @@
import os
from functools import partial
from pathlib import Path
from datasets import load_dataset
from transformers import (
Trainer,
TrainingArguments,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
GenerationConfig
)
from .training_args import CONFIG
from ..model.TexTeller import TexTeller
from ..utils.functional import tokenize_fn, collate_fn, img_transform_fn
from ..utils.metrics import bleu_metric
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
training_args = TrainingArguments(**CONFIG)
trainer = Trainer(
model,
training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=collate_fn_with_tokenizer,
)
trainer.train(resume_from_checkpoint=None)
def evaluate(model, tokenizer, eval_dataset, collate_fn):
eval_config = CONFIG.copy()
eval_config['predict_with_generate'] = True
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
eval_config['generation_config'] = generate_config
seq2seq_config = Seq2SeqTrainingArguments(**eval_config)
trainer = Seq2SeqTrainer(
model,
seq2seq_config,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=collate_fn,
compute_metrics=partial(bleu_metric, tokenizer=tokenizer)
)
eval_res = trainer.evaluate()
print(eval_res)
if __name__ == '__main__':
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH)
dataset = dataset.shuffle(seed=42)
dataset = dataset.flatten_indices()
tokenizer = TexTeller.get_tokenizer()
# If you want use your own tokenizer, please modify the path to your tokenizer
#+tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
map_fn = partial(tokenize_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8)
tokenized_dataset = tokenized_dataset.with_transform(img_transform_fn)
# Split dataset into train and eval, ratio 9:1
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# Train from scratch
model = TexTeller()
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
#+e.g.
#+model = TexTeller.from_pretrained(
#+ '/path/to/your/model_checkpoint'
#+)
enable_train = True
enable_evaluate = True
if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
if enable_evaluate and len(eval_dataset) > 0:
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)

View File

@@ -0,0 +1,38 @@
CONFIG = {
"seed": 42, # Random seed for reproducibility
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
"learning_rate": 5e-5, # Learning rate
"num_train_epochs": 10, # Total number of training epochs
"per_device_train_batch_size": 4, # Batch size per GPU for training
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
"output_dir": "train_result", # Output directory
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
"report_to": ["tensorboard"], # Report logs to TensorBoard
"save_strategy": "steps", # Strategy to save checkpoints
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
"logging_strategy": "steps", # Log every certain number of steps
"logging_steps": 500, # Number of steps between each log
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
"optim": "adamw_torch", # Optimizer
"lr_scheduler_type": "cosine", # Learning rate scheduler
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
"eval_steps": 500, # If evaluation_strategy="step"
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
}

View File

@@ -0,0 +1,46 @@
import torch
import numpy as np
from transformers import DataCollatorForLanguageModeling
from typing import List, Dict, Any
from .transforms import train_transform
def left_move(x: torch.Tensor, pad_val):
assert len(x.shape) == 2, 'x should be 2-dimensional'
lefted_x = torch.ones_like(x)
lefted_x[:, :-1] = x[:, 1:]
lefted_x[:, -1] = pad_val
return lefted_x
def tokenize_fn(samples: Dict[str, List[Any]], tokenizer=None) -> Dict[str, List[Any]]:
assert tokenizer is not None, 'tokenizer should not be None'
tokenized_formula = tokenizer(samples['latex_formula'], return_special_tokens_mask=True)
tokenized_formula['pixel_values'] = samples['image']
return tokenized_formula
def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[Any]]:
assert tokenizer is not None, 'tokenizer should not be None'
pixel_values = [dic.pop('pixel_values') for dic in samples]
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
batch = clm_collator(samples)
batch['pixel_values'] = pixel_values
batch['decoder_input_ids'] = batch.pop('input_ids')
batch['decoder_attention_mask'] = batch.pop('attention_mask')
# left shift labels and decoder_attention_mask, padding with -100
batch['labels'] = left_move(batch['labels'], -100)
# convert list of Image to tensor with (B, C, H, W)
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
return batch
def img_transform_fn(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
processed_img = train_transform(samples['pixel_values'])
samples['pixel_values'] = processed_img
return samples

View File

@@ -0,0 +1,26 @@
import cv2
import numpy as np
from typing import List
def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
processed_images = []
for path in image_paths:
image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if image is None:
print(f"Image at {path} could not be read.")
continue
if image.dtype == np.uint16:
print(f'Converting {path} to 8-bit, image may be lossy.')
image = cv2.convertScaleAbs(image, alpha=(255.0/65535.0))
channels = 1 if len(image.shape) == 2 else image.shape[2]
if channels == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
elif channels == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif channels == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
processed_images.append(image)
return processed_images

View File

@@ -0,0 +1,38 @@
import torch
from transformers import RobertaTokenizerFast, GenerationConfig
from typing import List
from models.ocr_model.model.TexTeller import TexTeller
from models.ocr_model.utils.transforms import inference_transform
from models.ocr_model.utils.helpers import convert2rgb
from models.globals import MAX_TOKEN_SIZE
def inference(
model: TexTeller,
tokenizer: RobertaTokenizerFast,
imgs_path: List[str],
use_cuda: bool,
num_beams: int = 1,
) -> List[str]:
model.eval()
imgs = convert2rgb(imgs_path)
imgs = inference_transform(imgs)
pixel_values = torch.stack(imgs)
if use_cuda:
model = model.to('cuda')
pixel_values = pixel_values.to('cuda')
generate_config = GenerationConfig(
max_new_tokens=MAX_TOKEN_SIZE,
num_beams=num_beams,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
)
pred = model.generate(pixel_values, generation_config=generate_config)
res = tokenizer.batch_decode(pred, skip_special_tokens=True)
return res

View File

@@ -0,0 +1,23 @@
import evaluate
import numpy as np
import os
from pathlib import Path
from typing import Dict
from transformers import EvalPrediction, RobertaTokenizer
def bleu_metric(eval_preds: EvalPrediction, tokenizer: RobertaTokenizer) -> Dict:
cur_dir = Path(os.getcwd())
os.chdir(Path(__file__).resolve().parent)
metric = evaluate.load('google_bleu') # Will download the metric from huggingface if not already downloaded
os.chdir(cur_dir)
logits, labels = eval_preds.predictions, eval_preds.label_ids
preds = logits
labels = np.where(labels == -100, 1, labels)
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
return metric.compute(predictions=preds, references=labels)

View File

@@ -0,0 +1,90 @@
import torch
import random
import numpy as np
import cv2
from torchvision.transforms import v2
from typing import List
from PIL import Image
from models.globals import (
FIXED_IMG_SIZE,
IMAGE_MEAN, IMAGE_STD,
MAX_RESIZE_RATIO, MIN_RESIZE_RATIO
)
general_transform_pipeline = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.uint8, scale=True),
v2.Grayscale(),
v2.Resize(
size=FIXED_IMG_SIZE - 1,
interpolation=v2.InterpolationMode.BICUBIC,
max_size=FIXED_IMG_SIZE,
antialias=True
),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[IMAGE_MEAN], std=[IMAGE_STD]),
])
def trim_white_border(image: np.ndarray):
if len(image.shape) != 3 or image.shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
if image.dtype != np.uint8:
raise ValueError(f"Image should stored in uint8")
h, w = image.shape[:2]
bg = np.full((h, w, 3), 255, dtype=np.uint8)
diff = cv2.absdiff(image, bg)
_, diff = cv2.threshold(diff, 1, 255, cv2.THRESH_BINARY)
gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
x, y, w, h = cv2.boundingRect(gray_diff)
trimmed_image = image[y:y+h, x:x+w]
return trimmed_image
def padding(images: List[torch.Tensor], required_size: int):
images = [
v2.functional.pad(
img,
padding=[0, 0, required_size - img.shape[2], required_size - img.shape[1]]
)
for img in images
]
return images
def random_resize(
images: List[np.ndarray],
minr: float,
maxr: float
) -> List[np.ndarray]:
if len(images[0].shape) != 3 or images[0].shape[2] != 3:
raise ValueError("Image is not in RGB format or channel is not in third dimension")
ratios = [random.uniform(minr, maxr) for _ in range(len(images))]
return [
cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LANCZOS4) # 抗锯齿
for img, r in zip(images, ratios)
]
def general_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
images = [trim_white_border(image) for image in images]
images = general_transform_pipeline(images)
images = padding(images, FIXED_IMG_SIZE)
return images
def train_transform(images: List[Image.Image]) -> List[torch.Tensor]:
images = [np.array(img.convert('RGB')) for img in images]
images = random_resize(images, MIN_RESIZE_RATIO, MAX_RESIZE_RATIO)
return general_transform(images)
def inference_transform(images: List[np.ndarray]) -> List[torch.Tensor]:
return general_transform(images)

View File

@@ -0,0 +1,25 @@
import os
from pathlib import Path
from datasets import load_dataset
from ..ocr_model.model.TexTeller import TexTeller
from ..globals import VOCAB_SIZE
if __name__ == '__main__':
script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
tokenizer = TexTeller.get_tokenizer()
# Don't forget to config your dataset path in loader.py
dataset = load_dataset('../ocr_model/train/dataset/loader.py')['train']
new_tokenizer = tokenizer.train_new_from_iterator(
text_iterator=dataset['latex_formula'],
# If you want to use a different vocab size, **change VOCAB_SIZE from globals.py**
vocab_size=VOCAB_SIZE
)
# Save the new tokenizer for later training and inference
new_tokenizer.save_pretrained('./your_dir_name')

81
src/server.py Normal file
View File

@@ -0,0 +1,81 @@
import argparse
import time
from starlette.requests import Request
from ray import serve
from ray.serve.handle import DeploymentHandle
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
parser = argparse.ArgumentParser()
parser.add_argument(
'-ckpt', '--checkpoint_dir', type=str
)
parser.add_argument(
'-tknz', '--tokenizer_dir', type=str
)
parser.add_argument('-port', '--server_port', type=int, default=8000)
parser.add_argument('--num_replicas', type=int, default=1)
parser.add_argument('--ncpu_per_replica', type=float, default=1.0)
parser.add_argument('--ngpu_per_replica', type=float, default=0.0)
parser.add_argument('--use_cuda', action='store_true', default=False)
parser.add_argument('--num_beam', type=int, default=1)
args = parser.parse_args()
if args.ngpu_per_replica > 0 and not args.use_cuda:
raise ValueError("use_cuda must be True if ngpu_per_replica > 0")
@serve.deployment(
num_replicas=args.num_replicas,
ray_actor_options={
"num_cpus": args.ncpu_per_replica,
"num_gpus": args.ngpu_per_replica
}
)
class TexTellerServer:
def __init__(
self,
checkpoint_path: str,
tokenizer_path: str,
use_cuda: bool = False,
num_beam: int = 1
) -> None:
self.model = TexTeller.from_pretrained(checkpoint_path)
self.tokenizer = TexTeller.get_tokenizer(tokenizer_path)
self.use_cuda = use_cuda
self.num_beam = num_beam
self.model = self.model.to('cuda') if use_cuda else self.model
def predict(self, image_path: str) -> str:
return inference(self.model, self.tokenizer, [image_path], self.use_cuda, self.num_beam)[0]
@serve.deployment()
class Ingress:
def __init__(self, texteller_server: DeploymentHandle) -> None:
self.texteller_server = texteller_server
async def __call__(self, request: Request) -> str:
msg = await request.json()
img_path: str = msg['img_path']
pred = await self.texteller_server.predict.remote(img_path)
return pred
if __name__ == '__main__':
ckpt_dir = args.checkpoint_dir
tknz_dir = args.tokenizer_dir
serve.start(http_options={"port": args.server_port})
texteller_server = TexTellerServer.bind(ckpt_dir, tknz_dir, use_cuda=args.use_cuda, num_beam=args.num_beam)
ingress = Ingress.bind(texteller_server)
ingress_handle = serve.run(ingress, route_prefix="/predict")
while True:
time.sleep(1)

9
src/start_web.sh Executable file
View File

@@ -0,0 +1,9 @@
#!/usr/bin/env bash
set -exu
export CHECKPOINT_DIR="OleehyO/TexTeller"
export TOKENIZER_DIR="OleehyO/TexTeller"
export USE_CUDA=False # True or False (case-sensitive)
export NUM_BEAM=1
streamlit run web.py

93
src/web.py Normal file
View File

@@ -0,0 +1,93 @@
import os
import io
import base64
import tempfile
import streamlit as st
from PIL import Image
from models.ocr_model.utils.inference import inference
from models.ocr_model.model.TexTeller import TexTeller
@st.cache_resource
def get_model():
return TexTeller.from_pretrained(os.environ['CHECKPOINT_DIR'])
@st.cache_resource
def get_tokenizer():
return TexTeller.get_tokenizer(os.environ['TOKENIZER_DIR'])
model = get_model()
tokenizer = get_tokenizer()
# ============================ pages =============================== #
html_string = '''
<h1 style="color: orange; text-align: center;">
✨ TexTeller ✨
</h1>
'''
st.markdown(html_string, unsafe_allow_html=True)
if "start" not in st.session_state:
st.balloons()
st.session_state["start"] = 1
uploaded_file = st.file_uploader("",type=['jpg', 'png'])
if uploaded_file:
img = Image.open(uploaded_file)
temp_dir = tempfile.mkdtemp()
png_file_path = os.path.join(temp_dir, 'image.png')
img.save(png_file_path, 'PNG')
def get_image_base64(img_file):
buffered = io.BytesIO()
img_file.seek(0)
img = Image.open(img_file)
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
img_base64 = get_image_base64(uploaded_file)
st.markdown(f"""
<style>
.centered-container {{
text-align: center;
}}
.centered-image {{
display: block;
margin-left: auto;
margin-right: auto;
max-width: 700px;
}}
</style>
<div class="centered-container">
<img src="data:image/png;base64,{img_base64}" class="centered-image" alt="Input image">
<p style="color:gray;">Input image ({img.height}✖️{img.width})</p>
</div>
""", unsafe_allow_html=True)
st.write("")
st.write("")
with st.spinner("Predicting..."):
uploaded_file.seek(0)
TeXTeller_result = inference(
model,
tokenizer,
[png_file_path],
True if os.environ['USE_CUDA'] == 'True' else False,
int(os.environ['NUM_BEAM'])
)[0]
# st.subheader(':rainbow[Predict] :sunglasses:', divider='rainbow')
st.subheader(':sunglasses:', divider='gray')
st.latex(TeXTeller_result)
st.code(TeXTeller_result, language='latex')
st.success('Done!')
# ============================ pages =============================== #