[chore] exclude paddleocr directory from pre-commit hooks

This commit is contained in:
三洋三洋
2025-02-28 19:56:49 +08:00
parent a8a005ae10
commit 3d546f9993
130 changed files with 592 additions and 739 deletions

View File

@@ -4,8 +4,10 @@ repos:
hooks: hooks:
- id: ruff - id: ruff
args: [--fix, --respect-gitignore, --config=pyproject.toml] args: [--fix, --respect-gitignore, --config=pyproject.toml]
exclude: ^texteller/models/thrid_party/paddleocr/
- id: ruff-format - id: ruff-format
args: [--config=pyproject.toml] args: [--config=pyproject.toml]
exclude: ^texteller/models/thrid_party/paddleocr/
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0 rev: v4.5.0

View File

@@ -1,35 +0,0 @@
{"img_name": "0.png", "formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
{"img_name": "1.png", "formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
{"img_name": "2.png", "formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
{"img_name": "3.png", "formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
{"img_name": "4.png", "formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
{"img_name": "5.png", "formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
{"img_name": "6.png", "formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
{"img_name": "7.png", "formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
{"img_name": "8.png", "formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
{"img_name": "9.png", "formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
{"img_name": "10.png", "formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
{"img_name": "11.png", "formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
{"img_name": "12.png", "formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
{"img_name": "13.png", "formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
{"img_name": "14.png", "formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
{"img_name": "15.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
{"img_name": "16.png", "formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
{"img_name": "17.png", "formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
{"img_name": "18.png", "formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
{"img_name": "19.png", "formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
{"img_name": "20.png", "formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
{"img_name": "21.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
{"img_name": "22.png", "formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
{"img_name": "23.png", "formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
{"img_name": "24.png", "formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
{"img_name": "25.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
{"img_name": "26.png", "formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
{"img_name": "27.png", "formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
{"img_name": "28.png", "formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
{"img_name": "29.png", "formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"img_name": "30.png", "formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
{"img_name": "31.png", "formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"img_name": "32.png", "formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
{"img_name": "33.png", "formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
{"img_name": "34.png", "formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}

View File

@@ -1,50 +0,0 @@
from PIL import Image
from pathlib import Path
import datasets
import json
DIR_URL = Path('absolute/path/to/dataset/directory')
# e.g. DIR_URL = Path('/home/OleehyO/TeXTeller/src/models/ocr_model/train/dataset')
class LatexFormulas(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = []
def _info(self):
return datasets.DatasetInfo(
features=datasets.Features({
"image": datasets.Image(),
"latex_formula": datasets.Value("string")
})
)
def _split_generators(self, dl_manager: datasets.DownloadManager):
dir_path = Path(dl_manager.download(str(DIR_URL)))
assert dir_path.is_dir()
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
'dir_path': dir_path,
}
)
]
def _generate_examples(self, dir_path: Path):
images_path = dir_path / 'images'
formulas_path = dir_path / 'formulas.jsonl'
img2formula = {}
with formulas_path.open('r', encoding='utf-8') as f:
for line in f:
single_json = json.loads(line)
img2formula[single_json['img_name']] = single_json['formula']
for img_path in images_path.iterdir():
if img_path.suffix not in ['.jpg', '.png']:
continue
yield str(img_path), {
"image": Image.open(img_path),
"latex_formula": img2formula[img_path.name]
}

View File

@@ -1,38 +0,0 @@
CONFIG = {
"seed": 42, # Random seed for reproducibility
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
"learning_rate": 5e-5, # Learning rate
"num_train_epochs": 10, # Total number of training epochs
"per_device_train_batch_size": 4, # Batch size per GPU for training
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
"output_dir": "train_result", # Output directory
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
"report_to": ["tensorboard"], # Report logs to TensorBoard
"save_strategy": "steps", # Strategy to save checkpoints
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
"logging_strategy": "steps", # Log every certain number of steps
"logging_steps": 500, # Number of steps between each log
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
"optim": "adamw_torch", # Optimizer
"lr_scheduler_type": "cosine", # Learning rate scheduler
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
"eval_steps": 500, # If evaluation_strategy="step"
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
}

View File

@@ -1 +0,0 @@
from .mix_inference import mix_inference

View File

@@ -1,65 +0,0 @@
import os
import argparse
import cv2 as cv
from pathlib import Path
from models.ocr_model.utils.to_katex import to_katex
from models.ocr_model.utils.inference import inference as latex_inference
from models.ocr_model.model.TexTeller import TexTeller
if __name__ == '__main__':
os.chdir(Path(__file__).resolve().parent)
parser = argparse.ArgumentParser()
parser.add_argument(
'-img_dir',
type=str,
help='path to the input image',
default='./detect_results/subimages'
)
parser.add_argument(
'-output_dir',
type=str,
help='path to the output dir',
default='./rec_results'
)
parser.add_argument(
'--inference-mode',
type=str,
default='cpu',
help='Inference mode, select one of cpu, cuda, or mps'
)
parser.add_argument(
'--num-beam',
type=int,
default=1,
help='number of beam search for decoding'
)
args = parser.parse_args()
print('Loading model and tokenizer...')
latex_rec_model = TexTeller.from_pretrained()
tokenizer = TexTeller.get_tokenizer()
print('Model and tokenizer loaded.')
# Create the output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)
# Loop through all images in the input directory
for filename in os.listdir(args.img_dir):
img_path = os.path.join(args.img_dir, filename)
img = cv.imread(img_path)
if img is not None:
print(f'Inference for {filename}...')
res = latex_inference(latex_rec_model, tokenizer, [img], accelerator=args.inference_mode, num_beams=args.num_beam)
res = to_katex(res[0])
# Save the recognition result to a text file
output_file = os.path.join(args.output_dir, os.path.splitext(filename)[0] + '.txt')
with open(output_file, 'w') as f:
f.write(res)
print(f'Result saved to {output_file}')
else:
print(f"Warning: Could not read image {img_path}. Skipping...")

View File

@@ -1,85 +1,96 @@
import os import os
import argparse import argparse
import glob import glob
import subprocess import subprocess
import onnxruntime import onnxruntime
from pathlib import Path from pathlib import Path
from models.det_model.inference import PredictConfig, predict_image from models.det_model.inference import PredictConfig, predict_image
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml", parser.add_argument(
default="./models/det_model/model/infer_cfg.yml") "--infer_cfg", type=str, help="infer_cfg.yml", default="./models/det_model/model/infer_cfg.yml"
parser.add_argument('--onnx_file', type=str, help="onnx model file path", )
default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx") parser.add_argument(
parser.add_argument("--image_dir", type=str, default='./testImgs') '--onnx_file',
parser.add_argument("--image_file", type=str) type=str,
parser.add_argument("--imgsave_dir", type=str, default="./detect_results") help="onnx model file path",
parser.add_argument('--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True) default="./models/det_model/model/rtdetr_r50vd_6x_coco.onnx",
)
parser.add_argument("--image_dir", type=str, default='./testImgs')
def get_test_images(infer_dir, infer_img): parser.add_argument("--image_file", type=str)
""" parser.add_argument("--imgsave_dir", type=str, default="./detect_results")
Get image path list in TEST mode parser.add_argument(
""" '--use_gpu', action='store_true', help='Whether to use GPU for inference', default=True
assert infer_img is not None or infer_dir is not None, \ )
"--image_file or --image_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), \
"{} is not a file".format(infer_img) def get_test_images(infer_dir, infer_img):
assert infer_dir is None or os.path.isdir(infer_dir), \ """
"{} is not a directory".format(infer_dir) Get image path list in TEST mode
"""
# infer_img has a higher priority assert (
if infer_img and os.path.isfile(infer_img): infer_img is not None or infer_dir is not None
return [infer_img] ), "--image_file or --image_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), "{} is not a file".format(infer_img)
images = set() assert infer_dir is None or os.path.isdir(infer_dir), "{} is not a directory".format(infer_dir)
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \ # infer_img has a higher priority
"infer_dir {} is not a directory".format(infer_dir) if infer_img and os.path.isfile(infer_img):
exts = ['jpg', 'jpeg', 'png', 'bmp'] return [infer_img]
exts += [ext.upper() for ext in exts]
for ext in exts: images = set()
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext))) infer_dir = os.path.abspath(infer_dir)
images = list(images) assert os.path.isdir(infer_dir), "infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp']
assert len(images) > 0, "no image found in {}".format(infer_dir) exts += [ext.upper() for ext in exts]
print("Found {} inference images in total.".format(len(images))) for ext in exts:
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
return images images = list(images)
def download_file(url, filename): assert len(images) > 0, "no image found in {}".format(infer_dir)
print(f"Downloading {filename}...") print("Found {} inference images in total.".format(len(images)))
subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True)
print("Download complete.") return images
if __name__ == '__main__':
cur_path = os.getcwd() def download_file(url, filename):
script_dirpath = Path(__file__).resolve().parent print(f"Downloading {filename}...")
os.chdir(script_dirpath) subprocess.run(["wget", "-q", "--show-progress", "-O", filename, url], check=True)
print("Download complete.")
FLAGS = parser.parse_args()
if not os.path.exists(FLAGS.infer_cfg): if __name__ == '__main__':
infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true" cur_path = os.getcwd()
download_file(infer_cfg_url, FLAGS.infer_cfg) script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath)
if not os.path.exists(FLAGS.onnx_file):
onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true" FLAGS = parser.parse_args()
download_file(onnx_file_url, FLAGS.onnx_file)
if not os.path.exists(FLAGS.infer_cfg):
# load image list infer_cfg_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/infer_cfg.yml?download=true"
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) download_file(infer_cfg_url, FLAGS.infer_cfg)
if FLAGS.use_gpu: if not os.path.exists(FLAGS.onnx_file):
predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CUDAExecutionProvider']) onnx_file_url = "https://huggingface.co/TonyLee1256/texteller_det/resolve/main/rtdetr_r50vd_6x_coco.onnx?download=true"
else: download_file(onnx_file_url, FLAGS.onnx_file)
predictor = onnxruntime.InferenceSession(FLAGS.onnx_file, providers=['CPUExecutionProvider'])
# load infer config # load image list
infer_config = PredictConfig(FLAGS.infer_cfg) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list) if FLAGS.use_gpu:
predictor = onnxruntime.InferenceSession(
os.chdir(cur_path) FLAGS.onnx_file, providers=['CUDAExecutionProvider']
)
else:
predictor = onnxruntime.InferenceSession(
FLAGS.onnx_file, providers=['CPUExecutionProvider']
)
# load infer config
infer_config = PredictConfig(FLAGS.infer_cfg)
predict_image(FLAGS.imgsave_dir, infer_config, predictor, img_list)
os.chdir(cur_path)

View File

@@ -18,32 +18,20 @@ from models.det_model.inference import PredictConfig
if __name__ == '__main__': if __name__ == '__main__':
os.chdir(Path(__file__).resolve().parent) os.chdir(Path(__file__).resolve().parent)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-img', type=str, required=True, help='path to the input image')
parser.add_argument( parser.add_argument(
'-img', '--inference-mode',
type=str,
required=True,
help='path to the input image'
)
parser.add_argument(
'--inference-mode',
type=str, type=str,
default='cpu', default='cpu',
help='Inference mode, select one of cpu, cuda, or mps' help='Inference mode, select one of cpu, cuda, or mps',
) )
parser.add_argument( parser.add_argument(
'--num-beam', '--num-beam', type=int, default=1, help='number of beam search for decoding'
type=int,
default=1,
help='number of beam search for decoding'
) )
parser.add_argument( parser.add_argument('-mix', action='store_true', help='use mix mode')
'-mix',
action='store_true',
help='use mix mode'
)
args = parser.parse_args() args = parser.parse_args()
# You can use your own checkpoint and tokenizer path. # You can use your own checkpoint and tokenizer path.
print('Loading model and tokenizer...') print('Loading model and tokenizer...')
latex_rec_model = TexTeller.from_pretrained() latex_rec_model = TexTeller.from_pretrained()
@@ -63,8 +51,8 @@ if __name__ == '__main__':
use_gpu = args.inference_mode == 'cuda' use_gpu = args.inference_mode == 'cuda'
SIZE_LIMIT = 20 * 1024 * 1024 SIZE_LIMIT = 20 * 1024 * 1024
det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx" det_model_dir = "./models/thrid_party/paddleocr/checkpoints/det/default_model.onnx"
rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx" rec_model_dir = "./models/thrid_party/paddleocr/checkpoints/rec/default_model.onnx"
# The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime) # The CPU inference of the detection model will be faster than the GPU inference (in onnxruntime)
det_use_gpu = False det_use_gpu = False
rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT) rec_use_gpu = use_gpu and not (os.path.getsize(rec_model_dir) < SIZE_LIMIT)
@@ -78,8 +66,16 @@ if __name__ == '__main__':
detector = predict_det.TextDetector(paddleocr_args) detector = predict_det.TextDetector(paddleocr_args)
paddleocr_args.use_gpu = rec_use_gpu paddleocr_args.use_gpu = rec_use_gpu
recognizer = predict_rec.TextRecognizer(paddleocr_args) recognizer = predict_rec.TextRecognizer(paddleocr_args)
lang_ocr_models = [detector, recognizer] lang_ocr_models = [detector, recognizer]
latex_rec_models = [latex_rec_model, tokenizer] latex_rec_models = [latex_rec_model, tokenizer]
res = mix_inference(img_path, infer_config, latex_det_model, lang_ocr_models, latex_rec_models, args.inference_mode, args.num_beam) res = mix_inference(
img_path,
infer_config,
latex_det_model,
lang_ocr_models,
latex_rec_models,
args.inference_mode,
args.num_beam,
)
print(res) print(res)

Binary file not shown.

View File

@@ -9,7 +9,7 @@ class Point:
def __init__(self, x: int, y: int): def __init__(self, x: int, y: int):
self.x = int(x) self.x = int(x)
self.y = int(y) self.y = int(y)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Point(x={self.x}, y={self.y})" return f"Point(x={self.x}, y={self.y})"
@@ -28,30 +28,28 @@ class Bbox:
@property @property
def ul_point(self) -> Point: def ul_point(self) -> Point:
return self.p return self.p
@property @property
def ur_point(self) -> Point: def ur_point(self) -> Point:
return Point(self.p.x + self.w, self.p.y) return Point(self.p.x + self.w, self.p.y)
@property @property
def ll_point(self) -> Point: def ll_point(self) -> Point:
return Point(self.p.x, self.p.y + self.h) return Point(self.p.x, self.p.y + self.h)
@property @property
def lr_point(self) -> Point: def lr_point(self) -> Point:
return Point(self.p.x + self.w, self.p.y + self.h) return Point(self.p.x + self.w, self.p.y + self.h)
def same_row(self, other) -> bool: def same_row(self, other) -> bool:
if ( if (self.p.y >= other.p.y and self.ll_point.y <= other.ll_point.y) or (
(self.p.y >= other.p.y and self.ll_point.y <= other.ll_point.y) self.p.y <= other.p.y and self.ll_point.y >= other.ll_point.y
or (self.p.y <= other.p.y and self.ll_point.y >= other.ll_point.y)
): ):
return True return True
if self.ll_point.y <= other.p.y or self.p.y >= other.ll_point.y: if self.ll_point.y <= other.p.y or self.p.y >= other.ll_point.y:
return False return False
return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD return 1.0 * abs(self.p.y - other.p.y) / max(self.h, other.h) < self.THREADHOLD
def __lt__(self, other) -> bool: def __lt__(self, other) -> bool:
''' '''
from top to bottom, from left to right from top to bottom, from left to right
@@ -60,7 +58,7 @@ class Bbox:
return self.p.y < other.p.y return self.p.y < other.p.y
else: else:
return self.p.x < other.p.x return self.p.x < other.p.x
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})" return f"Bbox(upper_left_point={self.p}, h={self.h}, w={self.w}), label={self.label}, confident={self.confidence}, content={self.content})"
@@ -76,16 +74,16 @@ def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"
top = bbox.p.y top = bbox.p.y
right = bbox.p.x + bbox.w right = bbox.p.x + bbox.w
bottom = bbox.p.y + bbox.h bottom = bbox.p.y + bbox.h
# Draw the rectangle on the image # Draw the rectangle on the image
drawer.rectangle([left, top, right, bottom], outline="green", width=1) drawer.rectangle([left, top, right, bottom], outline="green", width=1)
# Optionally, add text label if it exists # Optionally, add text label if it exists
if bbox.label: if bbox.label:
drawer.text((left, top), bbox.label, fill="blue") drawer.text((left, top), bbox.label, fill="blue")
if bbox.content: if bbox.content:
drawer.text((left, bottom - 10), bbox.content[:10], fill="red") drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
# Save the image with drawn rectangles # Save the image with drawn rectangles
img.save(log_dir / name) img.save(log_dir / name)

View File

@@ -12,10 +12,28 @@ from .Bbox import Bbox
# Global dictionary # Global dictionary
SUPPORT_MODELS = { SUPPORT_MODELS = {
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'YOLO',
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'PPYOLOE',
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet', 'RCNN',
'DETR' 'SSD',
'Face',
'FCOS',
'SOLOv2',
'TTFNet',
'S2ANet',
'JDE',
'FairMOT',
'DeepSORT',
'GFL',
'PicoDet',
'CenterNet',
'TOOD',
'RetinaNet',
'StrongBaseline',
'STGCN',
'YOLOX',
'HRNet',
'DETR',
} }
@@ -42,12 +60,12 @@ class PredictConfig(object):
self.fpn_stride = yml_conf.get("fpn_stride", None) self.fpn_stride = yml_conf.get("fpn_stride", None)
color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)] color_pool = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
self.colors = {label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)} self.colors = {
label: color_pool[i % len(color_pool)] for i, label in enumerate(self.label_list)
}
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False): if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
print( print('The RCNN export model is used for ONNX and it only supports batch_size = 1')
'The RCNN export model is used for ONNX and it only supports batch_size = 1'
)
self.print_config() self.print_config()
def check_model(self, yml_conf): def check_model(self, yml_conf):
@@ -58,8 +76,7 @@ class PredictConfig(object):
for support_model in SUPPORT_MODELS: for support_model in SUPPORT_MODELS:
if support_model in yml_conf['arch']: if support_model in yml_conf['arch']:
return True return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[ raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf['arch'], SUPPORT_MODELS))
'arch'], SUPPORT_MODELS))
def print_config(self): def print_config(self):
print('----------- Model Configuration -----------') print('----------- Model Configuration -----------')
@@ -77,8 +94,15 @@ def draw_bbox(image, outputs, infer_config):
label = infer_config.label_list[int(cls_id)] label = infer_config.label_list[int(cls_id)]
color = infer_config.colors[label] color = infer_config.colors[label]
cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2) cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)
cv2.putText(image, "{}: {:.2f}".format(label, score), cv2.putText(
(int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) image,
"{}: {:.2f}".format(label, score),
(int(xmin), int(ymin - 5)),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
color,
2,
)
return image return image
@@ -104,7 +128,7 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
inputs = transforms(img_path) inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()] inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None, ] for k in inputs_name} inputs = {k: inputs[k][None,] for k in inputs_name}
# Start timing # Start timing
start_time = time.time() start_time = time.time()
@@ -119,7 +143,9 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
else: else:
total_time += inference_time total_time += inference_time
num_images += 1 num_images += 1
print(f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds") print(
f"ONNXRuntime predict time for {os.path.basename(img_path)}: {inference_time:.4f} seconds"
)
print("ONNXRuntime predict: ") print("ONNXRuntime predict: ")
if infer_config.arch in ["HRNet"]: if infer_config.arch in ["HRNet"]:
@@ -128,8 +154,7 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
bboxes = np.array(outputs[0]) bboxes = np.array(outputs[0])
for bbox in bboxes: for bbox in bboxes:
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold: if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
print(f"{int(bbox[0])} {bbox[1]} " print(f"{int(bbox[0])} {bbox[1]} " f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
# Save the subimages (crop from the original image) # Save the subimages (crop from the original image)
subimg_counter = 1 subimg_counter = 1
@@ -137,7 +162,7 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
cls_id, score, xmin, ymin, xmax, ymax = output cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold: if score > infer_config.draw_threshold:
label = infer_config.label_list[int(cls_id)] label = infer_config.label_list[int(cls_id)]
subimg = img[int(max(ymin, 0)):int(ymax), int(max(xmin, 0)):int(xmax)] subimg = img[int(max(ymin, 0)) : int(ymax), int(max(xmin, 0)) : int(xmax)]
if len(subimg) == 0: if len(subimg) == 0:
continue continue
@@ -151,8 +176,14 @@ def predict_image(imgsave_dir, infer_config, predictor, img_list):
for output in np.array(outputs[0]): for output in np.array(outputs[0]):
cls_id, score, xmin, ymin, xmax, ymax = output cls_id, score, xmin, ymin, xmax, ymax = output
if score > infer_config.draw_threshold: if score > infer_config.draw_threshold:
cv2.rectangle(img_with_mask, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 255, 255), -1) # 盖白 cv2.rectangle(
img_with_mask,
(int(xmin), int(ymin)),
(int(xmax), int(ymax)),
(255, 255, 255),
-1,
) # 盖白
img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config) img_with_bbox = draw_bbox(img, np.array(outputs[0]), infer_config)
output_dir = imgsave_dir output_dir = imgsave_dir
@@ -178,7 +209,7 @@ def predict(img_path: str, predictor, infer_config) -> List[Bbox]:
transforms = Compose(infer_config.preprocess_infos) transforms = Compose(infer_config.preprocess_infos)
inputs = transforms(img_path) inputs = transforms(img_path)
inputs_name = [var.name for var in predictor.get_inputs()] inputs_name = [var.name for var in predictor.get_inputs()]
inputs = {k: inputs[k][None, ] for k in inputs_name} inputs = {k: inputs[k][None,] for k in inputs_name}
outputs = predictor.run(output_names=None, input_feed=inputs)[0] outputs = predictor.run(output_names=None, input_feed=inputs)[0]
res = [] res = []

View File

@@ -15,10 +15,8 @@ def decode_image(img_path):
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
img_info = { img_info = {
"im_shape": np.array( "im_shape": np.array(im.shape[:2], dtype=np.float32),
im.shape[:2], dtype=np.float32), "scale_factor": np.array([1.0, 1.0], dtype=np.float32),
"scale_factor": np.array(
[1., 1.], dtype=np.float32)
} }
return im, img_info return im, img_info
@@ -51,16 +49,9 @@ class Resize(object):
assert self.target_size[0] > 0 and self.target_size[1] > 0 assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2] im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im) im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize( im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp)
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array( im_info['scale_factor'] = np.array([im_scale_y, im_scale_x]).astype('float32')
[im_scale_y, im_scale_x]).astype('float32')
return im, im_info return im, im_info
def generate_scale(self, im): def generate_scale(self, im):
@@ -134,7 +125,9 @@ class Permute(object):
channel_first (bool): whether convert HWC to CHW channel_first (bool): whether convert HWC to CHW
""" """
def __init__(self, ): def __init__(
self,
):
super(Permute, self).__init__() super(Permute, self).__init__()
def __call__(self, im, im_info): def __call__(self, im, im_info):
@@ -151,7 +144,7 @@ class Permute(object):
class PadStride(object): class PadStride(object):
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config """padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args: Args:
stride (bool): model with FPN need image shape % stride == 0 stride (bool): model with FPN need image shape % stride == 0
""" """
@@ -198,18 +191,16 @@ class LetterBoxResize(object):
ratio_h = float(height) / shape[0] ratio_h = float(height) / shape[0]
ratio_w = float(width) / shape[1] ratio_w = float(width) / shape[1]
ratio = min(ratio_h, ratio_w) ratio = min(ratio_h, ratio_w)
new_shape = (round(shape[1] * ratio), new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) # [width, height]
round(shape[0] * ratio)) # [width, height]
padw = (width - new_shape[0]) / 2 padw = (width - new_shape[0]) / 2
padh = (height - new_shape[1]) / 2 padh = (height - new_shape[1]) / 2
top, bottom = round(padh - 0.1), round(padh + 0.1) top, bottom = round(padh - 0.1), round(padh + 0.1)
left, right = round(padw - 0.1), round(padw + 0.1) left, right = round(padw - 0.1), round(padw + 0.1)
img = cv2.resize( img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
img = cv2.copyMakeBorder( img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
value=color) # padded rectangular ) # padded rectangular
return img, ratio, padw, padh return img, ratio, padw, padh
def __call__(self, im, im_info): def __call__(self, im, im_info):
@@ -302,12 +293,7 @@ def _get_3rd_point(a, b):
return third_pt return third_pt
def get_affine_transform(center, def get_affine_transform(center, input_size, rot, output_size, shift=(0.0, 0.0), inv=False):
input_size,
rot,
output_size,
shift=(0., 0.),
inv=False):
"""Get the affine transform matrix, given the center/scale/rot/output_size. """Get the affine transform matrix, given the center/scale/rot/output_size.
Args: Args:
@@ -337,8 +323,8 @@ def get_affine_transform(center,
dst_h = output_size[1] dst_h = output_size[1]
rot_rad = np.pi * rot / 180 rot_rad = np.pi * rot / 180
src_dir = rotate_point([0., src_w * -0.5], rot_rad) src_dir = rotate_point([0.0, src_w * -0.5], rot_rad)
dst_dir = np.array([0., dst_w * -0.5]) dst_dir = np.array([0.0, dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32) src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift src[0, :] = center + scale_tmp * shift
@@ -359,16 +345,9 @@ def get_affine_transform(center,
class WarpAffine(object): class WarpAffine(object):
"""Warp affine the image """Warp affine the image"""
"""
def __init__(self, def __init__(self, keep_res=False, pad=31, input_h=512, input_w=512, scale=0.4, shift=0.1):
keep_res=False,
pad=31,
input_h=512,
input_w=512,
scale=0.4,
shift=0.1):
self.keep_res = keep_res self.keep_res = keep_res
self.pad = pad self.pad = pad
self.input_h = input_h self.input_h = input_h
@@ -398,12 +377,11 @@ class WarpAffine(object):
else: else:
s = max(h, w) * 1.0 s = max(h, w) * 1.0
input_h, input_w = self.input_h, self.input_w input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2., h / 2.], dtype=np.float32) c = np.array([w / 2.0, h / 2.0], dtype=np.float32)
trans_input = get_affine_transform(c, s, 0, [input_w, input_h]) trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
img = cv2.resize(img, (w, h)) img = cv2.resize(img, (w, h))
inp = cv2.warpAffine( inp = cv2.warpAffine(img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
return inp, im_info return inp, im_info
@@ -432,13 +410,17 @@ def get_warp_matrix(theta, size_input, size_dst, size_target):
matrix[0, 0] = np.cos(theta) * scale_x matrix[0, 0] = np.cos(theta) * scale_x
matrix[0, 1] = -np.sin(theta) * scale_x matrix[0, 1] = -np.sin(theta) * scale_x
matrix[0, 2] = scale_x * ( matrix[0, 2] = scale_x * (
-0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] * -0.5 * size_input[0] * np.cos(theta)
np.sin(theta) + 0.5 * size_target[0]) + 0.5 * size_input[1] * np.sin(theta)
+ 0.5 * size_target[0]
)
matrix[1, 0] = np.sin(theta) * scale_y matrix[1, 0] = np.sin(theta) * scale_y
matrix[1, 1] = np.cos(theta) * scale_y matrix[1, 1] = np.cos(theta) * scale_y
matrix[1, 2] = scale_y * ( matrix[1, 2] = scale_y * (
-0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] * -0.5 * size_input[0] * np.sin(theta)
np.cos(theta) + 0.5 * size_target[1]) - 0.5 * size_input[1] * np.cos(theta)
+ 0.5 * size_target[1]
)
return matrix return matrix
@@ -462,22 +444,26 @@ class TopDownEvalAffine(object):
def __call__(self, image, im_info): def __call__(self, image, im_info):
rot = 0 rot = 0
imshape = im_info['im_shape'][::-1] imshape = im_info['im_shape'][::-1]
center = im_info['center'] if 'center' in im_info else imshape / 2. center = im_info['center'] if 'center' in im_info else imshape / 2.0
scale = im_info['scale'] if 'scale' in im_info else imshape scale = im_info['scale'] if 'scale' in im_info else imshape
if self.use_udp: if self.use_udp:
trans = get_warp_matrix( trans = get_warp_matrix(
rot, center * 2.0, rot, center * 2.0, [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale
[self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale) )
image = cv2.warpAffine( image = cv2.warpAffine(
image, image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])), trans,
flags=cv2.INTER_LINEAR) (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR,
)
else: else:
trans = get_affine_transform(center, scale, rot, self.trainsize) trans = get_affine_transform(center, scale, rot, self.trainsize)
image = cv2.warpAffine( image = cv2.warpAffine(
image, image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])), trans,
flags=cv2.INTER_LINEAR) (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR,
)
return image, im_info return image, im_info

View File

@@ -1,6 +1,6 @@
# Formula image(grayscale) mean and variance # Formula image(grayscale) mean and variance
IMAGE_MEAN = 0.9545467 IMAGE_MEAN = 0.9545467
IMAGE_STD = 0.15394445 IMAGE_STD = 0.15394445
# Vocabulary size for TexTeller # Vocabulary size for TexTeller
VOCAB_SIZE = 15000 VOCAB_SIZE = 15000
@@ -20,4 +20,4 @@ MIN_RESIZE_RATIO = 0.75
# Minimum height and width for input image for TexTeller # Minimum height and width for input image for TexTeller
MIN_HEIGHT = 12 MIN_HEIGHT = 12
MIN_WIDTH = 30 MIN_WIDTH = 30

View File

@@ -1,30 +1,24 @@
from pathlib import Path from pathlib import Path
from ...globals import ( from ...globals import VOCAB_SIZE, FIXED_IMG_SIZE, IMG_CHANNELS, MAX_TOKEN_SIZE
VOCAB_SIZE,
FIXED_IMG_SIZE,
IMG_CHANNELS,
MAX_TOKEN_SIZE
)
from transformers import ( from transformers import RobertaTokenizerFast, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
RobertaTokenizerFast,
VisionEncoderDecoderModel,
VisionEncoderDecoderConfig
)
class TexTeller(VisionEncoderDecoderModel): class TexTeller(VisionEncoderDecoderModel):
REPO_NAME = 'OleehyO/TexTeller' REPO_NAME = 'OleehyO/TexTeller'
def __init__(self): def __init__(self):
config = VisionEncoderDecoderConfig.from_pretrained(Path(__file__).resolve().parent / "config.json") config = VisionEncoderDecoderConfig.from_pretrained(
config.encoder.image_size = FIXED_IMG_SIZE Path(__file__).resolve().parent / "config.json"
config.encoder.num_channels = IMG_CHANNELS )
config.decoder.vocab_size = VOCAB_SIZE config.encoder.image_size = FIXED_IMG_SIZE
config.encoder.num_channels = IMG_CHANNELS
config.decoder.vocab_size = VOCAB_SIZE
config.decoder.max_position_embeddings = MAX_TOKEN_SIZE config.decoder.max_position_embeddings = MAX_TOKEN_SIZE
super().__init__(config=config) super().__init__(config=config)
@classmethod @classmethod
def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None): def from_pretrained(cls, model_path: str = None, use_onnx=False, onnx_provider=None):
if model_path is None or model_path == 'default': if model_path is None or model_path == 'default':
@@ -32,8 +26,12 @@ class TexTeller(VisionEncoderDecoderModel):
return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME) return VisionEncoderDecoderModel.from_pretrained(cls.REPO_NAME)
else: else:
from optimum.onnxruntime import ORTModelForVision2Seq from optimum.onnxruntime import ORTModelForVision2Seq
use_gpu = True if onnx_provider == 'cuda' else False use_gpu = True if onnx_provider == 'cuda' else False
return ORTModelForVision2Seq.from_pretrained(cls.REPO_NAME, provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider") return ORTModelForVision2Seq.from_pretrained(
cls.REPO_NAME,
provider="CUDAExecutionProvider" if use_gpu else "CPUExecutionProvider",
)
model_path = Path(model_path).resolve() model_path = Path(model_path).resolve()
return VisionEncoderDecoderModel.from_pretrained(str(model_path)) return VisionEncoderDecoderModel.from_pretrained(str(model_path))

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

Before

Width:  |  Height:  |  Size: 3.1 KiB

After

Width:  |  Height:  |  Size: 3.1 KiB

View File

Before

Width:  |  Height:  |  Size: 8.7 KiB

After

Width:  |  Height:  |  Size: 8.7 KiB

View File

Before

Width:  |  Height:  |  Size: 6.8 KiB

After

Width:  |  Height:  |  Size: 6.8 KiB

View File

Before

Width:  |  Height:  |  Size: 4.1 KiB

After

Width:  |  Height:  |  Size: 4.1 KiB

View File

Before

Width:  |  Height:  |  Size: 5.2 KiB

After

Width:  |  Height:  |  Size: 5.2 KiB

View File

Before

Width:  |  Height:  |  Size: 12 KiB

After

Width:  |  Height:  |  Size: 12 KiB

View File

Before

Width:  |  Height:  |  Size: 2.8 KiB

After

Width:  |  Height:  |  Size: 2.8 KiB

View File

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 2.2 KiB

View File

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 2.2 KiB

View File

Before

Width:  |  Height:  |  Size: 2.6 KiB

After

Width:  |  Height:  |  Size: 2.6 KiB

View File

Before

Width:  |  Height:  |  Size: 3.1 KiB

After

Width:  |  Height:  |  Size: 3.1 KiB

View File

Before

Width:  |  Height:  |  Size: 2.7 KiB

After

Width:  |  Height:  |  Size: 2.7 KiB

View File

Before

Width:  |  Height:  |  Size: 3.9 KiB

After

Width:  |  Height:  |  Size: 3.9 KiB

View File

Before

Width:  |  Height:  |  Size: 3.9 KiB

After

Width:  |  Height:  |  Size: 3.9 KiB

View File

Before

Width:  |  Height:  |  Size: 2.9 KiB

After

Width:  |  Height:  |  Size: 2.9 KiB

View File

Before

Width:  |  Height:  |  Size: 3.7 KiB

After

Width:  |  Height:  |  Size: 3.7 KiB

View File

Before

Width:  |  Height:  |  Size: 3.5 KiB

After

Width:  |  Height:  |  Size: 3.5 KiB

View File

Before

Width:  |  Height:  |  Size: 3.1 KiB

After

Width:  |  Height:  |  Size: 3.1 KiB

View File

Before

Width:  |  Height:  |  Size: 2.5 KiB

After

Width:  |  Height:  |  Size: 2.5 KiB

View File

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 2.2 KiB

View File

Before

Width:  |  Height:  |  Size: 3.1 KiB

After

Width:  |  Height:  |  Size: 3.1 KiB

View File

Before

Width:  |  Height:  |  Size: 2.9 KiB

After

Width:  |  Height:  |  Size: 2.9 KiB

View File

Before

Width:  |  Height:  |  Size: 5.3 KiB

After

Width:  |  Height:  |  Size: 5.3 KiB

View File

Before

Width:  |  Height:  |  Size: 4.1 KiB

After

Width:  |  Height:  |  Size: 4.1 KiB

View File

Before

Width:  |  Height:  |  Size: 3.9 KiB

After

Width:  |  Height:  |  Size: 3.9 KiB

View File

Before

Width:  |  Height:  |  Size: 4.9 KiB

After

Width:  |  Height:  |  Size: 4.9 KiB

View File

Before

Width:  |  Height:  |  Size: 2.9 KiB

After

Width:  |  Height:  |  Size: 2.9 KiB

View File

Before

Width:  |  Height:  |  Size: 1.8 KiB

After

Width:  |  Height:  |  Size: 1.8 KiB

View File

Before

Width:  |  Height:  |  Size: 3.2 KiB

After

Width:  |  Height:  |  Size: 3.2 KiB

View File

Before

Width:  |  Height:  |  Size: 5.7 KiB

After

Width:  |  Height:  |  Size: 5.7 KiB

View File

Before

Width:  |  Height:  |  Size: 11 KiB

After

Width:  |  Height:  |  Size: 11 KiB

View File

Before

Width:  |  Height:  |  Size: 4.8 KiB

After

Width:  |  Height:  |  Size: 4.8 KiB

View File

Before

Width:  |  Height:  |  Size: 4.5 KiB

After

Width:  |  Height:  |  Size: 4.5 KiB

View File

Before

Width:  |  Height:  |  Size: 2.5 KiB

After

Width:  |  Height:  |  Size: 2.5 KiB

View File

Before

Width:  |  Height:  |  Size: 5.2 KiB

After

Width:  |  Height:  |  Size: 5.2 KiB

View File

@@ -0,0 +1,35 @@
{"file_name": "0.png", "latex_formula": "\\[\\mathbb{C}^{4}\\stackrel{{\\pi_{1}}}{{\\longleftarrow}}\\mathcal{ F}\\stackrel{{\\pi_{2}}}{{\\rightarrow}}\\mathcal{PT},\\]"}
{"file_name": "1.png", "latex_formula": "\\[W^{*}_{Z}(x_{1},x_{2})=W_{f\\lrcorner Z}(y_{1},y_{2})=\\mathcal{P}\\exp\\left( \\int_{\\gamma}A_{\\mu}dx^{\\mu}\\right).\\]"}
{"file_name": "2.png", "latex_formula": "\\[G=W^{*}_{Z}(q,p)=\\tilde{H}H^{-1}\\]"}
{"file_name": "3.png", "latex_formula": "\\[H=W^{*}_{Z}(p,x),\\ \\ \\tilde{H}=W^{*}_{Z}(q,x).\\]"}
{"file_name": "4.png", "latex_formula": "\\[v\\cdot f^{*}A|_{x}=(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)},\\quad x\\in Z, \\ v\\in T_{x}Z.\\]"}
{"file_name": "5.png", "latex_formula": "\\[(f\\lrcorner Z)_{*}v\\cdot A|_{f\\lrcorner Z(x)}=v^{\\alpha\\dot{\\alpha}}\\Big{(} \\frac{\\partial y^{\\beta\\dot{\\beta}}}{\\partial x^{\\alpha\\dot{\\alpha}}}A_{\\beta \\dot{\\beta}}\\Big{)}\\Big{|}_{f\\lrcorner Z(x)},\\ x\\in Z,\\ v\\in T_{x}Z,\\]"}
{"file_name": "6.png", "latex_formula": "\\[\\{T_{i},T_{j}\\}=\\{\\tilde{T}^{i},\\tilde{T}^{j}\\}=0,\\ \\ \\{T_{i},\\tilde{T}^{j}\\}=2i \\delta^{j}_{i}D,\\]"}
{"file_name": "7.png", "latex_formula": "\\[(\\partial_{s},q_{i},\\tilde{q}^{k})\\rightarrow(D,M^{j}_{i}T_{j},\\tilde{M}^{k}_ {l}\\tilde{T}^{l}),\\]"}
{"file_name": "8.png", "latex_formula": "\\[M^{i}_{j}\\tilde{M}^{j}_{k}=\\delta^{i}_{k}.\\]"}
{"file_name": "9.png", "latex_formula": "\\[Q_{i\\alpha}=q_{i\\alpha}+\\omega_{i\\alpha},\\ \\tilde{Q}^{i}_{\\dot{\\alpha}}=q^{i}_{ \\dot{\\alpha}}+\\tilde{\\omega}^{i}_{\\dot{\\alpha}},\\ D_{\\alpha\\dot{\\alpha}}= \\partial_{\\alpha\\dot{\\alpha}}+A_{\\alpha\\dot{\\alpha}}.\\]"}
{"file_name": "10.png", "latex_formula": "\\[\\hat{f}(g,\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{j})=(f(g),[V^{-1}]^ {\\alpha}_{\\beta}\\theta^{i\\beta},[\\tilde{V}^{-1}]^{\\dot{\\alpha}}_{\\dot{\\beta}} \\tilde{\\theta}^{\\dot{\\beta}}_{j}),\\ g\\in{\\cal G},\\]"}
{"file_name": "11.png", "latex_formula": "\\[v^{\\beta\\dot{\\beta}}V^{\\alpha}_{\\beta}\\tilde{V}^{\\dot{\\alpha}}_{\\dot{\\beta}} =((f\\lrcorner L_{0})_{*}v)^{\\alpha\\dot{\\alpha}},\\]"}
{"file_name": "12.png", "latex_formula": "\\[\\omega_{i\\alpha}=\\tilde{\\theta}^{\\dot{\\alpha}}_{i}h_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\ \\ \\tilde{\\omega}^{i}_{\\alpha}=\\theta^{i\\alpha}\\tilde{h}_{\\alpha\\dot{\\alpha}}(x^{ \\beta\\dot{\\beta}},\\tau^{\\beta\\dot{\\beta}}),\\]"}
{"file_name": "13.png", "latex_formula": "\\[\\begin{split}&\\lambda^{\\alpha}\\hat{f}^{*}\\omega_{i\\alpha}(z)= \\tilde{\\theta}^{\\dot{\\beta}}_{i}\\lambda^{\\alpha}\\left(V^{\\beta}_{\\alpha}h_{ \\beta\\dot{\\beta}}(x^{\\prime},\\tau^{\\prime})\\right),\\\\ &\\tilde{\\lambda}^{\\dot{\\alpha}}\\hat{f}^{*}\\tilde{\\omega}^{i}_{ \\dot{\\alpha}}(z)=\\theta^{i\\beta}\\tilde{\\lambda}^{\\dot{\\alpha}}\\left(\\tilde{V}^ {\\dot{\\beta}}_{\\dot{\\alpha}}\\tilde{h}_{\\beta\\dot{\\beta}}(x^{\\prime},\\tau^{ \\prime})\\right),\\end{split}\\]"}
{"file_name": "14.png", "latex_formula": "\\[A_{\\alpha\\dot{\\alpha}}=A_{\\alpha\\dot{\\alpha}}(x^{\\beta\\dot{\\beta}},\\tau^{ \\beta\\dot{\\beta}})\\]"}
{"file_name": "15.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}D_{\\alpha\\dot{\\alpha}}\\]"}
{"file_name": "16.png", "latex_formula": "\\[D=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}\\]"}
{"file_name": "17.png", "latex_formula": "\\[[v_{1}\\cdot D^{*},v_{2}\\cdot D^{*}]=0\\]"}
{"file_name": "18.png", "latex_formula": "\\[\\Phi_{A}=(\\omega_{i\\alpha},\\tilde{\\omega}^{i}_{\\dot{\\alpha}},A_{\\alpha\\dot{ \\alpha}})\\]"}
{"file_name": "19.png", "latex_formula": "\\[\\hat{f}:{\\cal F}^{6|4N}\\rightarrow{\\cal F}^{6|4N}\\]"}
{"file_name": "20.png", "latex_formula": "\\[\\sigma=(s,\\xi^{i},\\tilde{\\xi}_{j})\\in\\mathbb{C}^{1|2N}\\]"}
{"file_name": "21.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}(h_{\\alpha\\dot{\\alpha}}+\\tilde{h}_{\\alpha\\dot{\\alpha} })=0\\]"}
{"file_name": "22.png", "latex_formula": "\\[\\tau^{\\alpha\\dot{\\alpha}}\\rightarrow[V^{-1}]^{\\alpha}_{\\beta}[\\tilde{V}^{-1}]^{ \\dot{\\alpha}}_{\\dot{\\beta}}\\tau^{\\beta\\dot{\\beta}}\\]"}
{"file_name": "23.png", "latex_formula": "\\[\\tau^{\\beta\\dot{\\beta}}=\\sum_{i}\\theta^{i\\beta}\\tilde{\\theta}^{\\dot{\\beta}}_{i}\\]"}
{"file_name": "24.png", "latex_formula": "\\[\\theta^{i\\alpha}\\omega_{i\\alpha}+\\tilde{\\theta}^{i}_{\\dot{\\alpha}}\\tilde{ \\omega}^{\\dot{\\alpha}}_{i}=0\\]"}
{"file_name": "25.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{Q}^{i}_{\\dot{\\alpha}}\\]"}
{"file_name": "26.png", "latex_formula": "\\[\\tilde{T}^{i}=\\tilde{\\lambda}^{\\dot{\\alpha}}\\tilde{q}^{i}_{\\dot{\\alpha}}\\]"}
{"file_name": "27.png", "latex_formula": "\\[\\tilde{\\lambda}^{\\dot{\\alpha}}f^{*}A_{\\alpha\\dot{\\alpha}}=H^{-1}\\tilde{ \\lambda}^{\\dot{\\alpha}}\\partial_{\\alpha\\dot{\\alpha}}H\\]"}
{"file_name": "28.png", "latex_formula": "\\[\\tilde{q}^{i}=\\partial_{\\tilde{\\xi}_{i}}+i\\xi^{i}\\partial_{s}\\]"}
{"file_name": "29.png", "latex_formula": "\\[\\tilde{q}^{i}_{\\dot{\\alpha}}=\\frac{\\partial}{\\partial\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}}+i\\theta^{i\\alpha}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"file_name": "30.png", "latex_formula": "\\[f\\lrcorner L(z)=\\pi_{1}\\circ f(z,\\lambda,\\tilde{\\lambda})\\ \\forall z\\in L\\]"}
{"file_name": "31.png", "latex_formula": "\\[q_{i\\alpha}=\\frac{\\partial}{\\partial\\theta^{i\\alpha}}+i\\tilde{\\theta}^{\\dot{ \\alpha}}_{i}\\frac{\\partial}{\\partial x^{\\alpha\\dot{\\alpha}}}\\]"}
{"file_name": "32.png", "latex_formula": "\\[q_{i}=\\partial_{\\xi^{i}}+i\\tilde{\\xi}_{i}\\partial_{s}\\]"}
{"file_name": "33.png", "latex_formula": "\\[v^{\\alpha\\dot{\\alpha}}=\\lambda^{\\alpha}\\tilde{\\lambda}^{\\dot{\\alpha}}\\]"}
{"file_name": "34.png", "latex_formula": "\\[z^{A}=(x^{\\alpha\\dot{\\alpha}},\\theta^{i\\alpha},\\tilde{\\theta}^{\\dot{\\alpha}}_{ j})\\]"}

View File

@@ -5,18 +5,24 @@ from pathlib import Path
from datasets import load_dataset from datasets import load_dataset
from transformers import ( from transformers import (
Trainer, Trainer,
TrainingArguments, TrainingArguments,
Seq2SeqTrainer, Seq2SeqTrainer,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
GenerationConfig GenerationConfig,
) )
from .training_args import CONFIG from .training_args import CONFIG
from ..model.TexTeller import TexTeller from ..model.TexTeller import TexTeller
from ..utils.functional import tokenize_fn, collate_fn, img_train_transform, img_inf_transform, filter_fn from ..utils.functional import (
tokenize_fn,
collate_fn,
img_train_transform,
img_inf_transform,
filter_fn,
)
from ..utils.metrics import bleu_metric from ..utils.metrics import bleu_metric
from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT from ...globals import MAX_TOKEN_SIZE, MIN_WIDTH, MIN_HEIGHT
def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer): def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer):
@@ -24,11 +30,9 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
trainer = Trainer( trainer = Trainer(
model, model,
training_args, training_args,
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
tokenizer=tokenizer,
tokenizer=tokenizer,
data_collator=collate_fn_with_tokenizer, data_collator=collate_fn_with_tokenizer,
) )
@@ -52,43 +56,44 @@ def evaluate(model, tokenizer, eval_dataset, collate_fn):
trainer = Seq2SeqTrainer( trainer = Seq2SeqTrainer(
model, model,
seq2seq_config, seq2seq_config,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=collate_fn, data_collator=collate_fn,
compute_metrics=partial(bleu_metric, tokenizer=tokenizer) compute_metrics=partial(bleu_metric, tokenizer=tokenizer),
) )
eval_res = trainer.evaluate() eval_res = trainer.evaluate()
print(eval_res) print(eval_res)
if __name__ == '__main__': if __name__ == '__main__':
script_dirpath = Path(__file__).resolve().parent script_dirpath = Path(__file__).resolve().parent
os.chdir(script_dirpath) os.chdir(script_dirpath)
dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train'] # dataset = load_dataset(str(Path('./dataset/loader.py').resolve()))['train']
dataset = dataset.filter(lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH) dataset = load_dataset("imagefolder", data_dir=str(script_dirpath / 'dataset'))['train']
dataset = dataset.filter(
lambda x: x['image'].height > MIN_HEIGHT and x['image'].width > MIN_WIDTH
)
dataset = dataset.shuffle(seed=42) dataset = dataset.shuffle(seed=42)
dataset = dataset.flatten_indices() dataset = dataset.flatten_indices()
tokenizer = TexTeller.get_tokenizer() tokenizer = TexTeller.get_tokenizer()
# If you want use your own tokenizer, please modify the path to your tokenizer # If you want use your own tokenizer, please modify the path to your tokenizer
#+tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer') # +tokenizer = TexTeller.get_tokenizer('/path/to/your/tokenizer')
filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer) filter_fn_with_tokenizer = partial(filter_fn, tokenizer=tokenizer)
dataset = dataset.filter( dataset = dataset.filter(filter_fn_with_tokenizer, num_proc=8)
filter_fn_with_tokenizer,
num_proc=8
)
map_fn = partial(tokenize_fn, tokenizer=tokenizer) map_fn = partial(tokenize_fn, tokenizer=tokenizer)
tokenized_dataset = dataset.map(map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8) tokenized_dataset = dataset.map(
map_fn, batched=True, remove_columns=dataset.column_names, num_proc=8
)
# Split dataset into train and eval, ratio 9:1 # Split dataset into train and eval, ratio 9:1
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42) split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] train_dataset, eval_dataset = split_dataset['train'], split_dataset['test']
train_dataset = train_dataset.with_transform(img_train_transform) train_dataset = train_dataset.with_transform(img_train_transform)
eval_dataset = eval_dataset.with_transform(img_inf_transform) eval_dataset = eval_dataset.with_transform(img_inf_transform)
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer) collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
# Train from scratch # Train from scratch
@@ -96,14 +101,14 @@ if __name__ == '__main__':
# or train from TexTeller pre-trained model: model = TexTeller.from_pretrained() # or train from TexTeller pre-trained model: model = TexTeller.from_pretrained()
# If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint # If you want to train from pre-trained model, please modify the path to your pre-trained checkpoint
#+e.g. # +e.g.
#+model = TexTeller.from_pretrained( # +model = TexTeller.from_pretrained(
#+ '/path/to/your/model_checkpoint' # + '/path/to/your/model_checkpoint'
#+) # +)
enable_train = True enable_train = True
enable_evaluate = False enable_evaluate = False
if enable_train: if enable_train:
train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer) train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokenizer)
if enable_evaluate and len(eval_dataset) > 0: if enable_evaluate and len(eval_dataset) > 0:
evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer) evaluate(model, tokenizer, eval_dataset, collate_fn_with_tokenizer)

View File

@@ -0,0 +1,31 @@
CONFIG = {
"seed": 42, # Random seed for reproducibility
"use_cpu": False, # Whether to use CPU (it's easier to debug with CPU when starting to test the code)
"learning_rate": 5e-5, # Learning rate
"num_train_epochs": 10, # Total number of training epochs
"per_device_train_batch_size": 4, # Batch size per GPU for training
"per_device_eval_batch_size": 8, # Batch size per GPU for evaluation
"output_dir": "train_result", # Output directory
"overwrite_output_dir": False, # If the output directory exists, do not delete its content
"report_to": ["tensorboard"], # Report logs to TensorBoard
"save_strategy": "steps", # Strategy to save checkpoints
"save_steps": 500, # Interval of steps to save checkpoints, can be int or a float (0~1), when float it represents the ratio of total training steps (e.g., can set to 1.0 / 2000)
"save_total_limit": 5, # Maximum number of models to save. The oldest models will be deleted if this number is exceeded
"logging_strategy": "steps", # Log every certain number of steps
"logging_steps": 500, # Number of steps between each log
"logging_nan_inf_filter": False, # Record logs for loss=nan or inf
"optim": "adamw_torch", # Optimizer
"lr_scheduler_type": "cosine", # Learning rate scheduler
"warmup_ratio": 0.1, # Ratio of warmup steps in total training steps (e.g., for 1000 steps, the first 100 steps gradually increase lr from 0 to the set lr)
"max_grad_norm": 1.0, # For gradient clipping, ensure the norm of the gradients does not exceed 1.0 (default 1.0)
"fp16": False, # Whether to use 16-bit floating point for training (generally not recommended, as loss can easily explode)
"bf16": False, # Whether to use Brain Floating Point (bfloat16) for training (recommended if architecture supports it)
"gradient_accumulation_steps": 1, # Gradient accumulation steps, consider this parameter to achieve large batch size effects when batch size cannot be large
"jit_mode_eval": False, # Whether to use PyTorch jit trace during eval (can speed up the model, but the model must be static, otherwise will throw errors)
"torch_compile": False, # Whether to use torch.compile to compile the model (for better training and inference performance)
"dataloader_pin_memory": True, # Can speed up data transfer between CPU and GPU
"dataloader_num_workers": 1, # Default is not to use multiprocessing for data loading, usually set to 4*number of GPUs used
"evaluation_strategy": "steps", # Evaluation strategy, can be "steps" or "epoch"
"eval_steps": 500, # If evaluation_strategy="step"
"remove_unused_columns": False, # Don't change this unless you really know what you are doing.
}

View File

@@ -26,7 +26,7 @@ def collate_fn(samples: List[Dict[str, Any]], tokenizer=None) -> Dict[str, List[
pixel_values = [dic.pop('pixel_values') for dic in samples] pixel_values = [dic.pop('pixel_values') for dic in samples]
clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) clm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
batch = clm_collator(samples) batch = clm_collator(samples)
batch['pixel_values'] = pixel_values batch['pixel_values'] = pixel_values
batch['decoder_input_ids'] = batch.pop('input_ids') batch['decoder_input_ids'] = batch.pop('input_ids')
@@ -54,6 +54,7 @@ def img_inf_transform(samples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
def filter_fn(sample, tokenizer=None) -> bool: def filter_fn(sample, tokenizer=None) -> bool:
return ( return (
sample['image'].height > MIN_HEIGHT and sample['image'].width > MIN_WIDTH sample['image'].height > MIN_HEIGHT
and sample['image'].width > MIN_WIDTH
and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10 and len(tokenizer(sample['latex_formula'])['input_ids']) < MAX_TOKEN_SIZE - 10
) )

View File

@@ -12,7 +12,7 @@ def convert2rgb(image_paths: List[str]) -> List[np.ndarray]:
continue continue
if image.dtype == np.uint16: if image.dtype == np.uint16:
print(f'Converting {path} to 8-bit, image may be lossy.') print(f'Converting {path} to 8-bit, image may be lossy.')
image = cv2.convertScaleAbs(image, alpha=(255.0/65535.0)) image = cv2.convertScaleAbs(image, alpha=(255.0 / 65535.0))
channels = 1 if len(image.shape) == 2 else image.shape[2] channels = 1 if len(image.shape) == 2 else image.shape[2]
if channels == 4: if channels == 4:

Some files were not shown because too many files have changed in this diff Show More